From 406a01875c37c47c9b471a0cda93ec0691826c9f Mon Sep 17 00:00:00 2001 From: Jinzhu <wosmvp@gmail.com> Date: Fri, 20 Feb 2015 22:06:49 +0800 Subject: [PATCH] Test using pointer as struct fields --- callback_create.go | 12 +++---- model_struct.go | 6 +++- pointer_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ scope.go | 5 +-- 4 files changed, 97 insertions(+), 10 deletions(-) create mode 100644 pointer_test.go diff --git a/callback_create.go b/callback_create.go index a27a9eb4..7d57f1bd 100644 --- a/callback_create.go +++ b/callback_create.go @@ -55,12 +55,14 @@ func Create(scope *Scope) { } // execute create sql - var id interface{} if scope.Dialect().SupportLastInsertId() { if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - id, err = result.LastInsertId() + id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() + if primaryField != nil { + scope.SetColumn(primaryField, id) + } } } } else { @@ -68,14 +70,10 @@ func Create(scope *Scope) { if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil { scope.db.RowsAffected, _ = results.RowsAffected() } - } else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { + } else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { scope.db.RowsAffected = 1 } } - - if primaryField != nil && primaryField.IsBlank && !scope.HasError() { - scope.SetColumn(primaryField, id) - } } } diff --git a/model_struct.go b/model_struct.go index 828a4022..9e7e64df 100644 --- a/model_struct.go +++ b/model_struct.go @@ -65,7 +65,11 @@ type Relationship struct { func (scope *Scope) generateSqlTag(field *StructField) { var sqlType string - reflectValue := reflect.Indirect(reflect.New(field.Struct.Type)) + structType := field.Struct.Type + if structType.Kind() == reflect.Ptr { + structType = structType.Elem() + } + reflectValue := reflect.Indirect(reflect.New(structType)) sqlSettings := parseTagSetting(field.Tag.Get("sql")) if value, ok := sqlSettings["TYPE"]; ok { diff --git a/pointer_test.go b/pointer_test.go new file mode 100644 index 00000000..b47717f3 --- /dev/null +++ b/pointer_test.go @@ -0,0 +1,84 @@ +package gorm_test + +import "testing" + +type PointerStruct struct { + ID int64 + Name *string + Num *int +} + +type NormalStruct struct { + ID int64 + Name string + Num int +} + +func TestPointerFields(t *testing.T) { + DB.DropTable(&PointerStruct{}) + DB.AutoMigrate(&PointerStruct{}) + var name = "pointer struct 1" + var num = 100 + pointerStruct := PointerStruct{Name: &name, Num: &num} + if DB.Create(&pointerStruct).Error != nil { + t.Errorf("Failed to save pointer struct") + } + + var pointerStructResult PointerStruct + if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { + t.Errorf("Failed to query saved pointer struct") + } + + var tableName = DB.NewScope(&PointerStruct{}).TableName() + + var normalStruct NormalStruct + DB.Table(tableName).First(&normalStruct) + if normalStruct.Name != name || normalStruct.Num != num { + t.Errorf("Failed to query saved Normal struct") + } + + var nilPointerStruct = PointerStruct{} + if err := DB.Create(&nilPointerStruct).Error; err != nil { + t.Errorf("Failed to save nil pointer struct", err) + } + + var pointerStruct2 PointerStruct + if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + t.Errorf("Failed to query saved nil pointer struct", err) + } + + var normalStruct2 NormalStruct + if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + t.Errorf("Failed to query saved nil pointer struct", err) + } + + var partialNilPointerStruct1 = PointerStruct{Num: &num} + if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { + t.Errorf("Failed to save partial nil pointer struct", err) + } + + var pointerStruct3 PointerStruct + if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { + t.Errorf("Failed to query saved partial nil pointer struct", err) + } + + var normalStruct3 NormalStruct + if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { + t.Errorf("Failed to query saved partial pointer struct", err) + } + + var partialNilPointerStruct2 = PointerStruct{Name: &name} + if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { + t.Errorf("Failed to save partial nil pointer struct", err) + } + + var pointerStruct4 PointerStruct + if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { + t.Errorf("Failed to query saved partial nil pointer struct", err) + } + + var normalStruct4 NormalStruct + if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { + t.Errorf("Failed to query saved partial pointer struct", err) + } +} diff --git a/scope.go b/scope.go index 56b0e52b..73abc194 100644 --- a/scope.go +++ b/scope.go @@ -35,8 +35,9 @@ func (scope *Scope) IndirectValue() reflect.Value { // NewScope create scope for callbacks, including DB's search information func (db *DB) NewScope(value interface{}) *Scope { - db.Value = value - return &Scope{db: db, Search: db.search, Value: value} + dbClone := db.clone() + dbClone.Value = value + return &Scope{db: dbClone, Search: dbClone.search, Value: value} } func (scope *Scope) NeedPtr() *Scope {