From 92213273a5dccd52da5ed93ad1ef283af1691e57 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 17 Jan 2016 17:46:56 +0800 Subject: [PATCH] Refactor create callback --- callback_create.go | 69 ++++++++++++++++++++-------------------------- field.go | 28 ++++++++----------- utils.go | 10 +++---- 3 files changed, 46 insertions(+), 61 deletions(-) diff --git a/callback_create.go b/callback_create.go index 7e04d67d..921aa4cc 100644 --- a/callback_create.go +++ b/callback_create.go @@ -42,30 +42,29 @@ func createCallback(scope *Scope) { if !scope.HasError() { defer scope.trace(NowFunc()) - // set create sql - var sqls, columns []string - fields := scope.Fields() + var ( + columns, placeholders []string + blankColumnsWithDefaultValue []string + fields = scope.Fields() + ) + for _, field := range fields { if scope.changeableField(field) { if field.IsNormal { - if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { - if !field.IsBlank || !field.HasDefaultValue { + if !field.IsPrimaryKey || !field.IsBlank { + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else { columns = append(columns, scope.Quote(field.DBName)) - sqls = append(sqls, scope.AddToVars(field.Field.Interface())) - } else if field.HasDefaultValue { - var hasDefaultValueColumns []string - if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { - hasDefaultValueColumns = oldHasDefaultValueColumns.([]string) - } - hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName) - scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns) + placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) } } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, dbName := range relationship.ForeignDBNames { - if relationField := fields[dbName]; !scope.changeableField(relationField) { - columns = append(columns, scope.Quote(relationField.DBName)) - sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { + for _, foreignKey := range field.Relationship.ForeignDBNames { + if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) { + columns = append(columns, scope.Quote(foreignField.DBName)) + placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) } } } @@ -88,35 +87,27 @@ func createCallback(scope *Scope) { "INSERT INTO %v (%v) VALUES (%v) %v", scope.QuotedTableName(), strings.Join(columns, ","), - strings.Join(sqls, ","), + strings.Join(placeholders, ","), scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), )) } // execute create sql - if scope.Dialect().SupportLastInsertId() { + if scope.Dialect().SupportLastInsertId() || primaryField == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - id, err := result.LastInsertId() - if scope.Err(err) == nil { - scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil && primaryField.IsBlank { - scope.Err(scope.SetColumn(primaryField, id)) + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) } } } } else { - if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { - scope.db.RowsAffected, _ = results.RowsAffected() - } else { - scope.Err(err) - } - } else { - if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { - scope.db.RowsAffected = 1 - } else { - scope.Err(err) - } + if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + scope.db.RowsAffected = 1 } } } @@ -124,8 +115,8 @@ func createCallback(scope *Scope) { // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object func forceReloadAfterCreateCallback(scope *Scope) { - if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { - scope.DB().New().Select(columns.([]string)).First(scope.Value) + if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { + scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value) } } diff --git a/field.go b/field.go index 2ed4e732..2f0daf77 100644 --- a/field.go +++ b/field.go @@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) { // Fields get value's fields func (scope *Scope) Fields() map[string]*Field { if scope.fields == nil { - fields := map[string]*Field{} - modelStruct := scope.GetModelStruct() + var ( + fields = map[string]*Field{} + indirectScopeValue = scope.IndirectValue() + isStruct = indirectScopeValue.Kind() == reflect.Struct + ) - indirectValue := scope.IndirectValue() - isStruct := indirectValue.Kind() == reflect.Struct - for _, structField := range modelStruct.StructFields { + for _, structField := range scope.GetModelStruct().StructFields { if field, ok := fields[structField.DBName]; !ok || field.IsIgnored { if isStruct { - fields[structField.DBName] = getField(indirectValue, structField) + fieldValue := indirectScopeValue + for _, name := range structField.Names { + fieldValue = reflect.Indirect(fieldValue).FieldByName(name) + } + fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)} } else { fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} } @@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field { } scope.fields = fields - return fields } return scope.fields } - -func getField(indirectValue reflect.Value, structField *StructField) *Field { - field := &Field{StructField: structField} - for _, name := range structField.Names { - indirectValue = reflect.Indirect(indirectValue).FieldByName(name) - } - field.Field = indirectValue - field.IsBlank = isBlank(indirectValue) - return field -} diff --git a/utils.go b/utils.go index 58b14ac4..43d0031c 100644 --- a/utils.go +++ b/utils.go @@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string { return strings.Join(newColumns, ",") } -func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { - for _, primaryValue := range primaryValues { - for _, value := range primaryValue { - values = append(values, value) +func toQueryValues(values [][]interface{}) (results []interface{}) { + for _, value := range values { + for _, v := range value { + results = append(results, v) } } - return values + return }