mirror of https://github.com/go-gorm/gorm.git
Refactor create callback
This commit is contained in:
parent
e38b1e0948
commit
92213273a5
|
@ -42,30 +42,29 @@ func createCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
// set create sql
|
var (
|
||||||
var sqls, columns []string
|
columns, placeholders []string
|
||||||
fields := scope.Fields()
|
blankColumnsWithDefaultValue []string
|
||||||
|
fields = scope.Fields()
|
||||||
|
)
|
||||||
|
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
if !field.IsPrimaryKey || !field.IsBlank {
|
||||||
if !field.IsBlank || !field.HasDefaultValue {
|
if field.IsBlank && field.HasDefaultValue {
|
||||||
|
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
|
||||||
|
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||||
|
} else {
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||||
} else if field.HasDefaultValue {
|
|
||||||
var hasDefaultValueColumns []string
|
|
||||||
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
|
||||||
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
|
|
||||||
}
|
|
||||||
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
|
|
||||||
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||||
for _, dbName := range relationship.ForeignDBNames {
|
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||||
if relationField := fields[dbName]; !scope.changeableField(relationField) {
|
if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) {
|
||||||
columns = append(columns, scope.Quote(relationField.DBName))
|
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,35 +87,27 @@ func createCallback(scope *Scope) {
|
||||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(placeholders, ","),
|
||||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql
|
||||||
if scope.Dialect().SupportLastInsertId() {
|
if scope.Dialect().SupportLastInsertId() || primaryField == nil {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
id, err := result.LastInsertId()
|
// set rows affected count
|
||||||
if scope.Err(err) == nil {
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
if primaryField != nil && primaryField.IsBlank {
|
// set primary value to primary field
|
||||||
scope.Err(scope.SetColumn(primaryField, id))
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
|
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||||
|
scope.Err(primaryField.Set(primaryValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if primaryField == nil {
|
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
scope.db.RowsAffected = 1
|
||||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
|
||||||
} else {
|
|
||||||
scope.Err(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
|
||||||
scope.db.RowsAffected = 1
|
|
||||||
} else {
|
|
||||||
scope.Err(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,8 +115,8 @@ func createCallback(scope *Scope) {
|
||||||
|
|
||||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||||
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||||
scope.DB().New().Select(columns.([]string)).First(scope.Value)
|
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
28
field.go
28
field.go
|
@ -58,15 +58,20 @@ func (field *Field) Set(value interface{}) (err error) {
|
||||||
// Fields get value's fields
|
// Fields get value's fields
|
||||||
func (scope *Scope) Fields() map[string]*Field {
|
func (scope *Scope) Fields() map[string]*Field {
|
||||||
if scope.fields == nil {
|
if scope.fields == nil {
|
||||||
fields := map[string]*Field{}
|
var (
|
||||||
modelStruct := scope.GetModelStruct()
|
fields = map[string]*Field{}
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||||
|
)
|
||||||
|
|
||||||
indirectValue := scope.IndirectValue()
|
for _, structField := range scope.GetModelStruct().StructFields {
|
||||||
isStruct := indirectValue.Kind() == reflect.Struct
|
|
||||||
for _, structField := range modelStruct.StructFields {
|
|
||||||
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
|
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
|
||||||
if isStruct {
|
if isStruct {
|
||||||
fields[structField.DBName] = getField(indirectValue, structField)
|
fieldValue := indirectScopeValue
|
||||||
|
for _, name := range structField.Names {
|
||||||
|
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||||
|
}
|
||||||
|
fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
|
||||||
} else {
|
} else {
|
||||||
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
||||||
}
|
}
|
||||||
|
@ -74,17 +79,6 @@ func (scope *Scope) Fields() map[string]*Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.fields = fields
|
scope.fields = fields
|
||||||
return fields
|
|
||||||
}
|
}
|
||||||
return scope.fields
|
return scope.fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
|
||||||
field := &Field{StructField: structField}
|
|
||||||
for _, name := range structField.Names {
|
|
||||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
|
||||||
}
|
|
||||||
field.Field = indirectValue
|
|
||||||
field.IsBlank = isBlank(indirectValue)
|
|
||||||
return field
|
|
||||||
}
|
|
||||||
|
|
10
utils.go
10
utils.go
|
@ -132,11 +132,11 @@ func toQueryCondition(scope *Scope, columns []string) string {
|
||||||
return strings.Join(newColumns, ",")
|
return strings.Join(newColumns, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||||
for _, primaryValue := range primaryValues {
|
for _, value := range values {
|
||||||
for _, value := range primaryValue {
|
for _, v := range value {
|
||||||
values = append(values, value)
|
results = append(results, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return values
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue