From 406a01875c37c47c9b471a0cda93ec0691826c9f Mon Sep 17 00:00:00 2001
From: Jinzhu <wosmvp@gmail.com>
Date: Fri, 20 Feb 2015 22:06:49 +0800
Subject: [PATCH] Test using pointer as struct fields

---
 callback_create.go | 12 +++----
 model_struct.go    |  6 +++-
 pointer_test.go    | 84 ++++++++++++++++++++++++++++++++++++++++++++++
 scope.go           |  5 +--
 4 files changed, 97 insertions(+), 10 deletions(-)
 create mode 100644 pointer_test.go

diff --git a/callback_create.go b/callback_create.go
index a27a9eb4..7d57f1bd 100644
--- a/callback_create.go
+++ b/callback_create.go
@@ -55,12 +55,14 @@ func Create(scope *Scope) {
 		}
 
 		// execute create sql
-		var id interface{}
 		if scope.Dialect().SupportLastInsertId() {
 			if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
-				id, err = result.LastInsertId()
+				id, err := result.LastInsertId()
 				if scope.Err(err) == nil {
 					scope.db.RowsAffected, _ = result.RowsAffected()
+					if primaryField != nil {
+						scope.SetColumn(primaryField, id)
+					}
 				}
 			}
 		} else {
@@ -68,14 +70,10 @@ func Create(scope *Scope) {
 				if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil {
 					scope.db.RowsAffected, _ = results.RowsAffected()
 				}
-			} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil {
+			} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
 				scope.db.RowsAffected = 1
 			}
 		}
-
-		if primaryField != nil && primaryField.IsBlank && !scope.HasError() {
-			scope.SetColumn(primaryField, id)
-		}
 	}
 }
 
diff --git a/model_struct.go b/model_struct.go
index 828a4022..9e7e64df 100644
--- a/model_struct.go
+++ b/model_struct.go
@@ -65,7 +65,11 @@ type Relationship struct {
 
 func (scope *Scope) generateSqlTag(field *StructField) {
 	var sqlType string
-	reflectValue := reflect.Indirect(reflect.New(field.Struct.Type))
+	structType := field.Struct.Type
+	if structType.Kind() == reflect.Ptr {
+		structType = structType.Elem()
+	}
+	reflectValue := reflect.Indirect(reflect.New(structType))
 	sqlSettings := parseTagSetting(field.Tag.Get("sql"))
 
 	if value, ok := sqlSettings["TYPE"]; ok {
diff --git a/pointer_test.go b/pointer_test.go
new file mode 100644
index 00000000..b47717f3
--- /dev/null
+++ b/pointer_test.go
@@ -0,0 +1,84 @@
+package gorm_test
+
+import "testing"
+
+type PointerStruct struct {
+	ID   int64
+	Name *string
+	Num  *int
+}
+
+type NormalStruct struct {
+	ID   int64
+	Name string
+	Num  int
+}
+
+func TestPointerFields(t *testing.T) {
+	DB.DropTable(&PointerStruct{})
+	DB.AutoMigrate(&PointerStruct{})
+	var name = "pointer struct 1"
+	var num = 100
+	pointerStruct := PointerStruct{Name: &name, Num: &num}
+	if DB.Create(&pointerStruct).Error != nil {
+		t.Errorf("Failed to save pointer struct")
+	}
+
+	var pointerStructResult PointerStruct
+	if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
+		t.Errorf("Failed to query saved pointer struct")
+	}
+
+	var tableName = DB.NewScope(&PointerStruct{}).TableName()
+
+	var normalStruct NormalStruct
+	DB.Table(tableName).First(&normalStruct)
+	if normalStruct.Name != name || normalStruct.Num != num {
+		t.Errorf("Failed to query saved Normal struct")
+	}
+
+	var nilPointerStruct = PointerStruct{}
+	if err := DB.Create(&nilPointerStruct).Error; err != nil {
+		t.Errorf("Failed to save nil pointer struct", err)
+	}
+
+	var pointerStruct2 PointerStruct
+	if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
+		t.Errorf("Failed to query saved nil pointer struct", err)
+	}
+
+	var normalStruct2 NormalStruct
+	if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
+		t.Errorf("Failed to query saved nil pointer struct", err)
+	}
+
+	var partialNilPointerStruct1 = PointerStruct{Num: &num}
+	if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
+		t.Errorf("Failed to save partial nil pointer struct", err)
+	}
+
+	var pointerStruct3 PointerStruct
+	if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
+		t.Errorf("Failed to query saved partial nil pointer struct", err)
+	}
+
+	var normalStruct3 NormalStruct
+	if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
+		t.Errorf("Failed to query saved partial pointer struct", err)
+	}
+
+	var partialNilPointerStruct2 = PointerStruct{Name: &name}
+	if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
+		t.Errorf("Failed to save partial nil pointer struct", err)
+	}
+
+	var pointerStruct4 PointerStruct
+	if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
+		t.Errorf("Failed to query saved partial nil pointer struct", err)
+	}
+
+	var normalStruct4 NormalStruct
+	if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
+		t.Errorf("Failed to query saved partial pointer struct", err)
+	}
+}
diff --git a/scope.go b/scope.go
index 56b0e52b..73abc194 100644
--- a/scope.go
+++ b/scope.go
@@ -35,8 +35,9 @@ func (scope *Scope) IndirectValue() reflect.Value {
 
 // NewScope create scope for callbacks, including DB's search information
 func (db *DB) NewScope(value interface{}) *Scope {
-	db.Value = value
-	return &Scope{db: db, Search: db.search, Value: value}
+	dbClone := db.clone()
+	dbClone.Value = value
+	return &Scope{db: dbClone, Search: dbClone.search, Value: value}
 }
 
 func (scope *Scope) NeedPtr() *Scope {