mirror of https://github.com/go-gorm/gorm.git
Code optimize (#4415)
* optimize gormSourceDir replace * fmt.Errorf adjust and Optimize for-break * strings trim * feat: avoid using the same name field and if..else optimization adjustment * optimization callbacks/create.go Create func if...else logic * fix: callbacks/create.go Create func * fix FileWithLineNum func and add gormSourceDir unit test * remove debug print and utils_filenum_test.go
This commit is contained in:
parent
00b252559f
commit
50e85e14d4
|
@ -26,7 +26,7 @@ func (db *DB) Association(column string) *Association {
|
|||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||
|
||||
if association.Relationship == nil {
|
||||
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
|
||||
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
|
@ -355,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||
} else if ev.Type().Elem().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev.Elem())
|
||||
} else {
|
||||
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
|
||||
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
|
||||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
|
|
10
callbacks.go
10
callbacks.go
|
@ -212,7 +212,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
|
|||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
|
@ -220,7 +220,7 @@ func (c *callback) Remove(name string) error {
|
|||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
|
@ -250,7 +250,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
|
@ -266,7 +266,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||
// if before callback exists
|
||||
|
@ -284,7 +284,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
} else if curIdx < sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
|
|
|
@ -33,75 +33,81 @@ func BeforeCreate(db *gorm.DB) {
|
|||
func Create(config *Config) func(db *gorm.DB) {
|
||||
if config.WithReturning {
|
||||
return CreateWithReturning
|
||||
} else {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
// maybe record logger TODO
|
||||
return
|
||||
}
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if db.RowsAffected > 0 {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||
if isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if !(db.RowsAffected > 0) {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||
if isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
case reflect.Struct:
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -190,16 +190,17 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
} else {
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
} else {
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize for-break
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
|
|
4
gorm.go
4
gorm.go
|
@ -409,7 +409,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||
}
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -422,7 +422,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||
|
||||
relation.JoinTable = joinSchema
|
||||
} else {
|
||||
return fmt.Errorf("failed to found relation: %v", field)
|
||||
return fmt.Errorf("failed to found relation: %s", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -119,13 +119,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
if !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -294,16 +291,20 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
|||
|
||||
func (m Migrator) AddColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
if !field.IgnoreMigration {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD ? ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return nil
|
||||
// avoid using the same name field
|
||||
f := stmt.Schema.LookUpField(field)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
|
||||
if !f.IgnoreMigration {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD ? ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
|
||||
).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -198,28 +198,28 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.DataType = Bool
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.DataType = Int
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.DataType = Uint
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.DataType = Float
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
|
@ -227,7 +227,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
|
||||
field.DefaultValueInterface = field.DefaultValue
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
@ -392,7 +392,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
|
||||
schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -423,12 +423,12 @@ func (field *Field) setupValuerAndSetter() {
|
|||
} else {
|
||||
v = v.Field(-idx - 1)
|
||||
|
||||
if v.Type().Elem().Kind() == reflect.Struct {
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
} else {
|
||||
return nil, true
|
||||
}
|
||||
if v.Type().Elem().Kind() != reflect.Struct {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
} else {
|
||||
return nil, true
|
||||
}
|
||||
|
@ -736,7 +736,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
if t, err := now.Parse(data); err == nil {
|
||||
field.ReflectValueOf(value).Set(reflect.ValueOf(t))
|
||||
} else {
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
|
@ -765,7 +765,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
}
|
||||
fieldValue.Elem().Set(reflect.ValueOf(t))
|
||||
} else {
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
|
|
|
@ -74,7 +74,9 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
|||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1)
|
||||
formattedName := strings.Replace(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_", -1)
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > 64 {
|
||||
h := sha1.New()
|
||||
|
|
|
@ -85,7 +85,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,11 +143,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
|
@ -159,7 +159,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,7 +203,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
if field := schema.LookUpField(foreignKey); field != nil {
|
||||
ownForeignFields = append(ownForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -215,7 +215,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
|
||||
refForeignFields = append(refForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -379,7 +379,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %v: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,9 +45,9 @@ type Schema struct {
|
|||
|
||||
func (schema Schema) String() string {
|
||||
if schema.ModelType.Name() == "" {
|
||||
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
|
||||
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
|
||||
}
|
||||
return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
|
||||
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
|
||||
}
|
||||
|
||||
func (schema Schema) MakeSlice() reflect.Value {
|
||||
|
@ -86,7 +86,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
|
@ -275,7 +275,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
|
|||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
|
|
|
@ -178,17 +178,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
|
|||
}
|
||||
|
||||
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
|
||||
} else {
|
||||
columns := make([]clause.Column, len(foreignKeys))
|
||||
for idx, key := range foreignKeys {
|
||||
columns[idx] = clause.Column{Table: table, Name: key}
|
||||
}
|
||||
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r
|
||||
}
|
||||
return columns, queryValues
|
||||
}
|
||||
|
||||
columns := make([]clause.Column, len(foreignKeys))
|
||||
for idx, key := range foreignKeys {
|
||||
columns[idx] = clause.Column{Table: table, Name: key}
|
||||
}
|
||||
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r
|
||||
}
|
||||
|
||||
return columns, queryValues
|
||||
}
|
||||
|
||||
type embeddedNamer struct {
|
||||
|
|
|
@ -3,8 +3,8 @@ package utils
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -15,17 +15,20 @@ var gormSourceDir string
|
|||
|
||||
func init() {
|
||||
_, file, _, _ := runtime.Caller(0)
|
||||
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "")
|
||||
// Here is the directory to get the gorm source code. Here, the filepath.Dir mode is enough,
|
||||
// and the filepath is compatible with various operating systems
|
||||
gormSourceDir = filepath.Dir(filepath.Dir(file))
|
||||
}
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
func FileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
for i := 1; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) {
|
||||
return file + ":" + strconv.FormatInt(int64(line), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue