From cf80e191578618d94f32a0d1dece3169c4f99bf2 Mon Sep 17 00:00:00 2001 From: Ethan Koenig Date: Mon, 27 Feb 2017 20:35:55 -0500 Subject: [PATCH] Optimize and unit test Issue_ReplaceLabels (#1080) --- models/issue.go | 45 ++++++++++++++++------------------------ models/issue_test.go | 35 +++++++++++++++++++++++++++++++ models/setup_for_test.go | 7 ++++--- 3 files changed, 57 insertions(+), 30 deletions(-) create mode 100644 models/issue_test.go diff --git a/models/issue.go b/models/issue.go index 4677da401..d70cf02ca 100644 --- a/models/issue.go +++ b/models/issue.go @@ -476,31 +476,24 @@ func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) { sort.Sort(labelSorter(issue.Labels)) var toAdd, toRemove []*Label - for _, l := range labels { - var exist bool - for _, oriLabel := range issue.Labels { - if oriLabel.ID == l.ID { - exist = true - break - } - } - if !exist { - toAdd = append(toAdd, l) - } - } - for _, oriLabel := range issue.Labels { - var exist bool - for _, l := range labels { - if oriLabel.ID == l.ID { - exist = true - break - } - } - if !exist { - toRemove = append(toRemove, oriLabel) + addIndex, removeIndex := 0, 0 + for addIndex < len(labels) && removeIndex < len(issue.Labels) { + addLabel := labels[addIndex] + removeLabel := issue.Labels[removeIndex] + if addLabel.ID == removeLabel.ID { + addIndex++ + removeIndex++ + } else if addLabel.ID < removeLabel.ID { + toAdd = append(toAdd, addLabel) + addIndex++ + } else { + toRemove = append(toRemove, removeLabel) + removeIndex++ } } + toAdd = append(toAdd, labels[addIndex:]...) + toRemove = append(toRemove, issue.Labels[removeIndex:]...) if len(toAdd) > 0 { if err = issue.addLabels(sess, toAdd, doer); err != nil { @@ -508,11 +501,9 @@ func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) { } } - if len(toRemove) > 0 { - for _, l := range toRemove { - if err = issue.removeLabel(sess, doer, l); err != nil { - return fmt.Errorf("removeLabel: %v", err) - } + for _, l := range toRemove { + if err = issue.removeLabel(sess, doer, l); err != nil { + return fmt.Errorf("removeLabel: %v", err) } } diff --git a/models/issue_test.go b/models/issue_test.go new file mode 100644 index 000000000..646a1a5a9 --- /dev/null +++ b/models/issue_test.go @@ -0,0 +1,35 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIssue_ReplaceLabels(t *testing.T) { + assert.NoError(t, PrepareTestDatabase()) + + testSuccess := func(issueID int64, labelIDs []int64) { + issue := AssertExistsAndLoadBean(t, &Issue{ID: issueID}).(*Issue) + repo := AssertExistsAndLoadBean(t, &Repository{ID: issue.RepoID}).(*Repository) + doer := AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User) + + labels := make([]*Label, len(labelIDs)) + for i, labelID := range labelIDs { + labels[i] = AssertExistsAndLoadBean(t, &Label{ID: labelID, RepoID: repo.ID}).(*Label) + } + assert.NoError(t, issue.ReplaceLabels(labels, doer)) + AssertCount(t, &IssueLabel{IssueID: issueID}, len(labelIDs)) + for _, labelID := range labelIDs { + AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID}) + } + } + + testSuccess(1, []int64{2}) + testSuccess(1, []int64{1, 2}) + testSuccess(1, []int64{}) +} diff --git a/models/setup_for_test.go b/models/setup_for_test.go index 91d5442c8..5a17eac78 100644 --- a/models/setup_for_test.go +++ b/models/setup_for_test.go @@ -107,8 +107,9 @@ func AssertSuccessfulInsert(t *testing.T, beans ...interface{}) { assert.NoError(t, err) } -// AssertSuccessfulUpdate assert that bean is successfully updated -func AssertSuccessfulUpdate(t *testing.T, bean interface{}, conditions ...interface{}) { - _, err := x.Update(bean, conditions...) +// AssertCount assert the count of a bean +func AssertCount(t *testing.T, bean interface{}, expected interface{}) { + actual, err := x.Count(bean) assert.NoError(t, err) + assert.EqualValues(t, expected, actual) }