forked from mirror/gorm
feat: go code style adjust and optimize code for callbacks package (#4861)
* feat: go code style adjust and optimize code for callbacks package * Update scan.go
This commit is contained in:
parent
b8f33a42a4
commit
9d5f315b6d
|
@ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var (
|
var (
|
||||||
objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
|
rValLen = db.Statement.ReflectValue.Len()
|
||||||
|
objs = make([]reflect.Value, 0, rValLen)
|
||||||
fieldType = rel.Field.FieldType
|
fieldType = rel.Field.FieldType
|
||||||
isPtr = fieldType.Kind() == reflect.Ptr
|
isPtr = fieldType.Kind() == reflect.Ptr
|
||||||
)
|
)
|
||||||
|
@ -49,22 +50,21 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < rValLen; i++ {
|
||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
|
||||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
|
||||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
|
||||||
objs = append(objs, obj)
|
|
||||||
if isPtr {
|
|
||||||
elems = reflect.Append(elems, rv)
|
|
||||||
} else {
|
|
||||||
elems = reflect.Append(elems, rv.Addr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
||||||
|
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
||||||
|
objs = append(objs, obj)
|
||||||
|
if isPtr {
|
||||||
|
elems = reflect.Append(elems, rv)
|
||||||
|
} else {
|
||||||
|
elems = reflect.Append(elems, rv.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
|
|
|
@ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
|
|
||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
|
rValLen := stmt.ReflectValue.Len()
|
||||||
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
stmt.SQL.Grow(rValLen * 18)
|
||||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
values.Values = make([][]interface{}, rValLen)
|
||||||
if stmt.ReflectValue.Len() == 0 {
|
if rValLen == 0 {
|
||||||
stmt.AddError(gorm.ErrEmptySlice)
|
stmt.AddError(gorm.ErrEmptySlice)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||||
|
for i := 0; i < rValLen; i++ {
|
||||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||||
if !rv.IsValid() {
|
if !rv.IsValid() {
|
||||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||||
|
@ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, isZero := field.ValueOf(rv); !isZero {
|
if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
|
||||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||||
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
|
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||||
}
|
}
|
||||||
defaultValueFieldsHavingValue[field][i] = v
|
defaultValueFieldsHavingValue[field][i] = rvOfvalue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||||
values.Values[0] = append(values.Values[0], v)
|
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
ok, mode := hasReturning(db, supportReturning)
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if !ok {
|
||||||
gorm.Scan(rows, db, mode)
|
|
||||||
rows.Close()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
|
gorm.Scan(rows, db, mode)
|
||||||
|
rows.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||||
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
||||||
for i := 0; i < joinResults.Len(); i++ {
|
for i := 0; i < joinResults.Len(); i++ {
|
||||||
|
joinIndexValue := joinResults.Index(i)
|
||||||
for idx, field := range joinForeignFields {
|
for idx, field := range joinForeignFields {
|
||||||
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
fieldValues[idx], _ = field.ValueOf(joinIndexValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, field := range joinRelForeignFields {
|
for idx, field := range joinRelForeignFields {
|
||||||
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||||
|
|
|
@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
} else {
|
return
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,16 +7,17 @@ import (
|
||||||
func RowQuery(db *gorm.DB) {
|
func RowQuery(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
BuildQuerySQL(db)
|
BuildQuerySQL(db)
|
||||||
|
if db.DryRun {
|
||||||
if !db.DryRun {
|
return
|
||||||
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
|
||||||
db.Statement.Settings.Delete("rows")
|
|
||||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
|
||||||
} else {
|
|
||||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.RowsAffected = -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||||
|
db.Statement.Settings.Delete("rows")
|
||||||
|
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
} else {
|
||||||
|
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.RowsAffected = -1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) {
|
||||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||||
if !db.Config.SkipDefaultTransaction {
|
if !db.Config.SkipDefaultTransaction {
|
||||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||||
if db.Error == nil {
|
if db.Error != nil {
|
||||||
db.Commit()
|
|
||||||
} else {
|
|
||||||
db.Rollback()
|
db.Rollback()
|
||||||
|
} else {
|
||||||
|
db.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.ConnPool = db.ConnPool
|
db.Statement.ConnPool = db.ConnPool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||||
var primaryKeyExprs []clause.Expression
|
var primaryKeyExprs []clause.Expression
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < size; i++ {
|
||||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||||
var notZero bool
|
var notZero bool
|
||||||
for idx, field := range stmt.Schema.PrimaryFields {
|
for idx, field := range stmt.Schema.PrimaryFields {
|
||||||
|
|
|
@ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||||
return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error
|
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
||||||
|
Scan(&tableList).Error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
|
|
6
scan.go
6
scan.go
|
@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
|
||||||
type ScanMode uint8
|
type ScanMode uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ScanInitialized ScanMode = 1 << 0
|
ScanInitialized ScanMode = 1 << 0 // 1
|
||||||
ScanUpdate = 1 << 1
|
ScanUpdate ScanMode = 1 << 1 // 2
|
||||||
ScanOnConflictDoNothing = 1 << 2
|
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
||||||
|
|
Loading…
Reference in New Issue