From 608fd976c4a9cbad05681d5dc6b4602be30c8fec Mon Sep 17 00:00:00 2001 From: Christophe de Vienne Date: Mon, 20 Jun 2016 15:00:19 +0200 Subject: [PATCH 1/4] Fix auto_increment on postgres database. --- create_test.go | 12 ++++++++++++ migration_test.go | 1 + model_struct.go | 4 ++++ 3 files changed, 17 insertions(+) diff --git a/create_test.go b/create_test.go index dc82de50..28a049f4 100644 --- a/create_test.go +++ b/create_test.go @@ -57,6 +57,18 @@ func TestCreate(t *testing.T) { } } +func TestCreateWithAutoIncrement(t *testing.T) { + user1 := User{} + user2 := User{} + + DB.Create(&user1) + DB.Create(&user2) + + if user2.Sequence-user1.Sequence != 1 { + t.Errorf("Auto increment should apply on Sequence") + } +} + func TestCreateWithNoGORMPrimayKey(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") diff --git a/migration_test.go b/migration_test.go index 592dadb7..ec33efc1 100644 --- a/migration_test.go +++ b/migration_test.go @@ -33,6 +33,7 @@ type User struct { Company Company Role PasswordHash []byte + Sequence uint `gorm:"AUTO_INCREMENT"` IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` Ignored struct{ Name string } `sql:"-"` diff --git a/model_struct.go b/model_struct.go index eb3762f4..d8f9ed1b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -175,6 +175,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.HasDefaultValue = true } + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + field.HasDefaultValue = true + } + indirectType := fieldStruct.Type for indirectType.Kind() == reflect.Ptr { indirectType = indirectType.Elem() From 328fe672c8f1aceaf7e65d66da1c443779b20051 Mon Sep 17 00:00:00 2001 From: Christophe de Vienne Date: Mon, 20 Jun 2016 16:00:38 +0200 Subject: [PATCH 2/4] Test AUTO_INCREMENT only on postgres Only the postgres dialect handles AUTO_INCREMENT on non-primary key. So we skip the auto increment test for other dialects. The mysql case is a little trickier because the simple presence of the 'AUTH_INCREMENT' tag produces a faulty 'CREATE TABLE' statement. Hence we need to remove it when present. --- create_test.go | 3 +++ dialect_mysql.go | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/create_test.go b/create_test.go index 28a049f4..2d71c9a6 100644 --- a/create_test.go +++ b/create_test.go @@ -58,6 +58,9 @@ func TestCreate(t *testing.T) { } func TestCreateWithAutoIncrement(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { + t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") + } user1 := User{} user2 := User{} diff --git a/dialect_mysql.go b/dialect_mysql.go index bc4828de..0ddcea4d 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -30,6 +30,14 @@ func (mysql) Quote(key string) string { func (mysql) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + // MySQL allows only one auto increment column per table, and it must + // be a KEY column. + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { + delete(field.TagSettings, "AUTO_INCREMENT") + } + } + if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: From 24501a3c1a9de1c0116001e9c4c74d532872267a Mon Sep 17 00:00:00 2001 From: Richard Knop Date: Sat, 9 Jul 2016 18:51:38 +0800 Subject: [PATCH 3/4] Fixed bug when preload duplicates has many related objects. --- callback_query_preload.go | 7 ++++ preload_test.go | 78 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/callback_query_preload.go b/callback_query_preload.go index d9ec8bdd..c9bfa866 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -186,6 +186,13 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) + if j > 0 { + prevObject := indirect(indirectScopeValue.Index(j - 1)) + prevObjectRealValue := getValueFromFields(prevObject, relation.AssociationForeignFieldNames) + if toString(prevObjectRealValue) == toString(objectRealValue) { + continue + } + } if results, ok := preloadMap[toString(objectRealValue)]; ok { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, results...)) diff --git a/preload_test.go b/preload_test.go index da3ee38f..fd5b3af6 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1509,6 +1509,84 @@ func TestNilPointerSlice2(t *testing.T) { } } +func TestPrefixedPreloadDuplication(t *testing.T) { + type ( + Level4 struct { + ID uint + Level3ID uint + } + Level3 struct { + ID uint + Level4s []*Level4 + } + Level2 struct { + ID uint + Level3ID sql.NullInt64 `sql:"index"` + Level3 *Level3 + } + Level1 struct { + ID uint + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) + + DB.DropTableIfExists(new(Level3)) + DB.DropTableIfExists(new(Level4)) + DB.DropTableIfExists(new(Level2)) + DB.DropTableIfExists(new(Level1)) + + if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { + t.Error(err) + } + + lvl := new(Level3) + if err := DB.Save(lvl).Error; err != nil { + t.Error(err) + } + + sublvl1 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl1).Error; err != nil { + t.Error(err) + } + sublvl2 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl2).Error; err != nil { + t.Error(err) + } + + lvl.Level4s = []*Level4{sublvl1, sublvl2} + + want1 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want1).Error; err != nil { + t.Error(err) + } + + want2 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + want := []Level1{want1, want2} + + var got []Level1 + err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r From ca46038cb43072306bca032b73ba22d873fe1afc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jul 2016 21:34:37 +0800 Subject: [PATCH 4/4] Fix preload duplicates has many related objects --- callback_query_preload.go | 7 ------- preload_test.go | 6 +++++- scope.go | 7 +++++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index c9bfa866..d9ec8bdd 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -186,13 +186,6 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - if j > 0 { - prevObject := indirect(indirectScopeValue.Index(j - 1)) - prevObjectRealValue := getValueFromFields(prevObject, relation.AssociationForeignFieldNames) - if toString(prevObjectRealValue) == toString(objectRealValue) { - continue - } - } if results, ok := preloadMap[toString(objectRealValue)]; ok { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, results...)) diff --git a/preload_test.go b/preload_test.go index fd5b3af6..8c56a8ac 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1513,19 +1513,23 @@ func TestPrefixedPreloadDuplication(t *testing.T) { type ( Level4 struct { ID uint + Name string Level3ID uint } Level3 struct { ID uint + Name string Level4s []*Level4 } Level2 struct { ID uint + Name string Level3ID sql.NullInt64 `sql:"index"` Level3 *Level3 } Level1 struct { ID uint + Name string Level2ID sql.NullInt64 `sql:"index"` Level2 *Level2 } @@ -1540,7 +1544,7 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } - lvl := new(Level3) + lvl := &Level3{} if err := DB.Save(lvl).Error; err != nil { t.Error(err) } diff --git a/scope.go b/scope.go index 0ecf43df..974ff035 100644 --- a/scope.go +++ b/scope.go @@ -1237,6 +1237,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { fieldType = fieldType.Elem() } + resultsMap := map[interface{}]bool{} results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() for i := 0; i < indirectScopeValue.Len(); i++ { @@ -1244,11 +1245,13 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { if result.Kind() == reflect.Slice { for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { + if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { + resultsMap[elem.Addr()] = true results = reflect.Append(results, elem.Addr()) } } - } else if result.CanAddr() { + } else if result.CanAddr() && resultsMap[result.Addr()] != true { + resultsMap[result.Addr()] = true results = reflect.Append(results, result.Addr()) } }