forked from mirror/gorm
Refactor for performance
This commit is contained in:
parent
13f96f7a15
commit
aaf0725771
|
@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
||||||
return ConvertSliceOfMapToValuesForCreate(stmt, value)
|
return ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
values = clause.Values{}
|
values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))}
|
||||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
||||||
curTime = stmt.DB.NowFunc()
|
curTime = stmt.DB.NowFunc()
|
||||||
isZero = false
|
isZero = false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var columns int
|
||||||
for _, db := range stmt.Schema.DBNames {
|
for _, db := range stmt.Schema.DBNames {
|
||||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
values.Columns[columns] = clause.Column{Name: db}
|
||||||
|
columns++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
values.Columns = values.Columns[:columns]
|
||||||
|
|
||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
|
|
@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(db.Statement.Selects) > 0 {
|
if len(db.Statement.Selects) > 0 {
|
||||||
for _, name := range db.Statement.Selects {
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||||
|
for idx, name := range db.Statement.Selects {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||||
Name: name,
|
|
||||||
Raw: true,
|
|
||||||
})
|
|
||||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||||
Name: f.DBName,
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||||
Name: name,
|
|
||||||
Raw: true,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// inline joins
|
// inline joins
|
||||||
if len(db.Statement.Joins) != 0 {
|
if len(db.Statement.Joins) != 0 {
|
||||||
joins := []clause.Join{}
|
|
||||||
|
|
||||||
if len(db.Statement.Selects) == 0 {
|
if len(db.Statement.Selects) == 0 {
|
||||||
for _, dbName := range db.Statement.Schema.DBNames {
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||||
Table: db.Statement.Table,
|
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||||
Name: dbName,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
joins := []clause.Join{}
|
||||||
for name, conds := range db.Statement.Joins {
|
for name, conds := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
|
@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var exprs []clause.Expression
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
for _, ref := range relation.References {
|
for idx, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
exprs = append(exprs, clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
|
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
})
|
}
|
||||||
} else {
|
} else {
|
||||||
if ref.PrimaryValue == "" {
|
if ref.PrimaryValue == "" {
|
||||||
exprs = append(exprs, clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
|
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
})
|
}
|
||||||
} else {
|
} else {
|
||||||
exprs = append(exprs, clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
Value: ref.PrimaryValue,
|
Value: ref.PrimaryValue,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||||
if len(db.Statement.Preloads) > 0 {
|
preloadMap := map[string][]string{}
|
||||||
preloadMap := map[string][]string{}
|
for name := range db.Statement.Preloads {
|
||||||
for name := range db.Statement.Preloads {
|
preloadFields := strings.Split(name, ".")
|
||||||
preloadFields := strings.Split(name, ".")
|
for idx := range preloadFields {
|
||||||
for idx := range preloadFields {
|
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
||||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadNames := make([]string, len(preloadMap))
|
||||||
|
idx := 0
|
||||||
|
for key := range preloadMap {
|
||||||
|
preloadNames[idx] = key
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
sort.Strings(preloadNames)
|
||||||
|
|
||||||
|
for _, name := range preloadNames {
|
||||||
|
var (
|
||||||
|
curSchema = db.Statement.Schema
|
||||||
|
preloadFields = preloadMap[name]
|
||||||
|
rels = make([]*schema.Relationship, len(preloadFields))
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, preloadField := range preloadFields {
|
||||||
|
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||||
|
rels[idx] = rel
|
||||||
|
curSchema = rel.FieldSchema
|
||||||
|
} else {
|
||||||
|
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadNames := make([]string, len(preloadMap))
|
preload(db, rels, db.Statement.Preloads[name])
|
||||||
idx := 0
|
|
||||||
for key := range preloadMap {
|
|
||||||
preloadNames[idx] = key
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
sort.Strings(preloadNames)
|
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
|
||||||
var (
|
|
||||||
curSchema = db.Statement.Schema
|
|
||||||
preloadFields = preloadMap[name]
|
|
||||||
rels = make([]*schema.Relationship, len(preloadFields))
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, preloadField := range preloadFields {
|
|
||||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
|
||||||
rels[idx] = rel
|
|
||||||
curSchema = rel.FieldSchema
|
|
||||||
} else {
|
|
||||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
preload(db, rels, db.Statement.Preloads[name])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
set = make([]clause.Assignment, 0, len(value))
|
set = make([]clause.Assignment, 0, len(value))
|
||||||
|
|
||||||
var keys []string
|
keys := make([]string, 0, len(value))
|
||||||
for k := range value {
|
for k := range value {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Assignments(values map[string]interface{}) Set {
|
func Assignments(values map[string]interface{}) Set {
|
||||||
var keys []string
|
keys := make([]string, 0, len(values))
|
||||||
var assignments []Assignment
|
|
||||||
|
|
||||||
for key := range values {
|
for key := range values {
|
||||||
keys = append(keys, key)
|
keys = append(keys, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, key := range keys {
|
assignments := make([]Assignment, len(keys))
|
||||||
assignments = append(assignments, Assignment{
|
for idx, key := range keys {
|
||||||
Column: Column{Name: key},
|
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
|
||||||
Value: values[key],
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
return assignments
|
return assignments
|
||||||
}
|
}
|
||||||
|
|
79
gorm.go
79
gorm.go
|
@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
|
||||||
var (
|
|
||||||
tx = db.getInstance()
|
|
||||||
stmt = tx.Statement
|
|
||||||
modelSchema, joinSchema *schema.Schema
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := stmt.Parse(model); err == nil {
|
|
||||||
modelSchema = stmt.Schema
|
|
||||||
} else {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := stmt.Parse(joinTable); err == nil {
|
|
||||||
joinSchema = stmt.Schema
|
|
||||||
} else {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
|
||||||
for _, ref := range relation.References {
|
|
||||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
|
||||||
f.DataType = ref.ForeignKey.DataType
|
|
||||||
ref.ForeignKey = f
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
relation.JoinTable = joinSchema
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to found relation: %v", field)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Callback returns callback manager
|
// Callback returns callback manager
|
||||||
func (db *DB) Callback() *callbacks {
|
func (db *DB) Callback() *callbacks {
|
||||||
return db.callbacks
|
return db.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoMigrate run auto migration for given models
|
|
||||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
|
||||||
return db.Migrator().AutoMigrate(dst...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddError add error to db
|
// AddError add error to db
|
||||||
func (db *DB) AddError(err error) error {
|
func (db *DB) AddError(err error) error {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
|
@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB {
|
||||||
func Expr(expr string, args ...interface{}) clause.Expr {
|
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||||
return clause.Expr{SQL: expr, Vars: args}
|
return clause.Expr{SQL: expr, Vars: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||||
|
var (
|
||||||
|
tx = db.getInstance()
|
||||||
|
stmt = tx.Statement
|
||||||
|
modelSchema, joinSchema *schema.Schema
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := stmt.Parse(model); err == nil {
|
||||||
|
modelSchema = stmt.Schema
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stmt.Parse(joinTable); err == nil {
|
||||||
|
joinSchema = stmt.Schema
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||||
|
for _, ref := range relation.References {
|
||||||
|
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||||
|
f.DataType = ref.ForeignKey.DataType
|
||||||
|
ref.ForeignKey = f
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relation.JoinTable = joinSchema
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to found relation: %v", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator {
|
||||||
return db.Dialector.Migrator(db)
|
return db.Dialector.Migrator(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoMigrate run auto migration for given models
|
||||||
|
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||||
|
return db.Migrator().AutoMigrate(dst...)
|
||||||
|
}
|
||||||
|
|
||||||
// ViewOption view option
|
// ViewOption view option
|
||||||
type ViewOption struct {
|
type ViewOption struct {
|
||||||
Replace bool
|
Replace bool
|
||||||
|
|
31
scan.go
31
scan.go
|
@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
default:
|
default:
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
reflectValueType := db.Statement.ReflectValue.Type().Elem()
|
var (
|
||||||
isPtr := reflectValueType.Kind() == reflect.Ptr
|
reflectValueType = db.Statement.ReflectValue.Type().Elem()
|
||||||
|
isPtr = reflectValueType.Kind() == reflect.Ptr
|
||||||
|
fields = make([]*schema.Field, len(columns))
|
||||||
|
joinFields [][2]*schema.Field
|
||||||
|
)
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
reflectValueType = reflectValueType.Elem()
|
reflectValueType = reflectValueType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
||||||
fields := make([]*schema.Field, len(columns))
|
|
||||||
joinFields := make([][2]*schema.Field, len(columns))
|
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
|
if len(joinFields) == 0 {
|
||||||
|
joinFields = make([][2]*schema.Field, len(columns))
|
||||||
|
}
|
||||||
|
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pluck values into slice of data
|
||||||
|
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
|
||||||
elem := reflect.New(reflectValueType).Elem()
|
elem := reflect.New(reflectValueType).Elem()
|
||||||
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
if isPluck {
|
||||||
// pluck
|
db.AddError(rows.Scan(elem.Addr().Interface()))
|
||||||
values[0] = elem.Addr().Interface()
|
|
||||||
db.AddError(rows.Scan(values...))
|
|
||||||
} else {
|
} else {
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if joinFields[idx][0] != nil {
|
if len(joinFields) != 0 && joinFields[idx][0] != nil {
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
value := reflect.ValueOf(values[idx]).Elem()
|
||||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||||
|
|
||||||
|
@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteTo write quoted value to writer
|
// QuoteTo write quoted value to writer
|
||||||
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
switch v := field.(type) {
|
switch v := field.(type) {
|
||||||
case clause.Table:
|
case clause.Table:
|
||||||
if v.Name == clause.CurrentTable {
|
if v.Name == clause.CurrentTable {
|
||||||
|
@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
case []string:
|
case []string:
|
||||||
writer.WriteByte('(')
|
writer.WriteByte('(')
|
||||||
for idx, d := range v {
|
for idx, d := range v {
|
||||||
if idx != 0 {
|
if idx > 0 {
|
||||||
writer.WriteString(",")
|
writer.WriteString(",")
|
||||||
}
|
}
|
||||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||||
|
@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote returns quoted value
|
// Quote returns quoted value
|
||||||
func (stmt Statement) Quote(field interface{}) string {
|
func (stmt *Statement) Quote(field interface{}) string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
stmt.QuoteTo(&builder, field)
|
stmt.QuoteTo(&builder, field)
|
||||||
return builder.String()
|
return builder.String()
|
||||||
|
@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildCondition build condition
|
// BuildCondition build condition
|
||||||
func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||||
if sql, ok := query.(string); ok {
|
if sql, ok := query.(string); ok {
|
||||||
// if it is a number, then treats it as primary key
|
// if it is a number, then treats it as primary key
|
||||||
if _, err := strconv.Atoi(sql); err != nil {
|
if _, err := strconv.Atoi(sql); err != nil {
|
||||||
|
|
Loading…
Reference in New Issue