mirror of https://github.com/go-gorm/gorm.git
Refactor create callback
This commit is contained in:
parent
e38b1e0948
commit
92213273a5
|
@ -42,30 +42,29 @@ func createCallback(scope *Scope) {
|
|||
if !scope.HasError() {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
// set create sql
|
||||
var sqls, columns []string
|
||||
fields := scope.Fields()
|
||||
var (
|
||||
columns, placeholders []string
|
||||
blankColumnsWithDefaultValue []string
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
for _, field := range fields {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||
if !field.IsBlank || !field.HasDefaultValue {
|
||||
if !field.IsPrimaryKey || !field.IsBlank {
|
||||
if field.IsBlank && field.HasDefaultValue {
|
||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
|
||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||
} else {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
} else if field.HasDefaultValue {
|
||||
var hasDefaultValueColumns []string
|
||||
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
||||
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
|
||||
}
|
||||
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
|
||||
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
|
||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
}
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, dbName := range relationship.ForeignDBNames {
|
||||
if relationField := fields[dbName]; !scope.changeableField(relationField) {
|
||||
columns = append(columns, scope.Quote(relationField.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||
if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
|
||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -88,35 +87,27 @@ func createCallback(scope *Scope) {
|
|||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(sqls, ","),
|
||||
strings.Join(placeholders, ","),
|
||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if scope.Dialect().SupportLastInsertId() {
|
||||
if scope.Dialect().SupportLastInsertId() || primaryField == nil {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
id, err := result.LastInsertId()
|
||||
if scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
scope.Err(scope.SetColumn(primaryField, id))
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -124,8 +115,8 @@ func createCallback(scope *Scope) {
|
|||
|
||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
||||
scope.DB().New().Select(columns.([]string)).First(scope.Value)
|
||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
28
field.go
28
field.go
|
@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() map[string]*Field {
|
||||
if scope.fields == nil {
|
||||
fields := map[string]*Field{}
|
||||
modelStruct := scope.GetModelStruct()
|
||||
var (
|
||||
fields = map[string]*Field{}
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||
)
|
||||
|
||||
indirectValue := scope.IndirectValue()
|
||||
isStruct := indirectValue.Kind() == reflect.Struct
|
||||
for _, structField := range modelStruct.StructFields {
|
||||
for _, structField := range scope.GetModelStruct().StructFields {
|
||||
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
|
||||
if isStruct {
|
||||
fields[structField.DBName] = getField(indirectValue, structField)
|
||||
fieldValue := indirectScopeValue
|
||||
for _, name := range structField.Names {
|
||||
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||
}
|
||||
fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
|
||||
} else {
|
||||
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
||||
}
|
||||
|
@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
}
|
||||
|
||||
scope.fields = fields
|
||||
return fields
|
||||
}
|
||||
return scope.fields
|
||||
}
|
||||
|
||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
||||
field := &Field{StructField: structField}
|
||||
for _, name := range structField.Names {
|
||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
||||
}
|
||||
field.Field = indirectValue
|
||||
field.IsBlank = isBlank(indirectValue)
|
||||
return field
|
||||
}
|
||||
|
|
10
utils.go
10
utils.go
|
@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string {
|
|||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||
for _, value := range values {
|
||||
for _, v := range value {
|
||||
results = append(results, v)
|
||||
}
|
||||
}
|
||||
return values
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue