Refactor create callback

This commit is contained in:
Jinzhu 2016-01-17 17:46:56 +08:00
parent e38b1e0948
commit 92213273a5
3 changed files with 46 additions and 61 deletions

View File

@ -42,30 +42,29 @@ func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
// set create sql
var sqls, columns []string
fields := scope.Fields()
var (
columns, placeholders []string
blankColumnsWithDefaultValue []string
fields = scope.Fields()
)
for _, field := range fields {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
if !field.IsPrimaryKey || !field.IsBlank {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
} else if field.HasDefaultValue {
var hasDefaultValueColumns []string
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
}
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, dbName := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(foreignField.DBName))
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
}
}
}
@ -88,35 +87,27 @@ func createCallback(scope *Scope) {
"INSERT INTO %v (%v) VALUES (%v) %v",
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
strings.Join(placeholders, ","),
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
))
}
// execute create sql
if scope.Dialect().SupportLastInsertId() {
if scope.Dialect().SupportLastInsertId() || primaryField == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
}
}
}
@ -124,8 +115,8 @@ func createCallback(scope *Scope) {
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
scope.DB().New().Select(columns.([]string)).First(scope.Value)
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
}
}

View File

@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) {
// Fields get value's fields
func (scope *Scope) Fields() map[string]*Field {
if scope.fields == nil {
fields := map[string]*Field{}
modelStruct := scope.GetModelStruct()
var (
fields = map[string]*Field{}
indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range modelStruct.StructFields {
for _, structField := range scope.GetModelStruct().StructFields {
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
fieldValue := indirectScopeValue
for _, name := range structField.Names {
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
}
fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field {
}
scope.fields = fields
return fields
}
return scope.fields
}
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
field.IsBlank = isBlank(indirectValue)
return field
}

View File

@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string {
return strings.Join(newColumns, ",")
}
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
for _, primaryValue := range primaryValues {
for _, value := range primaryValue {
values = append(values, value)
func toQueryValues(values [][]interface{}) (results []interface{}) {
for _, value := range values {
for _, v := range value {
results = append(results, v)
}
}
return values
return
}