diff --git a/association.go b/association.go index 516a8c57..aa740fc5 100644 --- a/association.go +++ b/association.go @@ -102,10 +102,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -187,18 +187,17 @@ func (association *Association) Replace(values ...interface{}) error { func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { var ( - reflectValue = association.DB.Statement.ReflectValue - rel = association.Relationship - primaryFields, foreignFields []*schema.Field - foreignKeys []string - updateAttrs = map[string]interface{}{} - conds []clause.Expression + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression ) for _, ref := range rel.References { if ref.PrimaryValue == "" { primaryFields = append(primaryFields, ref.PrimaryKey) - foreignFields = append(foreignFields, ref.ForeignKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } else { @@ -284,21 +283,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { primaryValues[idx], _ = field.ValueOf(fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()) + if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -436,12 +437,18 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if len(values) != reflectValue.Len() { if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) + if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } } } } @@ -461,12 +468,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case reflect.Struct: if clear && len(values) == 0 { - association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) - if association.Relationship.JoinTable == nil { + if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 408f3fc9..3508335a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -21,7 +21,7 @@ func SaveBeforeAssociations(db *gorm.DB) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(obj, pv) + db.AddError(ref.ForeignKey.Set(obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -121,9 +121,9 @@ func SaveAfterAssociations(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(rv, fv) + db.AddError(ref.ForeignKey.Set(rv, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(rv, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 1b06e0b7..7bd910f6 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -9,7 +9,7 @@ import ( // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { - columns := make([]string, 0, len(mapValue)) + values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) var keys []string @@ -25,7 +25,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - columns = append(columns, k) + values.Columns = append(values.Columns, clause.Column{Name: k}) values.Values[0] = append(values.Values[0], value) } } diff --git a/callbacks/interface.go b/callbacks/interface.go deleted file mode 100644 index ee0044e8..00000000 --- a/callbacks/interface.go +++ /dev/null @@ -1,11 +0,0 @@ -package callbacks - -import "gorm.io/gorm" - -type beforeSaveInterface interface { - BeforeSave(*gorm.DB) error -} - -type beforeCreateInterface interface { - BeforeCreate(*gorm.DB) error -} diff --git a/chainable_api.go b/chainable_api.go index 7ee20324..730f6308 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -41,7 +41,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile("(?i).+ AS (\\w+)\\s*$") +var tableRegexp = regexp.MustCompile(`(?i).+ AS (\w+)\s*$`) // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { diff --git a/clause/clause.go b/clause/clause.go index c7d1efeb..d413d0ee 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -18,7 +18,7 @@ type Writer interface { // Builder builder interface type Builder interface { Writer - WriteQuoted(field interface{}) error + WriteQuoted(field interface{}) AddVar(Writer, ...interface{}) } diff --git a/clause/joins.go b/clause/joins.go index 8d9055cd..f3e373f2 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -4,9 +4,9 @@ type JoinType string const ( CrossJoin JoinType = "CROSS" - InnerJoin = "INNER" - LeftJoin = "LEFT" - RightJoin = "RIGHT" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" ) // Join join clause for from diff --git a/clause/where.go b/clause/where.go index a0f4598d..9af9701c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -33,8 +33,6 @@ func (where Where) Build(builder Builder) { expr.Build(builder) } - - return } // MergeClause merge where clauses diff --git a/finisher_api.go b/finisher_api.go index d70b3cd0..6bfe5d20 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -138,11 +138,11 @@ func (tx *DB) assignExprsToValue(exprs []clause.Expression) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - field.Set(tx.Statement.ReflectValue, eq.Value) + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - field.Set(tx.Statement.ReflectValue, eq.Value) + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } default: } @@ -433,7 +433,7 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { - savePointer.SavePoint(db, name) + db.AddError(savePointer.SavePoint(db, name)) } else { db.AddError(ErrUnsupportedDriver) } @@ -442,7 +442,7 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { - savePointer.RollbackTo(db, name) + db.AddError(savePointer.RollbackTo(db, name)) } else { db.AddError(ErrUnsupportedDriver) } diff --git a/logger/logger.go b/logger/logger.go index 2a5e445c..49ae988c 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -129,7 +129,7 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel > 0 { - elapsed := time.Now().Sub(begin) + elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() diff --git a/logger/sql_test.go b/logger/sql_test.go index 8bc48116..180570b8 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,19 +31,19 @@ func TestExplainSQL(t *testing.T) { }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", - NumericRegexp: regexp.MustCompile("@p(\\d+)"), + NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", - NumericRegexp: regexp.MustCompile("\\$(\\d+)"), + NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", - NumericRegexp: regexp.MustCompile("@p(\\d+)"), + NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, diff --git a/migrator/migrator.go b/migrator/migrator.go index 169701e4..3e5d86d3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,6 +1,7 @@ package migrator import ( + "context" "database/sql" "fmt" "reflect" @@ -139,7 +140,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] - createTableSQL += fmt.Sprintf("? ?") + createTableSQL += "? ?" hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," @@ -534,7 +535,9 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } - dep.Parse(value) + if err := dep.Parse(value); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { diff --git a/schema/field.go b/schema/field.go index 3e08802a..2c43229b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -25,12 +25,12 @@ const ( const ( Bool DataType = "bool" - Int = "int" - Uint = "uint" - Float = "float" - String = "string" - Time = "time" - Bytes = "bytes" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" ) type Field struct { @@ -455,13 +455,13 @@ func (field *Field) setupValuerAndSetter() { if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - setter(value, v) + err = setter(value, v) } } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - setter(value, reflectV.Elem().Interface()) + err = setter(value, reflectV.Elem().Interface()) } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) @@ -744,7 +744,7 @@ func (field *Field) setupValuerAndSetter() { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - field.Set(value, reflectV.Elem().Interface()) + err = field.Set(value, reflectV.Elem().Interface()) } } else { fieldValue := field.ReflectValueOf(value) diff --git a/schema/relationship.go b/schema/relationship.go index e3ff0307..c290c5ba 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -71,9 +71,9 @@ func (schema *Schema) parseRelation(field *Field) { return } - if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { schema.buildPolymorphicRelation(relation, field, polymorphic) - } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { @@ -312,7 +312,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel OwnPrimaryKey: ownPriamryField, }) } - return } func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { diff --git a/statement.go b/statement.go index 142c7c31..38154939 100644 --- a/statement.go +++ b/statement.go @@ -60,9 +60,8 @@ func (stmt *Statement) WriteByte(c byte) error { } // WriteQuoted write quoted value -func (stmt *Statement) WriteQuoted(value interface{}) error { +func (stmt *Statement) WriteQuoted(value interface{}) { stmt.QuoteTo(&stmt.SQL, value) - return nil } // QuoteTo write quoted value to writer @@ -215,7 +214,7 @@ func (stmt *Statement) AddClause(v clause.Interface) { optimizer.ModifyStatement(stmt) } else { name := v.Name() - c, _ := stmt.Clauses[name] + c := stmt.Clauses[name] c.Name = name v.MergeClause(&c) stmt.Clauses[name] = c diff --git a/utils/utils.go b/utils/utils.go index 9bf00683..3d7e395b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -15,7 +15,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) - gormSourceDir = regexp.MustCompile("utils.utils\\.go").ReplaceAllString(file, "") + gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") } func FileWithLineNum() string {