From f92e6747cb12d5a5bc2bf7e0d76cb8e5f69cd637 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Mar 2022 17:24:25 +0800 Subject: [PATCH] Handle field set value error --- callbacks/associations.go | 14 +++++++------- callbacks/create.go | 18 +++++++++--------- callbacks/preload.go | 14 +++++++------- callbacks/update.go | 2 +- scan.go | 4 ++-- schema/field.go | 5 +++-- statement.go | 8 ++++---- tests/go.mod | 2 +- 8 files changed, 34 insertions(+), 33 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 644ef185..fd3141cf 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -159,9 +159,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) - ref.ForeignKey.Set(db.Statement.Context, f, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -193,9 +193,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) - ref.ForeignKey.Set(db.Statement.Context, elem, pv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) } } @@ -261,12 +261,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) } else { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } } joins = reflect.Append(joins, joinValue) diff --git a/callbacks/create.go b/callbacks/create.go index 0a43cacb..e94b7eca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -121,7 +121,7 @@ func Create(config *Config) func(db *gorm.DB) { _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -133,7 +133,7 @@ func Create(config *Config) func(db *gorm.DB) { } if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -141,7 +141,7 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Struct: _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } @@ -227,13 +227,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(stmt.Context, rv, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } @@ -267,13 +267,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 888f832d..ea2570ba 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -123,17 +123,17 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } @@ -158,12 +158,12 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(tx.Statement.Context, data, elem.Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1964973b..01f40509 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) } } } diff --git a/scan.go b/scan.go index 89d92354..42642ec6 100644 --- a/scan.go +++ b/scan.go @@ -69,7 +69,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { - field.Set(db.Statement.Context, reflectValue, values[idx]) + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { @@ -79,7 +79,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values relValue.Set(reflect.New(relValue.Type().Elem())) } - joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool diff --git a/schema/field.go b/schema/field.go index 96291816..3b5cc5c5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -12,6 +12,7 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -567,8 +568,8 @@ func (field *Field) setupValuerAndSetter() { if v, err = valuer.Value(); err == nil { err = setter(ctx, value, v) } - } else { - return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) } } diff --git a/statement.go b/statement.go index abf646b8..9fcee09c 100644 --- a/statement.go +++ b/statement.go @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(stmt.Context, destValue, value) + stmt.AddError(field.Set(stmt.Context, destValue, value)) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) } } else { - field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.Context, stmt.ReflectValue, value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) } } else { stmt.AddError(ErrInvalidField) diff --git a/tests/go.mod b/tests/go.mod index 17e5d350..b85ebdad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect + golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1