diff --git a/clause/set.go b/clause/set.go index 590e27d5..4adfe68f 100644 --- a/clause/set.go +++ b/clause/set.go @@ -1,5 +1,7 @@ package clause +import "sort" + type Set []Assignment type Assignment struct { @@ -32,3 +34,22 @@ func (set Set) Build(builder Builder) { func (set Set) MergeClause(clause *Clause) { clause.Expression = set } + +func Assignments(values map[string]interface{}) Set { + var keys []string + var assignments []Assignment + + for key := range values { + keys = append(keys, key) + } + + sort.Strings(keys) + + for _, key := range keys { + assignments = append(assignments, Assignment{ + Column: Column{Table: CurrentTable, Name: key}, + Value: values[key], + }) + } + return assignments +} diff --git a/clause/set_test.go b/clause/set_test.go index dbc1e970..56fac706 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -2,6 +2,8 @@ package clause_test import ( "fmt" + "sort" + "strings" "testing" "gorm.io/gorm/clause" @@ -36,3 +38,20 @@ func TestSet(t *testing.T) { }) } } + +func TestAssignments(t *testing.T) { + set := clause.Assignments(map[string]interface{}{ + "name": "jinzhu", + "age": 18, + }) + + assignments := []clause.Assignment(set) + + sort.Slice(assignments, func(i, j int) bool { + return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 + }) + + if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { + t.Errorf("invalid assignments, got %v", assignments) + } +}