Refactor create callback

This commit is contained in:
Jinzhu 2016-01-17 17:46:56 +08:00
parent e38b1e0948
commit 92213273a5
3 changed files with 46 additions and 61 deletions

View File

@ -42,30 +42,29 @@ func createCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
// set create sql var (
var sqls, columns []string columns, placeholders []string
fields := scope.Fields() blankColumnsWithDefaultValue []string
fields = scope.Fields()
)
for _, field := range fields { for _, field := range fields {
if scope.changeableField(field) { if scope.changeableField(field) {
if field.IsNormal { if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { if !field.IsPrimaryKey || !field.IsBlank {
if !field.IsBlank || !field.HasDefaultValue { if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else {
columns = append(columns, scope.Quote(field.DBName)) columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface())) placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
} else if field.HasDefaultValue {
var hasDefaultValueColumns []string
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
}
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
} }
} }
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, dbName := range relationship.ForeignDBNames { for _, foreignKey := range field.Relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) { if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(relationField.DBName)) columns = append(columns, scope.Quote(foreignField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
} }
} }
} }
@ -88,35 +87,27 @@ func createCallback(scope *Scope) {
"INSERT INTO %v (%v) VALUES (%v) %v", "INSERT INTO %v (%v) VALUES (%v) %v",
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(columns, ","), strings.Join(columns, ","),
strings.Join(sqls, ","), strings.Join(placeholders, ","),
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
)) ))
} }
// execute create sql // execute create sql
if scope.Dialect().SupportLastInsertId() { if scope.Dialect().SupportLastInsertId() || primaryField == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId() // set rows affected count
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank { if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id)) if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
} }
} }
} }
} else { } else {
if primaryField == nil { if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1 scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
} }
} }
} }
@ -124,8 +115,8 @@ func createCallback(scope *Scope) {
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) { func forceReloadAfterCreateCallback(scope *Scope) {
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
scope.DB().New().Select(columns.([]string)).First(scope.Value) scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
} }
} }

View File

@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) {
// Fields get value's fields // Fields get value's fields
func (scope *Scope) Fields() map[string]*Field { func (scope *Scope) Fields() map[string]*Field {
if scope.fields == nil { if scope.fields == nil {
fields := map[string]*Field{} var (
modelStruct := scope.GetModelStruct() fields = map[string]*Field{}
indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
indirectValue := scope.IndirectValue() for _, structField := range scope.GetModelStruct().StructFields {
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range modelStruct.StructFields {
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored { if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
if isStruct { if isStruct {
fields[structField.DBName] = getField(indirectValue, structField) fieldValue := indirectScopeValue
for _, name := range structField.Names {
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
}
fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
} else { } else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
} }
@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field {
} }
scope.fields = fields scope.fields = fields
return fields
} }
return scope.fields return scope.fields
} }
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
field.IsBlank = isBlank(indirectValue)
return field
}

View File

@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string {
return strings.Join(newColumns, ",") return strings.Join(newColumns, ",")
} }
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { func toQueryValues(values [][]interface{}) (results []interface{}) {
for _, primaryValue := range primaryValues { for _, value := range values {
for _, value := range primaryValue { for _, v := range value {
values = append(values, value) results = append(results, v)
} }
} }
return values return
} }