forked from mirror/gorm
commit
9d57c6b961
|
@ -0,0 +1,2 @@
|
||||||
|
documents
|
||||||
|
_book
|
598
association.go
598
association.go
|
@ -4,32 +4,289 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Association Mode contains some helper methods to handle relationship things easily.
|
||||||
type Association struct {
|
type Association struct {
|
||||||
Scope *Scope
|
|
||||||
Column string
|
|
||||||
Error error
|
Error error
|
||||||
Field *Field
|
scope *Scope
|
||||||
|
column string
|
||||||
|
field *Field
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) setErr(err error) *Association {
|
// Find find out all related associations
|
||||||
if err != nil {
|
func (association *Association) Find(value interface{}) *Association {
|
||||||
association.Error = err
|
association.scope.related(value, association.column)
|
||||||
|
return association.setErr(association.scope.db.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||||
|
func (association *Association) Append(values ...interface{}) *Association {
|
||||||
|
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||||
|
return association.Replace(values...)
|
||||||
|
}
|
||||||
|
return association.saveAssociations(values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace replace current associations with new one
|
||||||
|
func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
|
var (
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field.Field
|
||||||
|
newDB = scope.NewDB()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Append new values
|
||||||
|
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||||
|
association.saveAssociations(values...)
|
||||||
|
|
||||||
|
// Belongs To
|
||||||
|
if relationship.Kind == "belongs_to" {
|
||||||
|
// Set foreign key to be null when clearing value (length equals 0)
|
||||||
|
if len(values) == 0 {
|
||||||
|
// Set foreign key to be nil
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
}
|
||||||
|
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Polymorphic Relations
|
||||||
|
if relationship.PolymorphicDBName != "" {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete Relations except new created
|
||||||
|
if len(values) > 0 {
|
||||||
|
var associationForeignFieldNames []string
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// if many to many relations, get association fields name from association foreign keys
|
||||||
|
associationScope := scope.New(reflect.New(field.Type()).Interface())
|
||||||
|
for _, dbName := range relationship.AssociationForeignFieldNames {
|
||||||
|
if field, ok := associationScope.FieldByName(dbName); ok {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If other relations, use primary keys
|
||||||
|
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
|
||||||
|
|
||||||
|
if len(newPrimaryKeys) > 0 {
|
||||||
|
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||||
|
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// if many to many relations, delete related relations from join table
|
||||||
|
var sourceForeignFieldNames []string
|
||||||
|
|
||||||
|
for _, dbName := range relationship.ForeignFieldNames {
|
||||||
|
if field, ok := scope.FieldByName(dbName); ok {
|
||||||
|
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||||
|
|
||||||
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||||
|
}
|
||||||
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
|
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Find(value interface{}) *Association {
|
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||||
association.Scope.related(value, association.Column)
|
func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
return association.setErr(association.Scope.db.Error)
|
var (
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field.Field
|
||||||
|
newDB = scope.NewDB()
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(values) == 0 {
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||||
|
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||||
|
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||||
|
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// source value's foreign keys
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get association's foreign fields name
|
||||||
|
var associationScope = scope.New(reflect.New(field.Type()).Interface())
|
||||||
|
var associationForeignFieldNames []string
|
||||||
|
for _, associationDBName := range relationship.AssociationForeignFieldNames {
|
||||||
|
if field, ok := associationScope.FieldByName(associationDBName); ok {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// association value's foreign keys
|
||||||
|
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
|
||||||
|
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||||
|
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||||
|
|
||||||
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||||
|
} else {
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "belongs_to" {
|
||||||
|
// find with deleting relation's foreign keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// set foreign key to be null if there are some records affected
|
||||||
|
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||||
|
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||||
|
if results.RowsAffected > 0 {
|
||||||
|
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
association.setErr(results.Error)
|
||||||
|
}
|
||||||
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
|
// find all relations
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// only include those deleting relations
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||||
|
toQueryValues(deletingPrimaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// set matched relation's foreign key to be null
|
||||||
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove deleted records from source's field
|
||||||
|
if association.Error == nil {
|
||||||
|
if field.Kind() == reflect.Slice {
|
||||||
|
leftValues := reflect.Zero(field.Type())
|
||||||
|
|
||||||
|
for i := 0; i < field.Len(); i++ {
|
||||||
|
reflectValue := field.Index(i)
|
||||||
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||||
|
var isDeleted = false
|
||||||
|
for _, pk := range deletingPrimaryKeys {
|
||||||
|
if equalAsString(primaryKey, pk) {
|
||||||
|
isDeleted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isDeleted {
|
||||||
|
leftValues = reflect.Append(leftValues, reflectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
association.field.Set(leftValues)
|
||||||
|
} else if field.Kind() == reflect.Struct {
|
||||||
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||||
|
for _, pk := range deletingPrimaryKeys {
|
||||||
|
if equalAsString(primaryKey, pk) {
|
||||||
|
association.field.Set(reflect.Zero(field.Type()))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear remove relationship between source & current associations, won't delete those associations
|
||||||
|
func (association *Association) Clear() *Association {
|
||||||
|
return association.Replace()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count return the count of current associations
|
||||||
|
func (association *Association) Count() int {
|
||||||
|
var (
|
||||||
|
count = 0
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
fieldValue = association.field.Field.Interface()
|
||||||
|
query = scope.DB()
|
||||||
|
)
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||||
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
} else if relationship.Kind == "belongs_to" {
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.PolymorphicType != "" {
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
|
||||||
|
scope.TableName(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Model(fieldValue).Count(&count)
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveAssociations save passed values as associations
|
||||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||||
scope := association.Scope
|
var (
|
||||||
field := association.Field
|
scope = association.scope
|
||||||
relationship := association.Field.Relationship
|
field = association.field
|
||||||
|
relationship = field.Relationship
|
||||||
|
)
|
||||||
|
|
||||||
saveAssociation := func(reflectValue reflect.Value) {
|
saveAssociation := func(reflectValue reflect.Value) {
|
||||||
// value has to been pointer
|
// value has to been pointer
|
||||||
|
@ -94,318 +351,9 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
|
||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Append(values ...interface{}) *Association {
|
func (association *Association) setErr(err error) *Association {
|
||||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
if err != nil {
|
||||||
return association.Replace(values...)
|
association.Error = err
|
||||||
}
|
|
||||||
return association.saveAssociations(values...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Replace(values ...interface{}) *Association {
|
|
||||||
var (
|
|
||||||
relationship = association.Field.Relationship
|
|
||||||
scope = association.Scope
|
|
||||||
field = association.Field.Field
|
|
||||||
newDB = scope.NewDB()
|
|
||||||
)
|
|
||||||
|
|
||||||
// Append new values
|
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
|
||||||
association.saveAssociations(values...)
|
|
||||||
|
|
||||||
// Belongs To
|
|
||||||
if relationship.Kind == "belongs_to" {
|
|
||||||
// Set foreign key to be null only when clearing value
|
|
||||||
if len(values) == 0 {
|
|
||||||
// Set foreign key to be nil
|
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
foreignKeyMap[foreignKey] = nil
|
|
||||||
}
|
|
||||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Relations
|
|
||||||
if relationship.PolymorphicDBName != "" {
|
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Relations except new created
|
|
||||||
if len(values) > 0 {
|
|
||||||
var newPrimaryKeys [][]interface{}
|
|
||||||
var associationForeignFieldNames []string
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
// If many to many relations, get it from foreign key
|
|
||||||
associationForeignFieldNames = relationship.AssociationForeignFieldNames
|
|
||||||
} else {
|
|
||||||
// If other relations, get real primary keys
|
|
||||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
|
||||||
|
|
||||||
if len(newPrimaryKeys) > 0 {
|
|
||||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
|
||||||
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
|
||||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
foreignKeyMap[foreignKey] = nil
|
|
||||||
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
|
||||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) Delete(values ...interface{}) *Association {
|
|
||||||
var (
|
|
||||||
relationship = association.Field.Relationship
|
|
||||||
scope = association.Scope
|
|
||||||
field = association.Field.Field
|
|
||||||
newDB = scope.NewDB()
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(values) == 0 {
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
|
||||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
|
||||||
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...)
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
// source value's foreign keys
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// association value's foreign keys
|
|
||||||
deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
|
|
||||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
|
||||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
|
||||||
} else {
|
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
foreignKeyMap[foreignKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.Kind == "belongs_to" {
|
|
||||||
// find with deleting relation's foreign keys
|
|
||||||
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
|
|
||||||
newDB = newDB.Where(
|
|
||||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
|
||||||
toQueryValues(primaryKeys)...,
|
|
||||||
)
|
|
||||||
|
|
||||||
// set foreign key to be null
|
|
||||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
|
||||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
|
||||||
if results.RowsAffected > 0 {
|
|
||||||
scope.updatedAttrsWithValues(foreignKeyMap, false)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
association.setErr(results.Error)
|
|
||||||
}
|
|
||||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
|
||||||
// find all relations
|
|
||||||
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value)
|
|
||||||
newDB = newDB.Where(
|
|
||||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
|
||||||
toQueryValues(primaryKeys)...,
|
|
||||||
)
|
|
||||||
|
|
||||||
// only include those deleting relations
|
|
||||||
newDB = newDB.Where(
|
|
||||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
|
||||||
toQueryValues(deletingPrimaryKeys)...,
|
|
||||||
)
|
|
||||||
|
|
||||||
// set matched relation's foreign key to be null
|
|
||||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
|
||||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove deleted records from field
|
|
||||||
if association.Error == nil {
|
|
||||||
if association.Field.Field.Kind() == reflect.Slice {
|
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
|
||||||
|
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
|
||||||
reflectValue := association.Field.Field.Index(i)
|
|
||||||
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
|
||||||
var included = false
|
|
||||||
for _, pk := range deletingPrimaryKeys {
|
|
||||||
if equalAsString(primaryKey, pk) {
|
|
||||||
included = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !included {
|
|
||||||
leftValues = reflect.Append(leftValues, reflectValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
association.Field.Set(leftValues)
|
|
||||||
} else if association.Field.Field.Kind() == reflect.Struct {
|
|
||||||
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
|
|
||||||
for _, pk := range deletingPrimaryKeys {
|
|
||||||
if equalAsString(primaryKey, pk) {
|
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Clear() *Association {
|
|
||||||
return association.Replace()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Count() int {
|
|
||||||
var (
|
|
||||||
count = 0
|
|
||||||
relationship = association.Field.Relationship
|
|
||||||
scope = association.Scope
|
|
||||||
fieldValue = association.Field.Field.Interface()
|
|
||||||
newScope = scope.New(fieldValue)
|
|
||||||
)
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
|
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
|
||||||
query := scope.DB()
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
|
||||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
|
|
||||||
field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.PolymorphicType != "" {
|
|
||||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
|
||||||
}
|
|
||||||
query.Model(fieldValue).Count(&count)
|
|
||||||
} else if relationship.Kind == "belongs_to" {
|
|
||||||
query := scope.DB()
|
|
||||||
for idx, primaryKey := range relationship.AssociationForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
|
|
||||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),
|
|
||||||
field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
query.Model(fieldValue).Count(&count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
|
||||||
scope := association.Scope
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
|
||||||
if reflectValue.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
|
||||||
primaryKeys := []interface{}{}
|
|
||||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
|
||||||
for _, column := range columns {
|
|
||||||
if field, ok := newScope.FieldByName(column); ok {
|
|
||||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
|
||||||
} else {
|
|
||||||
primaryKeys = append(primaryKeys, "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
results = append(results, primaryKeys)
|
|
||||||
}
|
|
||||||
} else if reflectValue.Kind() == reflect.Struct {
|
|
||||||
newScope := scope.New(value)
|
|
||||||
var primaryKeys []interface{}
|
|
||||||
for _, column := range columns {
|
|
||||||
if field, ok := newScope.FieldByName(column); ok {
|
|
||||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
|
||||||
} else {
|
|
||||||
primaryKeys = append(primaryKeys, "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, primaryKeys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
|
||||||
var results []string
|
|
||||||
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
var marks []string
|
|
||||||
for _, _ = range primaryValue {
|
|
||||||
marks = append(marks, "?")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(marks) > 1 {
|
|
||||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
|
||||||
} else {
|
|
||||||
results = append(results, strings.Join(marks, ""))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(results, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryCondition(scope *Scope, columns []string) string {
|
|
||||||
var newColumns []string
|
|
||||||
for _, column := range columns {
|
|
||||||
newColumns = append(newColumns, scope.Quote(column))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) > 1 {
|
|
||||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
|
||||||
} else {
|
|
||||||
return strings.Join(newColumns, ",")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
for _, value := range primaryValue {
|
|
||||||
values = append(values, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBelongsTo(t *testing.T) {
|
func TestBelongsTo(t *testing.T) {
|
||||||
|
@ -16,7 +18,7 @@ func TestBelongsTo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&post).Error; err != nil {
|
if err := DB.Save(&post).Error; err != nil {
|
||||||
t.Errorf("Got errors when save post", err.Error())
|
t.Error("Got errors when save post", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||||
|
@ -177,6 +179,49 @@ func TestBelongsTo(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBelongsToOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
|
||||||
|
ProfileRefer int
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "belongs_to" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
|
||||||
|
ProfileID int
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "belongs_to" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasOne(t *testing.T) {
|
func TestHasOne(t *testing.T) {
|
||||||
user := User{
|
user := User{
|
||||||
Name: "has one",
|
Name: "has one",
|
||||||
|
@ -184,7 +229,7 @@ func TestHasOne(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&user).Error; err != nil {
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
t.Errorf("Got errors when save user", err.Error())
|
t.Error("Got errors when save user", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.CreditCard.UserId.Int64 == 0 {
|
if user.CreditCard.UserId.Int64 == 0 {
|
||||||
|
@ -323,6 +368,49 @@ func TestHasOne(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasOneOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:UserRefer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_one" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOneOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_one" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasMany(t *testing.T) {
|
func TestHasMany(t *testing.T) {
|
||||||
post := Post{
|
post := Post{
|
||||||
Title: "post has many",
|
Title: "post has many",
|
||||||
|
@ -331,7 +419,7 @@ func TestHasMany(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&post).Error; err != nil {
|
if err := DB.Save(&post).Error; err != nil {
|
||||||
t.Errorf("Got errors when save post", err.Error())
|
t.Error("Got errors when save post", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, comment := range post.Comments {
|
for _, comment := range post.Comments {
|
||||||
|
@ -462,6 +550,49 @@ func TestHasMany(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasManyOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile []Profile `gorm:"ForeignKey:UserRefer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_many" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasManyOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_many" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManyToMany(t *testing.T) {
|
func TestManyToMany(t *testing.T) {
|
||||||
DB.Raw("delete from languages")
|
DB.Raw("delete from languages")
|
||||||
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
||||||
|
|
210
callback.go
210
callback.go
|
@ -4,34 +4,39 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type callback struct {
|
// DefaultCallback default callbacks defined by gorm
|
||||||
|
var DefaultCallback = &Callback{}
|
||||||
|
|
||||||
|
// Callback is a struct that contains all CURD callbacks
|
||||||
|
// Field `creates` contains callbacks will be call when creating object
|
||||||
|
// Field `updates` contains callbacks will be call when updating object
|
||||||
|
// Field `deletes` contains callbacks will be call when deleting object
|
||||||
|
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||||
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
|
type Callback struct {
|
||||||
creates []*func(scope *Scope)
|
creates []*func(scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(scope *Scope)
|
||||||
queries []*func(scope *Scope)
|
queries []*func(scope *Scope)
|
||||||
rowQueries []*func(scope *Scope)
|
rowQueries []*func(scope *Scope)
|
||||||
processors []*callbackProcessor
|
processors []*CallbackProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
type callbackProcessor struct {
|
// CallbackProcessor contains callback informations
|
||||||
name string
|
type CallbackProcessor struct {
|
||||||
before string
|
name string // current callback's name
|
||||||
after string
|
before string // register current callback before a callback
|
||||||
replace bool
|
after string // register current callback after a callback
|
||||||
remove bool
|
replace bool // replace callbacks with same name
|
||||||
typ string
|
remove bool // delete callbacks with same name
|
||||||
processor *func(scope *Scope)
|
kind string // callback type: create, update, delete, query, row_query
|
||||||
callback *callback
|
processor *func(scope *Scope) // callback handler
|
||||||
|
parent *Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) addProcessor(typ string) *callbackProcessor {
|
func (c *Callback) clone() *Callback {
|
||||||
cp := &callbackProcessor{typ: typ, callback: c}
|
return &Callback{
|
||||||
c.processors = append(c.processors, cp)
|
|
||||||
return cp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callback) clone() *callback {
|
|
||||||
return &callback{
|
|
||||||
creates: c.creates,
|
creates: c.creates,
|
||||||
updates: c.updates,
|
updates: c.updates,
|
||||||
deletes: c.deletes,
|
deletes: c.deletes,
|
||||||
|
@ -40,57 +45,95 @@ func (c *callback) clone() *callback {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Create() *callbackProcessor {
|
// Create could be used to register callbacks for creating object
|
||||||
return c.addProcessor("create")
|
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||||
|
// // business logic
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// // set error if some thing wrong happened, will rollback the creating
|
||||||
|
// scope.Err(errors.New("error"))
|
||||||
|
// })
|
||||||
|
func (c *Callback) Create() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "create", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Update() *callbackProcessor {
|
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||||
return c.addProcessor("update")
|
func (c *Callback) Update() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "update", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Delete() *callbackProcessor {
|
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||||
return c.addProcessor("delete")
|
func (c *Callback) Delete() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "delete", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Query() *callbackProcessor {
|
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||||
return c.addProcessor("query")
|
// Refer `Create` for usage
|
||||||
|
func (c *Callback) Query() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) RowQuery() *callbackProcessor {
|
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||||
return c.addProcessor("row_query")
|
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "row_query", parent: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
|
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||||
cp.before = name
|
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||||
|
cp.after = callbackName
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) After(name string) *callbackProcessor {
|
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||||
cp.after = name
|
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||||
|
cp.before = callbackName
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
cp.name = name
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
cp.processor = &fc
|
cp.name = callbackName
|
||||||
cp.callback.sort()
|
cp.processor = &callback
|
||||||
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Remove(name string) {
|
// Remove a registered callback
|
||||||
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
cp.name = name
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
|
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.callback.sort()
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
|
// Replace a registered callback with new callback
|
||||||
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||||
cp.name = name
|
// scope.SetColumn("Created", now)
|
||||||
cp.processor = &fc
|
// scope.SetColumn("Updated", now)
|
||||||
|
// })
|
||||||
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
|
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
cp.callback.sort()
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get registered callback
|
||||||
|
// db.Callback().Create().Get("gorm:create")
|
||||||
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
|
for _, p := range cp.parent.processors {
|
||||||
|
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||||
|
return *p.processor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRIndex get right index from string slice
|
||||||
func getRIndex(strs []string, str string) int {
|
func getRIndex(strs []string, str string) int {
|
||||||
for i := len(strs) - 1; i >= 0; i-- {
|
for i := len(strs) - 1; i >= 0; i-- {
|
||||||
if strs[i] == str {
|
if strs[i] == str {
|
||||||
|
@ -100,83 +143,77 @@ func getRIndex(strs []string, str string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||||
var sortCallbackProcessor func(c *callbackProcessor)
|
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||||
var names, sortedNames = []string{}, []string{}
|
var (
|
||||||
|
allNames, sortedNames []string
|
||||||
|
sortCallbackProcessor func(c *CallbackProcessor)
|
||||||
|
)
|
||||||
|
|
||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
if index := getRIndex(names, cp.name); index > -1 {
|
// show warning message the callback name already exists
|
||||||
if !cp.replace && !cp.remove {
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||||
}
|
}
|
||||||
}
|
allNames = append(allNames, cp.name)
|
||||||
names = append(names, cp.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sortCallbackProcessor = func(c *callbackProcessor) {
|
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||||
if getRIndex(sortedNames, c.name) > -1 {
|
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
|
||||||
return
|
if c.before != "" { // if defined before callback
|
||||||
}
|
if index := getRIndex(sortedNames, c.before); index != -1 {
|
||||||
|
// if before callback already sorted, append current callback just after it
|
||||||
if len(c.before) > 0 {
|
|
||||||
if index := getRIndex(sortedNames, c.before); index > -1 {
|
|
||||||
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||||
} else if index := getRIndex(names, c.before); index > -1 {
|
} else if index := getRIndex(allNames, c.before); index != -1 {
|
||||||
|
// if before callback exists but haven't sorted, append current callback to last
|
||||||
sortedNames = append(sortedNames, c.name)
|
sortedNames = append(sortedNames, c.name)
|
||||||
sortCallbackProcessor(cps[index])
|
sortCallbackProcessor(cps[index])
|
||||||
} else {
|
|
||||||
sortedNames = append(sortedNames, c.name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.after) > 0 {
|
if c.after != "" { // if defined after callback
|
||||||
if index := getRIndex(sortedNames, c.after); index > -1 {
|
if index := getRIndex(sortedNames, c.after); index != -1 {
|
||||||
|
// if after callback already sorted, append current callback just before it
|
||||||
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||||
} else if index := getRIndex(names, c.after); index > -1 {
|
} else if index := getRIndex(allNames, c.after); index != -1 {
|
||||||
|
// if after callback exists but haven't sorted
|
||||||
cp := cps[index]
|
cp := cps[index]
|
||||||
if len(cp.before) == 0 {
|
// set after callback's before callback to current callback
|
||||||
|
if cp.before == "" {
|
||||||
cp.before = c.name
|
cp.before = c.name
|
||||||
}
|
}
|
||||||
sortCallbackProcessor(cp)
|
sortCallbackProcessor(cp)
|
||||||
} else {
|
|
||||||
sortedNames = append(sortedNames, c.name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if current callback haven't been sorted, append it to last
|
||||||
if getRIndex(sortedNames, c.name) == -1 {
|
if getRIndex(sortedNames, c.name) == -1 {
|
||||||
sortedNames = append(sortedNames, c.name)
|
sortedNames = append(sortedNames, c.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
sortCallbackProcessor(cp)
|
sortCallbackProcessor(cp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var funcs = []*func(scope *Scope){}
|
var sortedFuncs []*func(scope *Scope)
|
||||||
var sortedFuncs = []*func(scope *Scope){}
|
|
||||||
for _, name := range sortedNames {
|
for _, name := range sortedNames {
|
||||||
index := getRIndex(names, name)
|
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||||
if !cps[index].remove {
|
|
||||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cp := range cps {
|
return sortedFuncs
|
||||||
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
|
|
||||||
if !cp.remove {
|
|
||||||
funcs = append(funcs, cp.processor)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(sortedFuncs, funcs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) sort() {
|
// reorder all registered processors, and reset CURD callbacks
|
||||||
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
|
func (c *Callback) reorder() {
|
||||||
|
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||||
|
|
||||||
for _, processor := range c.processors {
|
for _, processor := range c.processors {
|
||||||
switch processor.typ {
|
if processor.name != "" {
|
||||||
|
switch processor.kind {
|
||||||
case "create":
|
case "create":
|
||||||
creates = append(creates, processor)
|
creates = append(creates, processor)
|
||||||
case "update":
|
case "update":
|
||||||
|
@ -189,6 +226,7 @@ func (c *callback) sort() {
|
||||||
rowQueries = append(rowQueries, processor)
|
rowQueries = append(rowQueries, processor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.creates = sortProcessors(creates)
|
c.creates = sortProcessors(creates)
|
||||||
c.updates = sortProcessors(updates)
|
c.updates = sortProcessors(updates)
|
||||||
|
@ -196,5 +234,3 @@ func (c *callback) sort() {
|
||||||
c.queries = sortProcessors(queries)
|
c.queries = sortProcessors(queries)
|
||||||
c.rowQueries = sortProcessors(rowQueries)
|
c.rowQueries = sortProcessors(rowQueries)
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
|
||||||
|
|
|
@ -5,12 +5,31 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeCreate(scope *Scope) {
|
// Define callbacks for creating
|
||||||
scope.CallMethodWithErrorCheck("BeforeSave")
|
func init() {
|
||||||
scope.CallMethodWithErrorCheck("BeforeCreate")
|
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTimeStampWhenCreate(scope *Scope) {
|
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||||
|
func beforeCreateCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeSave")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeCreate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||||
|
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
now := NowFunc()
|
now := NowFunc()
|
||||||
scope.SetColumn("CreatedAt", now)
|
scope.SetColumn("CreatedAt", now)
|
||||||
|
@ -18,109 +37,108 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(scope *Scope) {
|
// createCallback the callback used to insert data into database
|
||||||
defer scope.Trace(NowFunc())
|
func createCallback(scope *Scope) {
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
// set create sql
|
defer scope.trace(NowFunc())
|
||||||
var sqls, columns []string
|
|
||||||
fields := scope.Fields()
|
var (
|
||||||
for _, field := range fields {
|
columns, placeholders []string
|
||||||
|
blankColumnsWithDefaultValue []string
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
if !field.IsPrimaryKey || !field.IsBlank {
|
||||||
if !field.IsBlank || !field.HasDefaultValue {
|
if field.IsBlank && field.HasDefaultValue {
|
||||||
|
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
|
||||||
|
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||||
|
} else {
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||||
} else if field.HasDefaultValue {
|
|
||||||
var hasDefaultValueColumns []string
|
|
||||||
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
|
||||||
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
|
|
||||||
}
|
|
||||||
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
|
|
||||||
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||||
for _, dbName := range relationship.ForeignDBNames {
|
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||||
if relationField := fields[dbName]; !scope.changeableField(relationField) {
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
columns = append(columns, scope.Quote(relationField.DBName))
|
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
returningKey := "*"
|
var (
|
||||||
primaryField := scope.PrimaryField()
|
returningColumn = "*"
|
||||||
if primaryField != nil {
|
quotedTableName = scope.QuotedTableName()
|
||||||
returningKey = scope.Quote(primaryField.DBName)
|
primaryField = scope.PrimaryField()
|
||||||
|
extraOption string
|
||||||
|
)
|
||||||
|
|
||||||
|
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if primaryField != nil {
|
||||||
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
|
scope.Raw(fmt.Sprintf(
|
||||||
scope.QuotedTableName(),
|
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
quotedTableName,
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(placeholders, ","),
|
||||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql
|
||||||
if scope.Dialect().SupportLastInsertId() {
|
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
id, err := result.LastInsertId()
|
// set rows affected count
|
||||||
if scope.Err(err) == nil {
|
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
// set primary value to primary field
|
||||||
if primaryField != nil && primaryField.IsBlank {
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
scope.Err(scope.SetColumn(primaryField, id))
|
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||||
|
scope.Err(primaryField.Set(primaryValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if primaryField == nil {
|
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
|
||||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
|
||||||
} else {
|
|
||||||
scope.Err(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
|
||||||
scope.db.RowsAffected = 1
|
scope.db.RowsAffected = 1
|
||||||
} else {
|
|
||||||
scope.Err(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ForceReloadAfterCreate(scope *Scope) {
|
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||||
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||||
scope.DB().New().Select(columns.([]string)).First(scope.Value)
|
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||||
|
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterCreate(scope *Scope) {
|
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||||
scope.CallMethodWithErrorCheck("AfterCreate")
|
func afterCreateCallback(scope *Scope) {
|
||||||
scope.CallMethodWithErrorCheck("AfterSave")
|
if !scope.HasError() {
|
||||||
}
|
scope.CallMethod("AfterCreate")
|
||||||
|
}
|
||||||
func init() {
|
if !scope.HasError() {
|
||||||
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
|
scope.CallMethod("AfterSave")
|
||||||
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
|
}
|
||||||
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
|
||||||
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
|
|
||||||
DefaultCallback.Create().Register("gorm:create", Create)
|
|
||||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
|
|
||||||
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
|
|
||||||
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
|
|
||||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,35 +2,52 @@ package gorm
|
||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
func BeforeDelete(scope *Scope) {
|
// Define callbacks for deleting
|
||||||
scope.CallMethodWithErrorCheck("BeforeDelete")
|
func init() {
|
||||||
|
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Delete(scope *Scope) {
|
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||||
|
func beforeDeleteCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeDelete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||||
|
func deleteCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
var extraOption string
|
||||||
|
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
|
}
|
||||||
|
|
||||||
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
|
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
|
||||||
scope.Raw(
|
scope.Raw(fmt.Sprintf(
|
||||||
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
|
"UPDATE %v SET deleted_at=%v%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
scope.AddToVars(NowFunc()),
|
scope.AddToVars(NowFunc()),
|
||||||
scope.CombinedConditionSql(),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
))
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
)).Exec()
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"DELETE FROM %v%v%v",
|
||||||
|
scope.QuotedTableName(),
|
||||||
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
)).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Exec()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterDelete(scope *Scope) {
|
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||||
scope.CallMethodWithErrorCheck("AfterDelete")
|
func afterDeleteCallback(scope *Scope) {
|
||||||
}
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterDelete")
|
||||||
func init() {
|
}
|
||||||
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
|
|
||||||
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
|
|
||||||
DefaultCallback.Delete().Register("gorm:delete", Delete)
|
|
||||||
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
|
|
||||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,115 +6,89 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(scope *Scope) {
|
// Define callbacks for querying
|
||||||
defer scope.Trace(NowFunc())
|
func init() {
|
||||||
|
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||||
|
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||||
|
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryCallback used to query data from database
|
||||||
|
func queryCallback(scope *Scope) {
|
||||||
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice bool
|
isSlice bool
|
||||||
isPtr bool
|
isPtr bool
|
||||||
anyRecordFound bool
|
results = scope.IndirectValue()
|
||||||
destType reflect.Type
|
resultType reflect.Type
|
||||||
)
|
)
|
||||||
|
|
||||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||||
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
|
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
|
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var dest = scope.IndirectValue()
|
|
||||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
results = reflect.Indirect(reflect.ValueOf(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
if kind := dest.Kind(); kind == reflect.Slice {
|
if kind := results.Kind(); kind == reflect.Slice {
|
||||||
isSlice = true
|
isSlice = true
|
||||||
destType = dest.Type().Elem()
|
resultType = results.Type().Elem()
|
||||||
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
|
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||||
|
|
||||||
if destType.Kind() == reflect.Ptr {
|
if resultType.Kind() == reflect.Ptr {
|
||||||
isPtr = true
|
isPtr = true
|
||||||
destType = destType.Elem()
|
resultType = resultType.Elem()
|
||||||
}
|
}
|
||||||
} else if kind != reflect.Struct {
|
} else if kind != reflect.Struct {
|
||||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
|
||||||
scope.db.RowsAffected = 0
|
scope.db.RowsAffected = 0
|
||||||
|
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||||
if scope.Err(err) != nil {
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
scope.db.RowsAffected++
|
scope.db.RowsAffected++
|
||||||
|
|
||||||
anyRecordFound = true
|
elem := results
|
||||||
elem := dest
|
|
||||||
if isSlice {
|
if isSlice {
|
||||||
elem = reflect.New(destType).Elem()
|
elem = reflect.New(resultType).Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
var values = make([]interface{}, len(columns))
|
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
|
||||||
|
|
||||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
|
||||||
|
|
||||||
for index, column := range columns {
|
|
||||||
if field, ok := fields[column]; ok {
|
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
|
||||||
values[index] = field.Field.Addr().Interface()
|
|
||||||
} else {
|
|
||||||
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
|
|
||||||
reflectValue.Elem().Set(field.Field.Addr())
|
|
||||||
values[index] = reflectValue.Interface()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var value interface{}
|
|
||||||
values[index] = &value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Err(rows.Scan(values...))
|
|
||||||
|
|
||||||
for index, column := range columns {
|
|
||||||
value := values[index]
|
|
||||||
if field, ok := fields[column]; ok {
|
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
|
||||||
field.Field.Set(reflect.ValueOf(value).Elem())
|
|
||||||
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
|
||||||
field.Field.Set(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isSlice {
|
if isSlice {
|
||||||
if isPtr {
|
if isPtr {
|
||||||
dest.Set(reflect.Append(dest, elem.Addr()))
|
results.Set(reflect.Append(results, elem.Addr()))
|
||||||
} else {
|
} else {
|
||||||
dest.Set(reflect.Append(dest, elem))
|
results.Set(reflect.Append(results, elem))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !anyRecordFound && !isSlice {
|
if scope.db.RowsAffected == 0 && !isSlice {
|
||||||
scope.Err(RecordNotFound)
|
scope.Err(ErrRecordNotFound)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterQuery(scope *Scope) {
|
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||||
scope.CallMethodWithErrorCheck("AfterFind")
|
func afterQueryCallback(scope *Scope) {
|
||||||
}
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterFind")
|
||||||
func init() {
|
}
|
||||||
DefaultCallback.Query().Register("gorm:query", Query)
|
|
||||||
DefaultCallback.Query().Register("gorm:preload", Preload)
|
|
||||||
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,308 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// preloadCallback used to preload associations
|
||||||
|
func preloadCallback(scope *Scope) {
|
||||||
|
if scope.Search.preload == nil || scope.HasError() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
preloadedMap = map[string]bool{}
|
||||||
|
fields = scope.Fields()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, preload := range scope.Search.preload {
|
||||||
|
var (
|
||||||
|
preloadFields = strings.Split(preload.schema, ".")
|
||||||
|
currentScope = scope
|
||||||
|
currentFields = fields
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, preloadField := range preloadFields {
|
||||||
|
var currentPreloadConditions []interface{}
|
||||||
|
|
||||||
|
// if not preloaded
|
||||||
|
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||||
|
|
||||||
|
// assign search conditions to last preload
|
||||||
|
if idx == len(preloadFields)-1 {
|
||||||
|
currentPreloadConditions = preload.conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range currentFields {
|
||||||
|
if field.Name != preloadField || field.Relationship == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch field.Relationship.Kind {
|
||||||
|
case "has_one":
|
||||||
|
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||||
|
case "has_many":
|
||||||
|
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||||
|
case "belongs_to":
|
||||||
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||||
|
case "many_to_many":
|
||||||
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||||
|
default:
|
||||||
|
scope.Err(errors.New("unsupported relation"))
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadedMap[preloadKey] = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !preloadedMap[preloadKey] {
|
||||||
|
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload next level
|
||||||
|
if idx < len(preloadFields)-1 {
|
||||||
|
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||||
|
currentFields = currentScope.Fields()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||||
|
var (
|
||||||
|
preloadDB = scope.NewDB()
|
||||||
|
preloadConditions []interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, condition := range conditions {
|
||||||
|
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||||
|
preloadDB = scopes(preloadDB)
|
||||||
|
} else {
|
||||||
|
preloadConditions = append(preloadConditions, condition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return preloadDB, preloadConditions
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHasOnePreload used to preload has one associations
|
||||||
|
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||||
|
indirectValue.FieldByName(field.Name).Set(result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(field.Set(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHasManyPreload used to preload has many associations
|
||||||
|
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
|
||||||
|
objectField := object.FieldByName(field.Name)
|
||||||
|
objectField.Set(reflect.Append(objectField, result))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(field.Set(resultsValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleBelongsToPreload used to preload belongs to associations
|
||||||
|
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||||
|
object.FieldByName(field.Name).Set(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(field.Set(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleManyToManyPreload used to preload many to many associations
|
||||||
|
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
var (
|
||||||
|
relation = field.Relationship
|
||||||
|
joinTableHandler = relation.JoinTableHandler
|
||||||
|
fieldType = field.Struct.Type.Elem()
|
||||||
|
foreignKeyValue interface{}
|
||||||
|
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||||
|
linkHash = map[string][]reflect.Value{}
|
||||||
|
isPtr bool
|
||||||
|
)
|
||||||
|
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourceKeys = []string{}
|
||||||
|
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||||
|
sourceKeys = append(sourceKeys, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// generate query with join table
|
||||||
|
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||||
|
preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
|
||||||
|
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||||
|
|
||||||
|
// preload inline conditions
|
||||||
|
if len(preloadConditions) > 0 {
|
||||||
|
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := preloadDB.Rows()
|
||||||
|
|
||||||
|
if scope.Err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
elem = reflect.New(fieldType).Elem()
|
||||||
|
fields = scope.New(elem.Addr().Interface()).fieldsMap()
|
||||||
|
)
|
||||||
|
|
||||||
|
// register foreign keys in join tables
|
||||||
|
for _, sourceKey := range sourceKeys {
|
||||||
|
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.scan(rows, columns, fields)
|
||||||
|
|
||||||
|
// generate hashed forkey keys in join table
|
||||||
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||||
|
for idx, sourceKey := range sourceKeys {
|
||||||
|
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
|
||||||
|
}
|
||||||
|
hashedSourceKeys := toString(foreignKeys)
|
||||||
|
|
||||||
|
if isPtr {
|
||||||
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||||
|
} else {
|
||||||
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
fieldsSourceMap = map[string]reflect.Value{}
|
||||||
|
foreignFieldNames = []string{}
|
||||||
|
fields = scope.fieldsMap()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, dbName := range relation.ForeignFieldNames {
|
||||||
|
if field, ok := fields[dbName]; ok {
|
||||||
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
||||||
|
}
|
||||||
|
} else if indirectScopeValue.IsValid() {
|
||||||
|
fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for source, link := range linkHash {
|
||||||
|
fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,15 +2,15 @@ package gorm
|
||||||
|
|
||||||
import "reflect"
|
import "reflect"
|
||||||
|
|
||||||
func BeginTransaction(scope *Scope) {
|
func beginTransactionCallback(scope *Scope) {
|
||||||
scope.Begin()
|
scope.Begin()
|
||||||
}
|
}
|
||||||
|
|
||||||
func CommitOrRollbackTransaction(scope *Scope) {
|
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||||
scope.CommitOrRollback()
|
scope.CommitOrRollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveBeforeAssociations(scope *Scope) {
|
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||||
if !scope.shouldSaveAssociations() {
|
if !scope.shouldSaveAssociations() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(scope *Scope) {
|
func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
if !scope.shouldSaveAssociations() {
|
if !scope.shouldSaveAssociations() {
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {}
|
||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callback = &Callback{}
|
||||||
|
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("before_create2", beforeCreate2)
|
callback.Create().Register("before_create2", beforeCreate2)
|
||||||
|
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
var callback1 = &Callback{}
|
||||||
callback1.Create().Register("before_create1", beforeCreate1)
|
callback1.Create().Register("before_create1", beforeCreate1)
|
||||||
callback1.Create().Register("create", create)
|
callback1.Create().Register("create", create)
|
||||||
callback1.Create().Register("after_create1", afterCreate1)
|
callback1.Create().Register("after_create1", afterCreate1)
|
||||||
|
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
var callback2 = &Callback{}
|
||||||
|
|
||||||
callback2.Update().Register("create", create)
|
callback2.Update().Register("create", create)
|
||||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
|
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
var callback1 = &Callback{}
|
||||||
|
|
||||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback1.Query().Register("before_create1", beforeCreate1)
|
callback1.Query().Register("before_create1", beforeCreate1)
|
||||||
|
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
var callback2 = &Callback{}
|
||||||
|
|
||||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
|
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callback = &Callback{}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveCallback(t *testing.T) {
|
func TestRemoveCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callback = &Callback{}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
|
|
@ -5,91 +5,102 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AssignUpdateAttributes(scope *Scope) {
|
// Define callbacks for updating
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||||
|
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||||
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
||||||
_, updateColumn := scope.Get("gorm:update_column")
|
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||||
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
} else {
|
||||||
|
|
||||||
if updateColumn {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", maps)
|
|
||||||
} else if len(updateAttrs) > 0 {
|
|
||||||
scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
|
||||||
} else if !hasUpdate {
|
|
||||||
scope.SkipLeft()
|
scope.SkipLeft()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BeforeUpdate(scope *Scope) {
|
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||||
|
func beforeUpdateCallback(scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.CallMethodWithErrorCheck("BeforeSave")
|
if !scope.HasError() {
|
||||||
scope.CallMethodWithErrorCheck("BeforeUpdate")
|
scope.CallMethod("BeforeSave")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeUpdate")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||||
|
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.SetColumn("UpdatedAt", NowFunc())
|
scope.SetColumn("UpdatedAt", NowFunc())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(scope *Scope) {
|
// updateCallback the callback used to update data to database
|
||||||
|
func updateCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||||
if scope.changeableDBColumn(key) {
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fields := scope.Fields()
|
for _, field := range scope.Fields() {
|
||||||
for _, field := range fields {
|
if scope.changeableField(field) {
|
||||||
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
|
if !field.IsPrimaryKey && field.IsNormal {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
for _, dbName := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
|
sqls = append(sqls,
|
||||||
sqls = append(sqls, sql)
|
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraOption string
|
||||||
|
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
|
}
|
||||||
|
|
||||||
if len(sqls) > 0 {
|
if len(sqls) > 0 {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"UPDATE %v SET %v %v",
|
"UPDATE %v SET %v%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(sqls, ", "),
|
strings.Join(sqls, ", "),
|
||||||
scope.CombinedConditionSql(),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
))
|
addExtraSpaceIfExist(extraOption),
|
||||||
scope.Exec()
|
)).Exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterUpdate(scope *Scope) {
|
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||||
|
func afterUpdateCallback(scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.CallMethodWithErrorCheck("AfterUpdate")
|
if !scope.HasError() {
|
||||||
scope.CallMethodWithErrorCheck("AfterSave")
|
scope.CallMethod("AfterUpdate")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterSave")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
|
|
||||||
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
|
|
||||||
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
|
|
||||||
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
|
||||||
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
|
|
||||||
DefaultCallback.Update().Register("gorm:update", Update)
|
|
||||||
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
|
|
||||||
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
|
|
||||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,117 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type commonDialect struct{}
|
|
||||||
|
|
||||||
func (commonDialect) BinVar(i int) string {
|
|
||||||
return "$$" // ?
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) SupportLastInsertId() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) HasTop() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "BOOLEAN"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "INTEGER AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "INTEGER"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "BIGINT AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "BIGINT"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "FLOAT"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("VARCHAR(%d)", size)
|
|
||||||
}
|
|
||||||
return "VARCHAR(65532)"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "TIMESTAMP"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := value.Interface().([]byte); ok {
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("BINARY(%d)", size)
|
|
||||||
}
|
|
||||||
return "BINARY(65532)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) ReturningStr(tableName, key string) string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) SelectFromDummyTable() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) Quote(key string) string {
|
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var (
|
|
||||||
count int
|
|
||||||
databaseName = c.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var (
|
|
||||||
count int
|
|
||||||
databaseName = c.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var (
|
|
||||||
count int
|
|
||||||
databaseName = c.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RawScanInt scans the first column of the first row into the `scan' int pointer.
|
|
||||||
// This function captures raw query errors and propagates them to the original scope.
|
|
||||||
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
|
|
||||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
|
||||||
}
|
|
||||||
|
|
||||||
// RawScanString scans the first column of the first row into the `scan' string pointer.
|
|
||||||
// This function captures raw query errors and propagates them to the original scope.
|
|
||||||
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
|
|
||||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
|
||||||
DB.AutoMigrate(&CustomizeColumn{})
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
scope := DB.NewScope(&CustomizeColumn{})
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
|
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||||
t.Errorf("CustomizeColumn should have column %s", col)
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
DB.HasTable("foobarbaz")
|
if err := DB.Find(&User{}).Error; err == nil {
|
||||||
if DB.Error == nil {
|
|
||||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ func TestSoftDelete(t *testing.T) {
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64
|
Id int64
|
||||||
Name string
|
Name string
|
||||||
DeletedAt time.Time
|
DeletedAt *time.Time
|
||||||
}
|
}
|
||||||
DB.AutoMigrate(&User{})
|
DB.AutoMigrate(&User{})
|
||||||
|
|
||||||
|
|
115
dialect.go
115
dialect.go
|
@ -1,41 +1,100 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Dialect interface contains behaviors that differ across SQL database
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
BinVar(i int) string
|
// GetName get dialect's name
|
||||||
SupportLastInsertId() bool
|
GetName() string
|
||||||
HasTop() bool
|
|
||||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
// SetDB set db for dialect
|
||||||
ReturningStr(tableName, key string) string
|
SetDB(db *sql.DB)
|
||||||
SelectFromDummyTable() string
|
|
||||||
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
|
BindVar(i int) string
|
||||||
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
HasTable(scope *Scope, tableName string) bool
|
// DataTypeOf return data's sql type
|
||||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
DataTypeOf(field *StructField) string
|
||||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
|
||||||
RemoveIndex(scope *Scope, indexName string)
|
// HasIndex check has index or not
|
||||||
CurrentDatabase(scope *Scope) string
|
HasIndex(tableName string, indexName string) bool
|
||||||
|
// HasForeignKey check has foreign key or not
|
||||||
|
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||||
|
// RemoveIndex remove index
|
||||||
|
RemoveIndex(tableName string, indexName string) error
|
||||||
|
// HasTable check has table or not
|
||||||
|
HasTable(tableName string) bool
|
||||||
|
// HasColumn check has column or not
|
||||||
|
HasColumn(tableName string, columnName string) bool
|
||||||
|
|
||||||
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
|
LimitAndOffsetSQL(limit, offset int) string
|
||||||
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
|
SelectFromDummyTable() string
|
||||||
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDialect(driver string) Dialect {
|
var dialectsMap = map[string]Dialect{}
|
||||||
var d Dialect
|
|
||||||
switch driver {
|
func newDialect(name string, db *sql.DB) Dialect {
|
||||||
case "postgres":
|
if value, ok := dialectsMap[name]; ok {
|
||||||
d = &postgres{}
|
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||||
case "foundation":
|
dialect.SetDB(db)
|
||||||
d = &foundation{}
|
return dialect
|
||||||
case "mysql":
|
|
||||||
d = &mysql{}
|
|
||||||
case "sqlite3":
|
|
||||||
d = &sqlite3{}
|
|
||||||
case "mssql":
|
|
||||||
d = &mssql{}
|
|
||||||
default:
|
|
||||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
|
|
||||||
d = &commonDialect{}
|
|
||||||
}
|
}
|
||||||
return d
|
|
||||||
|
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||||
|
commontDialect := &commonDialect{}
|
||||||
|
commontDialect.SetDB(db)
|
||||||
|
return commontDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDialect register new dialect
|
||||||
|
func RegisterDialect(name string, dialect Dialect) {
|
||||||
|
dialectsMap[name] = dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFieldStructForDialect parse field struct for dialect
|
||||||
|
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||||
|
// Get redirected field type
|
||||||
|
var reflectType = field.Struct.Type
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get redirected field value
|
||||||
|
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||||
|
|
||||||
|
// Get scanner's real value
|
||||||
|
var getScannerValue func(reflect.Value)
|
||||||
|
getScannerValue = func(value reflect.Value) {
|
||||||
|
fieldValue = value
|
||||||
|
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||||
|
getScannerValue(fieldValue.Field(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getScannerValue(fieldValue)
|
||||||
|
|
||||||
|
// Default Size
|
||||||
|
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||||
|
size, _ = strconv.Atoi(num)
|
||||||
|
} else {
|
||||||
|
size = 255
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default type from tag setting
|
||||||
|
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||||
|
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||||
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type commonDialect struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("common", &commonDialect{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) GetName() string {
|
||||||
|
return "common"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *commonDialect) SetDB(db *sql.DB) {
|
||||||
|
s.db = db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) BindVar(i int) string {
|
||||||
|
return "$$" // ?
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) Quote(key string) string {
|
||||||
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "BOOLEAN"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
sqlType = "INTEGER AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "INTEGER"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
sqlType = "BIGINT AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "BIGINT"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "FLOAT"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "VARCHAR(65532)"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "TIMESTAMP"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("BINARY(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "BINARY(65532)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) currentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||||
|
if limit > 0 || offset > 0 {
|
||||||
|
if limit >= 0 {
|
||||||
|
sql += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
}
|
||||||
|
if offset >= 0 {
|
||||||
|
sql += fmt.Sprintf(" OFFSET %d", offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) SelectFromDummyTable() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mysql struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("mysql", &mysql{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) GetName() string {
|
||||||
|
return "mysql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) Quote(key string) string {
|
||||||
|
return fmt.Sprintf("`%s`", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Data Type for MySQL Dialect
|
||||||
|
func (mysql) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "int AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "int"
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "int unsigned AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "int unsigned"
|
||||||
|
}
|
||||||
|
case reflect.Int64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "bigint AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint unsigned"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "double"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "longtext"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
||||||
|
sqlType = "timestamp"
|
||||||
|
} else {
|
||||||
|
sqlType = "timestamp NULL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "longblob"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), foreignKeyName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) currentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) SelectFromDummyTable() string {
|
||||||
|
return "FROM DUAL"
|
||||||
|
}
|
|
@ -0,0 +1,128 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgres struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("postgres", &postgres{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) GetName() string {
|
||||||
|
return "postgres"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) BindVar(i int) string {
|
||||||
|
return fmt.Sprintf("$%v", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "serial"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "bigserial"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "numeric"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "timestamp with time zone"
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
if dataValue.Type().Name() == "Hstore" {
|
||||||
|
sqlType = "hstore"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if isByteArrayOrSlice(dataValue) {
|
||||||
|
sqlType = "bytea"
|
||||||
|
} else if isUUID(dataValue) {
|
||||||
|
sqlType = "uuid"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) currentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) SupportLastInsertID() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||||
|
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUUID(value reflect.Value) bool {
|
||||||
|
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
typename := value.Type().Name()
|
||||||
|
lower := strings.ToLower(typename)
|
||||||
|
return "uuid" == lower || "guid" == lower
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sqlite3 struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("sqlite", &sqlite3{})
|
||||||
|
RegisterDialect("sqlite3", &sqlite3{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sqlite3) GetName() string {
|
||||||
|
return "sqlite3"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Data Type for Sqlite Dialect
|
||||||
|
func (sqlite3) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "bool"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "real"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "datetime"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
sqlType = "blob"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) currentDatabase() (name string) {
|
||||||
|
var (
|
||||||
|
ifaces = make([]interface{}, 3)
|
||||||
|
pointers = make([]*string, 3)
|
||||||
|
i int
|
||||||
|
)
|
||||||
|
for i = 0; i < 3; i++ {
|
||||||
|
ifaces[i] = &pointers[i]
|
||||||
|
}
|
||||||
|
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pointers[1] != nil {
|
||||||
|
name = *pointers[1]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setIdentityInsert(scope *gorm.Scope) {
|
||||||
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
||||||
|
gorm.RegisterDialect("mssql", &mssql{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type mssql struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) GetName() string {
|
||||||
|
return "mssql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mssql) SetDB(db *sql.DB) {
|
||||||
|
s.db = db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) BindVar(i int) string {
|
||||||
|
return "$$" // ?
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) Quote(key string) string {
|
||||||
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) DataTypeOf(field *gorm.StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "bit"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "int IDENTITY(1,1)"
|
||||||
|
} else {
|
||||||
|
sqlType = "int"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "bigint IDENTITY(1,1)"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "float"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("nvarchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "datetime2"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) currentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||||
|
if limit > 0 || offset > 0 {
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
|
||||||
|
|
||||||
|
if limit >= 0 {
|
||||||
|
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) SelectFromDummyTable() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import _ "github.com/go-sql-driver/mysql"
|
|
@ -0,0 +1,52 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/lib/pq/hstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Hstore map[string]*string
|
||||||
|
|
||||||
|
func (h Hstore) Value() (driver.Value, error) {
|
||||||
|
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||||
|
if len(h) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range h {
|
||||||
|
var s sql.NullString
|
||||||
|
if value != nil {
|
||||||
|
s.String = *value
|
||||||
|
s.Valid = true
|
||||||
|
}
|
||||||
|
hstore.Map[key] = s
|
||||||
|
}
|
||||||
|
return hstore.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hstore) Scan(value interface{}) error {
|
||||||
|
hstore := hstore.Hstore{}
|
||||||
|
|
||||||
|
if err := hstore.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hstore.Map) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*h = Hstore{}
|
||||||
|
for k := range hstore.Map {
|
||||||
|
if hstore.Map[k].Valid {
|
||||||
|
s := hstore.Map[k].String
|
||||||
|
(*h)[k] = &s
|
||||||
|
} else {
|
||||||
|
(*h)[k] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import _ "github.com/mattn/go-sqlite3"
|
|
@ -1,68 +0,0 @@
|
||||||
# Gorm Development
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
|
|
||||||
|
|
||||||
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
|
||||||
|
|
||||||
Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
|
|
||||||
|
|
||||||
Lets use below code to explain how it works:
|
|
||||||
|
|
||||||
db.Where("name = ?", "jinzhu").Find(&users)
|
|
||||||
|
|
||||||
// equivalent code
|
|
||||||
newdb := db.Where("name =?", "jinzhu")
|
|
||||||
newdb.Find(&user)
|
|
||||||
|
|
||||||
`newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
|
|
||||||
`Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
|
|
||||||
|
|
||||||
There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
|
|
||||||
|
|
||||||
## Callbacks
|
|
||||||
|
|
||||||
### Register a new callback
|
|
||||||
|
|
||||||
func updateCreated(scope *Scope) {
|
|
||||||
if scope.HasColumn("Created") {
|
|
||||||
scope.SetColumn("Created", NowFunc())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Callback().Create().Register("update_created_at", updateCreated)
|
|
||||||
// register a callback for Create process
|
|
||||||
|
|
||||||
### Delete an existing callback
|
|
||||||
|
|
||||||
db.Callback().Create().Remove("gorm:create")
|
|
||||||
// delete callback `gorm:create` from Create callbacks
|
|
||||||
|
|
||||||
### Replace an existing callback
|
|
||||||
|
|
||||||
db.Callback().Create().Replace("gorm:create", newCreateFunction)
|
|
||||||
// replace callback `gorm:create` with new function `newCreateFunction` for Create process
|
|
||||||
|
|
||||||
### Register callback orders
|
|
||||||
|
|
||||||
db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
|
|
||||||
db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
|
|
||||||
db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
|
|
||||||
db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
|
|
||||||
db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
|
|
||||||
db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
|
|
||||||
|
|
||||||
### Callback API
|
|
||||||
|
|
||||||
Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
|
|
||||||
|
|
||||||
[Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
|
|
||||||
|
|
||||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
|
||||||
|
|
||||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
|
|
||||||
|
|
||||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
|
||||||
|
|
||||||
View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API
|
|
17
errors.go
17
errors.go
|
@ -6,25 +6,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
RecordNotFound = errors.New("record not found")
|
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
||||||
InvalidSql = errors.New("invalid sql")
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
NoNewAttrs = errors.New("no new attributes")
|
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||||
NoValidTransaction = errors.New("no valid transaction")
|
ErrInvalidSQL = errors.New("invalid SQL")
|
||||||
CantStartTransaction = errors.New("can't start transaction")
|
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||||
|
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||||
|
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||||
|
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||||
)
|
)
|
||||||
|
|
||||||
type errorsInterface interface {
|
type errorsInterface interface {
|
||||||
GetErrors() []error
|
GetErrors() []error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Errors contains all happened errors
|
||||||
type Errors struct {
|
type Errors struct {
|
||||||
errors []error
|
errors []error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetErrors get all happened errors
|
||||||
func (errs Errors) GetErrors() []error {
|
func (errs Errors) GetErrors() []error {
|
||||||
return errs.errors
|
return errs.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add add an error
|
||||||
func (errs *Errors) Add(err error) {
|
func (errs *Errors) Add(err error) {
|
||||||
if errors, ok := err.(errorsInterface); ok {
|
if errors, ok := err.(errorsInterface); ok {
|
||||||
for _, err := range errors.GetErrors() {
|
for _, err := range errors.GetErrors() {
|
||||||
|
@ -40,6 +46,7 @@ func (errs *Errors) Add(err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Error format happened errors
|
||||||
func (errs Errors) Error() string {
|
func (errs Errors) Error() string {
|
||||||
var errors = []string{}
|
var errors = []string{}
|
||||||
for _, e := range errs.errors {
|
for _, e := range errs.errors {
|
||||||
|
|
55
field.go
55
field.go
|
@ -7,12 +7,14 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Field model field definition
|
||||||
type Field struct {
|
type Field struct {
|
||||||
*StructField
|
*StructField
|
||||||
IsBlank bool
|
IsBlank bool
|
||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set set a value to the field
|
||||||
func (field *Field) Set(value interface{}) (err error) {
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
if !field.Field.IsValid() {
|
if !field.Field.IsValid() {
|
||||||
return errors.New("field value not valid")
|
return errors.New("field value not valid")
|
||||||
|
@ -56,35 +58,34 @@ func (field *Field) Set(value interface{}) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fields get value's fields
|
// Fields get value's fields
|
||||||
func (scope *Scope) Fields() map[string]*Field {
|
func (scope *Scope) Fields() []*Field {
|
||||||
if scope.fields == nil {
|
var (
|
||||||
fields := map[string]*Field{}
|
fields []*Field
|
||||||
modelStruct := scope.GetModelStruct()
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||||
|
)
|
||||||
|
|
||||||
indirectValue := scope.IndirectValue()
|
for _, structField := range scope.GetModelStruct().StructFields {
|
||||||
isStruct := indirectValue.Kind() == reflect.Struct
|
|
||||||
for _, structField := range modelStruct.StructFields {
|
|
||||||
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
|
|
||||||
if isStruct {
|
if isStruct {
|
||||||
fields[structField.DBName] = getField(indirectValue, structField)
|
fieldValue := indirectScopeValue
|
||||||
} else {
|
|
||||||
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.fields = fields
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
return scope.fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
|
||||||
field := &Field{StructField: structField}
|
|
||||||
for _, name := range structField.Names {
|
for _, name := range structField.Names {
|
||||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||||
}
|
}
|
||||||
field.Field = indirectValue
|
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
||||||
field.IsBlank = isBlank(indirectValue)
|
} else {
|
||||||
return field
|
fields = append(fields, &Field{StructField: structField, IsBlank: true})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) fieldsMap() map[string]*Field {
|
||||||
|
var results = map[string]*Field{}
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if field.IsNormal {
|
||||||
|
results[field.DBName] = field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,12 +32,16 @@ type CalculateFieldCategory struct {
|
||||||
|
|
||||||
func TestCalculateField(t *testing.T) {
|
func TestCalculateField(t *testing.T) {
|
||||||
var field CalculateField
|
var field CalculateField
|
||||||
fields := DB.NewScope(&field).Fields()
|
var scope = DB.NewScope(&field)
|
||||||
if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
|
if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
|
||||||
t.Errorf("Should calculate fields correctly for the first time")
|
t.Errorf("Should calculate fields correctly for the first time")
|
||||||
}
|
}
|
||||||
|
|
||||||
if field, ok := fields["embedded_name"]; !ok {
|
if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
|
||||||
|
t.Errorf("Should calculate fields correctly for the first time")
|
||||||
|
}
|
||||||
|
|
||||||
|
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||||
t.Errorf("should find embedded field")
|
t.Errorf("should find embedded field")
|
||||||
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
||||||
t.Errorf("should find embedded field's tag settings")
|
t.Errorf("should find embedded field's tag settings")
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type foundation struct {
|
|
||||||
commonDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (foundation) BinVar(i int) string {
|
|
||||||
return fmt.Sprintf("$%v", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (foundation) SupportLastInsertId() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "serial"
|
|
||||||
}
|
|
||||||
return "int"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "bigserial"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "double"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "clob"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "datetime"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := value.Interface().([]byte); ok {
|
|
||||||
return "blob"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) ReturningStr(tableName, key string) string {
|
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s foundation) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
|
|
||||||
return
|
|
||||||
}
|
|
Binary file not shown.
Before Width: | Height: | Size: 65 KiB |
|
@ -7,40 +7,54 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||||
type JoinTableHandlerInterface interface {
|
type JoinTableHandlerInterface interface {
|
||||||
|
// initialize join table handler
|
||||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||||
|
// Table return join table's table name
|
||||||
Table(db *DB) string
|
Table(db *DB) string
|
||||||
|
// Add create relationship in join table for source and destination
|
||||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
SourceForeignKeys() []JoinTableForeignKey
|
SourceForeignKeys() []JoinTableForeignKey
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
DestinationForeignKeys() []JoinTableForeignKey
|
DestinationForeignKeys() []JoinTableForeignKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableForeignKey join table foreign key struct
|
||||||
type JoinTableForeignKey struct {
|
type JoinTableForeignKey struct {
|
||||||
DBName string
|
DBName string
|
||||||
AssociationDBName string
|
AssociationDBName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableSource is a struct that contains model type and foreign keys
|
||||||
type JoinTableSource struct {
|
type JoinTableSource struct {
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
ForeignKeys []JoinTableForeignKey
|
ForeignKeys []JoinTableForeignKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableHandler default join table handler
|
||||||
type JoinTableHandler struct {
|
type JoinTableHandler struct {
|
||||||
TableName string `sql:"-"`
|
TableName string `sql:"-"`
|
||||||
Source JoinTableSource `sql:"-"`
|
Source JoinTableSource `sql:"-"`
|
||||||
Destination JoinTableSource `sql:"-"`
|
Destination JoinTableSource `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||||
return s.Source.ForeignKeys
|
return s.Source.ForeignKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||||
return s.Destination.ForeignKeys
|
return s.Destination.ForeignKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup initialize a default join table handler
|
||||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||||
s.TableName = tableName
|
s.TableName = tableName
|
||||||
|
|
||||||
|
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Table return join table's table name
|
||||||
func (s JoinTableHandler) Table(db *DB) string {
|
func (s JoinTableHandler) Table(db *DB) string {
|
||||||
return s.TableName
|
return s.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
values := map[string]interface{}{}
|
values := map[string]interface{}{}
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
|
@ -74,20 +89,25 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
|
||||||
|
|
||||||
if s.Source.ModelType == modelType {
|
if s.Source.ModelType == modelType {
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
values[foreignKey.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if s.Destination.ModelType == modelType {
|
} else if s.Destination.ModelType == modelType {
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
values[foreignKey.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
|
// Add create relationship in join table for source and destination
|
||||||
|
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||||
scope := db.NewScope("")
|
scope := db.NewScope("")
|
||||||
searchMap := s.GetSearchMap(db, source1, source2)
|
searchMap := s.getSearchMap(db, source, destination)
|
||||||
|
|
||||||
var assignColumns, binVars, conditions []string
|
var assignColumns, binVars, conditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
|
@ -116,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
|
||||||
return db.Exec(sql, values...).Error
|
return db.Exec(sql, values...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||||
var (
|
var (
|
||||||
scope = db.NewScope(nil)
|
scope = db.NewScope(nil)
|
||||||
|
@ -123,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
values []interface{}
|
values []interface{}
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value := range s.GetSearchMap(db, sources...) {
|
for key, value := range s.getSearchMap(db, sources...) {
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
@ -131,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||||
var (
|
var (
|
||||||
scope = db.NewScope(source)
|
scope = db.NewScope(source)
|
||||||
|
@ -151,10 +173,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
|
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||||
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
|
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||||
|
|
||||||
var condString string
|
var condString string
|
||||||
if len(foreignFieldValues) > 0 {
|
if len(foreignFieldValues) > 0 {
|
||||||
|
@ -165,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
|
|
||||||
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
||||||
|
|
||||||
keys := scope.getColumnAsArray(foreignFieldNames)
|
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||||
values = append(values, toQueryValues(keys))
|
values = append(values, toQueryValues(keys))
|
||||||
} else {
|
} else {
|
||||||
condString = fmt.Sprintf("1 <> 1")
|
condString = fmt.Sprintf("1 <> 1")
|
||||||
|
@ -173,8 +197,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
||||||
|
|
||||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||||
Where(condString, toQueryValues(foreignFieldValues)...)
|
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||||
} else {
|
}
|
||||||
|
|
||||||
db.Error = errors.New("wrong source type for join table handler")
|
db.Error = errors.New("wrong source type for join table handler")
|
||||||
return db
|
return db
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ type PersonAddress struct {
|
||||||
gorm.JoinTableHandler
|
gorm.JoinTableHandler
|
||||||
PersonID int
|
PersonID int
|
||||||
AddressID int
|
AddressID int
|
||||||
DeletedAt time.Time
|
DeletedAt *time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
57
logger.go
57
logger.go
|
@ -8,25 +8,28 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||||
|
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type logger interface {
|
type logger interface {
|
||||||
Print(v ...interface{})
|
Print(v ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogWriter interface {
|
type logWriter interface {
|
||||||
Println(v ...interface{})
|
Println(v ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger default logger
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
LogWriter
|
logWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
// Print format & print log
|
||||||
|
|
||||||
// Format log
|
|
||||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
|
||||||
|
|
||||||
func (logger Logger) Print(values ...interface{}) {
|
func (logger Logger) Print(values ...interface{}) {
|
||||||
if len(values) > 1 {
|
if len(values) > 1 {
|
||||||
level := values[0]
|
level := values[0]
|
||||||
|
@ -38,29 +41,44 @@ func (logger Logger) Print(values ...interface{}) {
|
||||||
// duration
|
// duration
|
||||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||||
// sql
|
// sql
|
||||||
var formatedValues []interface{}
|
var sql string
|
||||||
|
var formattedValues []string
|
||||||
|
|
||||||
for _, value := range values[4].([]interface{}) {
|
for _, value := range values[4].([]interface{}) {
|
||||||
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
if indirectValue.IsValid() {
|
if indirectValue.IsValid() {
|
||||||
value = indirectValue.Interface()
|
value = indirectValue.Interface()
|
||||||
if t, ok := value.(time.Time); ok {
|
if t, ok := value.(time.Time); ok {
|
||||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
|
||||||
} else if b, ok := value.([]byte); ok {
|
} else if b, ok := value.([]byte); ok {
|
||||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
|
if str := string(b); isPrintable(str) {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, "'<binary>'")
|
||||||
|
}
|
||||||
} else if r, ok := value.(driver.Valuer); ok {
|
} else if r, ok := value.(driver.Valuer); ok {
|
||||||
if value, err := r.Value(); err == nil && value != nil {
|
if value, err := r.Value(); err == nil && value != nil {
|
||||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
} else {
|
} else {
|
||||||
formatedValues = append(formatedValues, "NULL")
|
formattedValues = append(formattedValues, "NULL")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
|
|
||||||
|
var formattedValuesLength = len(formattedValues)
|
||||||
|
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
||||||
|
sql += value
|
||||||
|
if index < formattedValuesLength {
|
||||||
|
sql += formattedValues[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, sql)
|
||||||
} else {
|
} else {
|
||||||
messages = append(messages, "\033[31;1m")
|
messages = append(messages, "\033[31;1m")
|
||||||
messages = append(messages, values[2:]...)
|
messages = append(messages, values[2:]...)
|
||||||
|
@ -69,3 +87,12 @@ func (logger Logger) Print(values ...interface{}) {
|
||||||
logger.Println(messages...)
|
logger.Println(messages...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPrintable(s string) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if !unicode.IsPrint(r) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
279
main.go
279
main.go
|
@ -6,24 +6,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NowFunc returns current time, this function is exported in order to be able
|
// DB contains information for current db connection
|
||||||
// to give the flexibility to the developer to customize it according to their
|
|
||||||
// needs
|
|
||||||
//
|
|
||||||
// e.g: return time.Now().UTC()
|
|
||||||
//
|
|
||||||
var NowFunc = func() time.Time {
|
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
callback *callback
|
callbacks *Callback
|
||||||
db sqlCommon
|
db sqlCommon
|
||||||
parent *DB
|
parent *DB
|
||||||
search *search
|
search *search
|
||||||
|
@ -36,7 +26,18 @@ type DB struct {
|
||||||
joinTableHandlers map[string]JoinTableHandler
|
joinTableHandlers map[string]JoinTableHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(dialect string, args ...interface{}) (DB, error) {
|
// Open initialize a new db connection, need to import driver first, e.g:
|
||||||
|
//
|
||||||
|
// import _ "github.com/go-sql-driver/mysql"
|
||||||
|
// func main() {
|
||||||
|
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||||
|
// }
|
||||||
|
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
|
||||||
|
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
|
||||||
|
func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
var db DB
|
var db DB
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
err = errors.New("invalid database source")
|
err = errors.New("invalid database source")
|
||||||
} else {
|
} else {
|
||||||
var source string
|
var source string
|
||||||
var dbSql sqlCommon
|
var dbSQL sqlCommon
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -55,22 +56,19 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
driver = value
|
driver = value
|
||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
if driver == "foundation" {
|
dbSQL, err = sql.Open(driver, source)
|
||||||
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
|
|
||||||
}
|
|
||||||
dbSql, err = sql.Open(driver, source)
|
|
||||||
case sqlCommon:
|
case sqlCommon:
|
||||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||||
dbSql = value
|
dbSQL = value
|
||||||
}
|
}
|
||||||
|
|
||||||
db = DB{
|
db = DB{
|
||||||
dialect: NewDialect(dialect),
|
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
callback: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
source: source,
|
source: source,
|
||||||
values: map[string]interface{}{},
|
values: map[string]interface{}{},
|
||||||
db: dbSql,
|
db: dbSQL,
|
||||||
}
|
}
|
||||||
db.parent = &db
|
db.parent = &db
|
||||||
|
|
||||||
|
@ -79,17 +77,20 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return db, err
|
return &db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close close current db connection
|
||||||
func (s *DB) Close() error {
|
func (s *DB) Close() error {
|
||||||
return s.parent.db.(*sql.DB).Close()
|
return s.parent.db.(*sql.DB).Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB get `*sql.DB` from current connection
|
||||||
func (s *DB) DB() *sql.DB {
|
func (s *DB) DB() *sql.DB {
|
||||||
return s.db.(*sql.DB)
|
return s.db.(*sql.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New clone a new db connection without search conditions
|
||||||
func (s *DB) New() *DB {
|
func (s *DB) New() *DB {
|
||||||
clone := s.clone()
|
clone := s.clone()
|
||||||
clone.search = nil
|
clone.search = nil
|
||||||
|
@ -97,29 +98,32 @@ func (s *DB) New() *DB {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewScope create scope for callbacks, including DB's search information
|
// NewScope create a scope for current operation
|
||||||
func (db *DB) NewScope(value interface{}) *Scope {
|
func (s *DB) NewScope(value interface{}) *Scope {
|
||||||
dbClone := db.clone()
|
dbClone := s.clone()
|
||||||
dbClone.Value = value
|
dbClone.Value = value
|
||||||
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||||
// Use of this method is discouraged. It's mainly intended to allow
|
|
||||||
// coexistence with legacy non-GORM code.
|
|
||||||
func (s *DB) CommonDB() sqlCommon {
|
func (s *DB) CommonDB() sqlCommon {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Callback() *callback {
|
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||||
s.parent.callback = s.parent.callback.clone()
|
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||||
return s.parent.callback
|
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||||
|
func (s *DB) Callback() *Callback {
|
||||||
|
s.parent.callbacks = s.parent.callbacks.clone()
|
||||||
|
return s.parent.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SetLogger(l logger) {
|
// SetLogger replace default logger
|
||||||
s.logger = l
|
func (s *DB) SetLogger(log logger) {
|
||||||
|
s.logger = log
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||||
func (s *DB) LogMode(enable bool) *DB {
|
func (s *DB) LogMode(enable bool) *DB {
|
||||||
if enable {
|
if enable {
|
||||||
s.logMode = 2
|
s.logMode = 2
|
||||||
|
@ -129,55 +133,82 @@ func (s *DB) LogMode(enable bool) *DB {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SingularTable use singular table by default
|
||||||
func (s *DB) SingularTable(enable bool) {
|
func (s *DB) SingularTable(enable bool) {
|
||||||
modelStructsMap = newModelStructsMap()
|
modelStructsMap = newModelStructsMap()
|
||||||
s.parent.singularTable = enable
|
s.parent.singularTable = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
|
||||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Where(query, args...).db
|
return s.clone().search.Where(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Or filter records that match before conditions or this one, similar to `Where`
|
||||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Or(query, args...).db
|
return s.clone().search.Or(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not filter records that don't match current conditions, similar to `Where`
|
||||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Not(query, args...).db
|
return s.clone().search.Not(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Limit(value interface{}) *DB {
|
// Limit specify the number of records to be retrieved
|
||||||
return s.clone().search.Limit(value).db
|
func (s *DB) Limit(limit int) *DB {
|
||||||
|
return s.clone().search.Limit(limit).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Offset(value interface{}) *DB {
|
// Offset specify the number of records to skip before starting to return the records
|
||||||
return s.clone().search.Offset(value).db
|
func (s *DB) Offset(offset int) *DB {
|
||||||
|
return s.clone().search.Offset(offset).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
|
||||||
func (s *DB) Order(value string, reorder ...bool) *DB {
|
func (s *DB) Order(value string, reorder ...bool) *DB {
|
||||||
return s.clone().search.Order(value, reorder...).db
|
return s.clone().search.Order(value, reorder...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
|
||||||
|
// When creating/updating, specify fields that you want to save to database
|
||||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Select(query, args...).db
|
return s.clone().search.Select(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Omit specify fields that you want to ignore when saving to database for creating, updating
|
||||||
func (s *DB) Omit(columns ...string) *DB {
|
func (s *DB) Omit(columns ...string) *DB {
|
||||||
return s.clone().search.Omit(columns...).db
|
return s.clone().search.Omit(columns...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Group specify the group method on the find
|
||||||
func (s *DB) Group(query string) *DB {
|
func (s *DB) Group(query string) *DB {
|
||||||
return s.clone().search.Group(query).db
|
return s.clone().search.Group(query).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Having specify HAVING conditions for GROUP BY
|
||||||
func (s *DB) Having(query string, values ...interface{}) *DB {
|
func (s *DB) Having(query string, values ...interface{}) *DB {
|
||||||
return s.clone().search.Having(query, values...).db
|
return s.clone().search.Having(query, values...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Joins(query string) *DB {
|
// Joins specify Joins conditions
|
||||||
return s.clone().search.Joins(query).db
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
|
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Joins(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
|
||||||
|
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||||
|
// return db.Where("amount > ?", 1000)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||||
|
// Refer https://jinzhu.github.io/gorm/curd.html#scopes
|
||||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
s = f(s)
|
s = f(s)
|
||||||
|
@ -185,60 +216,91 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete
|
||||||
func (s *DB) Unscoped() *DB {
|
func (s *DB) Unscoped() *DB {
|
||||||
return s.clone().search.unscoped().db
|
return s.clone().search.unscoped().db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||||
return s.clone().search.Attrs(attrs...).db
|
return s.clone().search.Attrs(attrs...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
func (s *DB) Assign(attrs ...interface{}) *DB {
|
func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||||
return s.clone().search.Assign(attrs...).db
|
return s.clone().search.Assign(attrs...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// First find first record that match given conditions, order by primary key
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Last find last record that match given conditions, order by primary key
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find find records that match given conditions
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scan scan value to a struct
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
|
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Row return `*sql.Row` with given conditions
|
||||||
func (s *DB) Row() *sql.Row {
|
func (s *DB) Row() *sql.Row {
|
||||||
return s.NewScope(s.Value).row()
|
return s.NewScope(s.Value).row()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rows return `*sql.Rows` with given conditions
|
||||||
func (s *DB) Rows() (*sql.Rows, error) {
|
func (s *DB) Rows() (*sql.Rows, error) {
|
||||||
return s.NewScope(s.Value).rows()
|
return s.NewScope(s.Value).rows()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanRows scan `*sql.Rows` to give struct
|
||||||
|
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||||
|
var (
|
||||||
|
clone = s.clone()
|
||||||
|
scope = clone.NewScope(result)
|
||||||
|
columns, err = rows.Columns()
|
||||||
|
)
|
||||||
|
|
||||||
|
if clone.AddError(err) == nil {
|
||||||
|
scope.scan(rows, columns, scope.fieldsMap())
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pluck used to query single column from a model as a map
|
||||||
|
// var ages []int64
|
||||||
|
// db.Find(&users).Pluck("age", &ages)
|
||||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||||
return s.NewScope(s.Value).pluck(column, value).db
|
return s.NewScope(s.Value).pluck(column, value).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count get how many records for a model
|
||||||
func (s *DB) Count(value interface{}) *DB {
|
func (s *DB) Count(value interface{}) *DB {
|
||||||
return s.NewScope(s.Value).count(value).db
|
return s.NewScope(s.Value).count(value).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Related get related associations
|
||||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
|
||||||
|
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if result := c.First(out, where...); result.Error != nil {
|
if result := c.First(out, where...); result.Error != nil {
|
||||||
|
@ -247,82 +309,100 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize()
|
c.NewScope(out).inlineCondition(where...).initialize()
|
||||||
} else {
|
} else {
|
||||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
|
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||||
|
// https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if result := c.First(out, where...); result.Error != nil {
|
if result := c.First(out, where...); result.Error != nil {
|
||||||
if !result.RecordNotFound() {
|
if !result.RecordNotFound() {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
|
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error)
|
||||||
} else if len(c.search.assignAttrs) > 0 {
|
} else if len(c.search.assignAttrs) > 0 {
|
||||||
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
|
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
return s.Updates(toSearchableMap(attrs...), true)
|
return s.Updates(toSearchableMap(attrs...), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
return s.UpdateColumns(toSearchableMap(attrs...))
|
return s.UpdateColumns(toSearchableMap(attrs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
Set("gorm:save_associations", false).
|
Set("gorm:save_associations", false).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.clone().NewScope(value)
|
||||||
if scope.PrimaryKeyZero() {
|
if scope.PrimaryKeyZero() {
|
||||||
return scope.callCallbacks(s.parent.callback.creates).db
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
return scope.callCallbacks(s.parent.callback.updates).db
|
return scope.callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create insert the value into database
|
||||||
func (s *DB) Create(value interface{}) *DB {
|
func (s *DB) Create(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.clone().NewScope(value)
|
||||||
return scope.callCallbacks(s.parent.callback.creates).db
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
|
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||||
|
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
|
||||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||||
return s.clone().search.Raw(true).Where(sql, values...).db
|
return s.clone().search.Raw(true).Where(sql, values...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Exec execute raw sql
|
||||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||||
scope := s.clone().NewScope(nil)
|
scope := s.clone().NewScope(nil)
|
||||||
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
||||||
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
|
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||||
scope.Raw(generatedSql)
|
scope.Raw(generatedSQL)
|
||||||
return scope.Exec().db
|
return scope.Exec().db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Model specify the model you would like to run db operations
|
||||||
|
// // update all users's name to `hello`
|
||||||
|
// db.Model(&User{}).Update("name", "hello")
|
||||||
|
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||||
|
// db.Model(&user).Update("name", "hello")
|
||||||
func (s *DB) Model(value interface{}) *DB {
|
func (s *DB) Model(value interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
c.Value = value
|
c.Value = value
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Table specify the table you would like to run db operations
|
||||||
func (s *DB) Table(name string) *DB {
|
func (s *DB) Table(name string) *DB {
|
||||||
clone := s.clone()
|
clone := s.clone()
|
||||||
clone.search.Table(name)
|
clone.search.Table(name)
|
||||||
|
@ -330,10 +410,12 @@ func (s *DB) Table(name string) *DB {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Debug start debug mode
|
||||||
func (s *DB) Debug() *DB {
|
func (s *DB) Debug() *DB {
|
||||||
return s.clone().LogMode(true)
|
return s.clone().LogMode(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Begin begin a transaction
|
||||||
func (s *DB) Begin() *DB {
|
func (s *DB) Begin() *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
if db, ok := c.db.(sqlDb); ok {
|
if db, ok := c.db.(sqlDb); ok {
|
||||||
|
@ -341,46 +423,56 @@ func (s *DB) Begin() *DB {
|
||||||
c.db = interface{}(tx).(sqlCommon)
|
c.db = interface{}(tx).(sqlCommon)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(CantStartTransaction)
|
c.AddError(ErrCantStartTransaction)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Commit commit a transaction
|
||||||
func (s *DB) Commit() *DB {
|
func (s *DB) Commit() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(NoValidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rollback rollback a transaction
|
||||||
func (s *DB) Rollback() *DB {
|
func (s *DB) Rollback() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Rollback())
|
s.AddError(db.Rollback())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(NoValidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRecord check if value's primary key is blank
|
||||||
func (s *DB) NewRecord(value interface{}) bool {
|
func (s *DB) NewRecord(value interface{}) bool {
|
||||||
return s.clone().NewScope(value).PrimaryKeyZero()
|
return s.clone().NewScope(value).PrimaryKeyZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordNotFound check if returning ErrRecordNotFound error
|
||||||
func (s *DB) RecordNotFound() bool {
|
func (s *DB) RecordNotFound() bool {
|
||||||
return s.Error == RecordNotFound
|
for _, err := range s.GetErrors() {
|
||||||
|
if err == ErrRecordNotFound {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrations
|
// CreateTable create table for models
|
||||||
func (s *DB) CreateTable(values ...interface{}) *DB {
|
func (s *DB) CreateTable(models ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, model := range models {
|
||||||
db = db.NewScope(value).createTable().db
|
db = db.NewScope(model).createTable().db
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropTable drop table for models
|
||||||
func (s *DB) DropTable(values ...interface{}) *DB {
|
func (s *DB) DropTable(values ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
@ -393,18 +485,18 @@ func (s *DB) DropTable(values ...interface{}) *DB {
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropTableIfExists drop table if it is exist
|
||||||
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if tableName, ok := value.(string); ok {
|
if s.HasTable(value) {
|
||||||
db = db.Table(tableName)
|
db.AddError(s.DropTable(value).Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
db = db.NewScope(value).dropTableIfExists().db
|
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasTable check has table or not
|
||||||
func (s *DB) HasTable(value interface{}) bool {
|
func (s *DB) HasTable(value interface{}) bool {
|
||||||
var (
|
var (
|
||||||
scope = s.clone().NewScope(value)
|
scope = s.clone().NewScope(value)
|
||||||
|
@ -417,69 +509,64 @@ func (s *DB) HasTable(value interface{}) bool {
|
||||||
tableName = scope.TableName()
|
tableName = scope.TableName()
|
||||||
}
|
}
|
||||||
|
|
||||||
has := scope.Dialect().HasTable(scope, tableName)
|
has := scope.Dialect().HasTable(tableName)
|
||||||
s.AddError(scope.db.Error)
|
s.AddError(scope.db.Error)
|
||||||
return has
|
return has
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
||||||
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||||
db := s.clone()
|
db := s.clone()
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
db = db.NewScope(value).NeedPtr().autoMigrate().db
|
db = db.NewScope(value).autoMigrate().db
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModifyColumn modify column to type
|
||||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.modifyColumn(column, typ)
|
scope.modifyColumn(column, typ)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropColumn drop a column
|
||||||
func (s *DB) DropColumn(column string) *DB {
|
func (s *DB) DropColumn(column string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.dropColumn(column)
|
scope.dropColumn(column)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) AddIndex(indexName string, column ...string) *DB {
|
// AddIndex add index for columns with given name
|
||||||
|
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
||||||
scope := s.Unscoped().NewScope(s.Value)
|
scope := s.Unscoped().NewScope(s.Value)
|
||||||
scope.addIndex(false, indexName, column...)
|
scope.addIndex(false, indexName, columns...)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
|
// AddUniqueIndex add unique index for columns with given name
|
||||||
|
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.addIndex(true, indexName, column...)
|
scope.addIndex(true, indexName, columns...)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveIndex remove index with name
|
||||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.removeIndex(indexName)
|
scope.removeIndex(indexName)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) CurrentDatabase() string {
|
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||||
var (
|
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||||
scope = s.clone().NewScope(s.Value)
|
|
||||||
name = s.dialect.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Add foreign key to the given scope
|
|
||||||
|
|
||||||
Example:
|
|
||||||
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
|
||||||
*/
|
|
||||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
||||||
func (s *DB) Association(column string) *Association {
|
func (s *DB) Association(column string) *Association {
|
||||||
var err error
|
var err error
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
@ -491,7 +578,7 @@ func (s *DB) Association(column string) *Association {
|
||||||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||||
} else {
|
} else {
|
||||||
return &Association{Scope: scope, Column: column, Field: field}
|
return &Association{scope: scope, column: column, field: field}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||||
|
@ -501,26 +588,30 @@ func (s *DB) Association(column string) *Association {
|
||||||
return &Association{Error: err}
|
return &Association{Error: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preload preload associations with given conditions
|
||||||
|
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||||
return s.clone().search.Preload(column, conditions...).db
|
return s.clone().search.Preload(column, conditions...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set set value by name
|
// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
|
||||||
func (s *DB) Set(name string, value interface{}) *DB {
|
func (s *DB) Set(name string, value interface{}) *DB {
|
||||||
return s.clone().InstantSet(name, value)
|
return s.clone().InstantSet(name, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstantSet instant set setting, will affect current db
|
||||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||||
s.values[name] = value
|
s.values[name] = value
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get value by name
|
// Get get setting by name
|
||||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
value, ok = s.values[name]
|
value, ok = s.values[name]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetJoinTableHandler set a model's join table handler for a relation
|
||||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||||
scope := s.NewScope(source)
|
scope := s.NewScope(source)
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
@ -530,7 +621,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
field.Relationship.JoinTableHandler = handler
|
field.Relationship.JoinTableHandler = handler
|
||||||
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
|
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
||||||
s.Table(table).AutoMigrate(handler)
|
s.Table(table).AutoMigrate(handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -538,9 +629,10 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddError add error to the db
|
||||||
func (s *DB) AddError(err error) error {
|
func (s *DB) AddError(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != RecordNotFound {
|
if err != ErrRecordNotFound {
|
||||||
if s.logMode == 0 {
|
if s.logMode == 0 {
|
||||||
go s.print(fileWithLineNum(), err)
|
go s.print(fileWithLineNum(), err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -559,6 +651,7 @@ func (s *DB) AddError(err error) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetErrors get happened errors from the db
|
||||||
func (s *DB) GetErrors() (errors []error) {
|
func (s *DB) GetErrors() (errors []error) {
|
||||||
if errs, ok := s.Error.(errorsInterface); ok {
|
if errs, ok := s.Error.(errorsInterface); ok {
|
||||||
return errs.GetErrors()
|
return errs.GetErrors()
|
||||||
|
|
|
@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.search == nil {
|
if s.search == nil {
|
||||||
db.search = &search{}
|
db.search = &search{limit: -1, offset: -1}
|
||||||
} else {
|
} else {
|
||||||
db.search = s.search.clone()
|
db.search = s.search.clone()
|
||||||
}
|
}
|
||||||
|
|
96
main_test.go
96
main_test.go
|
@ -4,23 +4,23 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
|
||||||
testdb "github.com/erikstmartin/go-testdb"
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
"github.com/jinzhu/now"
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
|
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/erikstmartin/go-testdb"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
|
"github.com/jinzhu/gorm/dialects/postgres"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
"github.com/jinzhu/now"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
DB gorm.DB
|
DB *gorm.DB
|
||||||
t1, t2, t3, t4, t5 time.Time
|
t1, t2, t3, t4, t5 time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ func init() {
|
||||||
runMigration()
|
runMigration()
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenTestConnection() (db gorm.DB, err error) {
|
func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
switch os.Getenv("GORM_DIALECT") {
|
switch os.Getenv("GORM_DIALECT") {
|
||||||
case "mysql":
|
case "mysql":
|
||||||
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
|
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
|
||||||
|
@ -115,7 +115,7 @@ func TestSetTable(t *testing.T) {
|
||||||
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
|
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
|
||||||
|
|
||||||
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
|
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
|
||||||
t.Errorf("No errors should happen if set table for pluck", err.Error())
|
t.Error("No errors should happen if set table for pluck", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []User
|
var users []User
|
||||||
|
@ -376,7 +376,7 @@ func TestRows(t *testing.T) {
|
||||||
|
|
||||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Not error should happen, but got")
|
t.Errorf("Not error should happen, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
@ -386,8 +386,39 @@ func TestRows(t *testing.T) {
|
||||||
rows.Scan(&name, &age)
|
rows.Scan(&name, &age)
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
|
|
||||||
if count != 2 {
|
if count != 2 {
|
||||||
t.Errorf("Should found two records with name 3")
|
t.Errorf("Should found two records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanRows(t *testing.T) {
|
||||||
|
user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
||||||
|
user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
||||||
|
user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Not error should happen, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []Result
|
||||||
|
for rows.Next() {
|
||||||
|
var result Result
|
||||||
|
if err := DB.ScanRows(rows, &result); err != nil {
|
||||||
|
t.Errorf("should get no error, but got %v", err)
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||||
|
t.Errorf("Should find expected results")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -448,7 +479,7 @@ func TestRaw(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
|
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
||||||
t.Error("Raw sql to update records")
|
t.Error("Raw sql to update records")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -470,14 +501,33 @@ func TestGroup(t *testing.T) {
|
||||||
func TestJoins(t *testing.T) {
|
func TestJoins(t *testing.T) {
|
||||||
var user = User{
|
var user = User{
|
||||||
Name: "joins",
|
Name: "joins",
|
||||||
|
CreditCard: CreditCard{Number: "411111111111"},
|
||||||
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||||
}
|
}
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
|
|
||||||
var result User
|
var users1 []User
|
||||||
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
|
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
|
||||||
if result.Name != "joins" || result.Id != user.Id {
|
if len(users1) != 2 {
|
||||||
t.Errorf("Should find all two emails with Join")
|
t.Errorf("should find two users using left join")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users2 []User
|
||||||
|
DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
|
||||||
|
if len(users2) != 1 {
|
||||||
|
t.Errorf("should find one users using left join with conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users3 []User
|
||||||
|
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
|
||||||
|
if len(users3) != 1 {
|
||||||
|
t.Errorf("should find one users using multiple left join conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users4 []User
|
||||||
|
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
|
||||||
|
if len(users4) != 0 {
|
||||||
|
t.Errorf("should find no user when searching with unexisting credit card")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -557,7 +607,7 @@ func TestTimeWithZone(t *testing.T) {
|
||||||
DB.First(&findUser, "name = ?", name)
|
DB.First(&findUser, "name = ?", name)
|
||||||
foundBirthday = findUser.Birthday.UTC().Format(format)
|
foundBirthday = findUser.Birthday.UTC().Format(format)
|
||||||
if foundBirthday != expectedBirthday {
|
if foundBirthday != expectedBirthday {
|
||||||
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
|
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||||
|
@ -573,7 +623,7 @@ func TestTimeWithZone(t *testing.T) {
|
||||||
func TestHstore(t *testing.T) {
|
func TestHstore(t *testing.T) {
|
||||||
type Details struct {
|
type Details struct {
|
||||||
Id int64
|
Id int64
|
||||||
Bulk gorm.Hstore
|
Bulk postgres.Hstore
|
||||||
}
|
}
|
||||||
|
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||||
|
@ -659,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
||||||
t.Errorf("Should have found existing record")
|
t.Errorf("Should have found existing record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
scope := DB.NewScope(&Email{})
|
scope := DB.NewScope(&Email{})
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email should have index idx_email_email")
|
t.Errorf("Email should have index idx_email_email")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email's index idx_email_email should be deleted")
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
|
||||||
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
|
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
|
||||||
|
|
||||||
scope := DB.NewScope(&BigEmail{})
|
scope := DB.NewScope(&BigEmail{})
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
4
model.go
4
model.go
|
@ -2,6 +2,10 @@ package gorm
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
||||||
|
// type User struct {
|
||||||
|
// gorm.Model
|
||||||
|
// }
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID uint `gorm:"primary_key"`
|
ID uint `gorm:"primary_key"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
|
|
@ -3,10 +3,8 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -14,6 +12,7 @@ import (
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultTableNameHandler default table name handler
|
||||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||||
return defaultTableName
|
return defaultTableName
|
||||||
}
|
}
|
||||||
|
@ -41,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
|
||||||
|
|
||||||
var modelStructsMap = newModelStructsMap()
|
var modelStructsMap = newModelStructsMap()
|
||||||
|
|
||||||
|
// ModelStruct model definition
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryFields []*StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
|
@ -48,10 +48,12 @@ type ModelStruct struct {
|
||||||
defaultTableName string
|
defaultTableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TableName get model's table name
|
||||||
func (s *ModelStruct) TableName(db *DB) string {
|
func (s *ModelStruct) TableName(db *DB) string {
|
||||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StructField model field's struct definition
|
||||||
type StructField struct {
|
type StructField struct {
|
||||||
DBName string
|
DBName string
|
||||||
Name string
|
Name string
|
||||||
|
@ -107,7 +109,7 @@ func getForeignField(column string, fields []*StructField) *StructField {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelStruct generate model struct & relationships based on struct and tag definition
|
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
||||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
var modelStruct ModelStruct
|
var modelStruct ModelStruct
|
||||||
// Scope value can't be nil
|
// Scope value can't be nil
|
||||||
|
@ -296,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if len(associationForeignKeys) == 0 {
|
if len(associationForeignKeys) == 0 {
|
||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if strings.HasPrefix(foreignKey, associationType) {
|
if strings.HasPrefix(foreignKey, associationType) {
|
||||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
@ -389,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if len(associationForeignKeys) == 0 {
|
if len(associationForeignKeys) == 0 {
|
||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if strings.HasPrefix(foreignKey, associationType) {
|
if strings.HasPrefix(foreignKey, associationType) {
|
||||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
@ -445,7 +453,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if len(associationForeignKeys) == 0 {
|
if len(associationForeignKeys) == 0 {
|
||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if strings.HasPrefix(foreignKey, field.Name) {
|
if strings.HasPrefix(foreignKey, field.Name) {
|
||||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name))
|
associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
@ -508,63 +519,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
return &modelStruct
|
return &modelStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStructFields get model's field structs
|
||||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
return scope.GetModelStruct().StructFields
|
return scope.GetModelStruct().StructFields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) generateSqlTag(field *StructField) string {
|
|
||||||
var sqlType string
|
|
||||||
structType := field.Struct.Type
|
|
||||||
if structType.Kind() == reflect.Ptr {
|
|
||||||
structType = structType.Elem()
|
|
||||||
}
|
|
||||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
|
||||||
|
|
||||||
if value, ok := field.TagSettings["TYPE"]; ok {
|
|
||||||
sqlType = value
|
|
||||||
}
|
|
||||||
|
|
||||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
|
||||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
|
||||||
additionalType = additionalType + " DEFAULT " + value
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsScanner {
|
|
||||||
var getScannerValue func(reflect.Value)
|
|
||||||
getScannerValue = func(value reflect.Value) {
|
|
||||||
reflectValue = value
|
|
||||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
|
||||||
getScannerValue(reflectValue.Field(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
getScannerValue(reflectValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqlType == "" {
|
|
||||||
var size = 255
|
|
||||||
|
|
||||||
if value, ok := field.TagSettings["SIZE"]; ok {
|
|
||||||
size, _ = strconv.Atoi(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
autoIncrease = true
|
|
||||||
}
|
|
||||||
if v == "FALSE" {
|
|
||||||
autoIncrease = false
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
|
||||||
return sqlType
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||||
setting := map[string]string{}
|
setting := map[string]string{}
|
||||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||||
|
|
80
mssql.go
80
mssql.go
|
@ -1,80 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mssql struct {
|
|
||||||
commonDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mssql) HasTop() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "bit"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "int IDENTITY(1,1)"
|
|
||||||
}
|
|
||||||
return "int"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "bigint IDENTITY(1,1)"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "float"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("nvarchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "text"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "datetime2"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := value.Interface().([]byte); ok {
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "text"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var (
|
|
||||||
count int
|
|
||||||
databaseName = s.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var (
|
|
||||||
count int
|
|
||||||
databaseName = s.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -21,7 +21,7 @@ type Tag struct {
|
||||||
ID uint `gorm:"primary_key"`
|
ID uint `gorm:"primary_key"`
|
||||||
Locale string `gorm:"primary_key"`
|
Locale string `gorm:"primary_key"`
|
||||||
Value string
|
Value string
|
||||||
Blogs []*Blog `gorm:"many2many:"blogs_tags`
|
Blogs []*Blog `gorm:"many2many:blogs_tags"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareTags(tags []Tag, contents []string) bool {
|
func compareTags(tags []Tag, contents []string) bool {
|
||||||
|
|
70
mysql.go
70
mysql.go
|
@ -1,70 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mysql struct {
|
|
||||||
commonDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
|
||||||
if autoIncrease {
|
|
||||||
return "int AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "int"
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "int unsigned AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "int unsigned"
|
|
||||||
case reflect.Int64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "bigint AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "bigint unsigned AUTO_INCREMENT"
|
|
||||||
}
|
|
||||||
return "bigint unsigned"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "double"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "longtext"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "timestamp NULL"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := value.Interface().([]byte); ok {
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varbinary(%d)", size)
|
|
||||||
}
|
|
||||||
return "longblob"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mysql) Quote(key string) string {
|
|
||||||
return fmt.Sprintf("`%s`", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mysql) SelectFromDummyTable() string {
|
|
||||||
return "FROM DUAL"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -39,46 +39,46 @@ func TestPointerFields(t *testing.T) {
|
||||||
|
|
||||||
var nilPointerStruct = PointerStruct{}
|
var nilPointerStruct = PointerStruct{}
|
||||||
if err := DB.Create(&nilPointerStruct).Error; err != nil {
|
if err := DB.Create(&nilPointerStruct).Error; err != nil {
|
||||||
t.Errorf("Failed to save nil pointer struct", err)
|
t.Error("Failed to save nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var pointerStruct2 PointerStruct
|
var pointerStruct2 PointerStruct
|
||||||
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||||
t.Errorf("Failed to query saved nil pointer struct", err)
|
t.Error("Failed to query saved nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var normalStruct2 NormalStruct
|
var normalStruct2 NormalStruct
|
||||||
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||||
t.Errorf("Failed to query saved nil pointer struct", err)
|
t.Error("Failed to query saved nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var partialNilPointerStruct1 = PointerStruct{Num: &num}
|
var partialNilPointerStruct1 = PointerStruct{Num: &num}
|
||||||
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
|
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
|
||||||
t.Errorf("Failed to save partial nil pointer struct", err)
|
t.Error("Failed to save partial nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var pointerStruct3 PointerStruct
|
var pointerStruct3 PointerStruct
|
||||||
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
|
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
|
||||||
t.Errorf("Failed to query saved partial nil pointer struct", err)
|
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var normalStruct3 NormalStruct
|
var normalStruct3 NormalStruct
|
||||||
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
|
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
|
||||||
t.Errorf("Failed to query saved partial pointer struct", err)
|
t.Error("Failed to query saved partial pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var partialNilPointerStruct2 = PointerStruct{Name: &name}
|
var partialNilPointerStruct2 = PointerStruct{Name: &name}
|
||||||
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
|
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
|
||||||
t.Errorf("Failed to save partial nil pointer struct", err)
|
t.Error("Failed to save partial nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var pointerStruct4 PointerStruct
|
var pointerStruct4 PointerStruct
|
||||||
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
|
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
|
||||||
t.Errorf("Failed to query saved partial nil pointer struct", err)
|
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var normalStruct4 NormalStruct
|
var normalStruct4 NormalStruct
|
||||||
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
|
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
|
||||||
t.Errorf("Failed to query saved partial pointer struct", err)
|
t.Error("Failed to query saved partial pointer struct", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
154
postgres.go
154
postgres.go
|
@ -1,154 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lib/pq/hstore"
|
|
||||||
)
|
|
||||||
|
|
||||||
type postgres struct {
|
|
||||||
commonDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (postgres) BinVar(i int) string {
|
|
||||||
return fmt.Sprintf("$%v", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (postgres) SupportLastInsertId() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "serial"
|
|
||||||
}
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "bigserial"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "numeric"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "text"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "timestamp with time zone"
|
|
||||||
}
|
|
||||||
case reflect.Map:
|
|
||||||
if value.Type() == hstoreType {
|
|
||||||
return "hstore"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if isByteArrayOrSlice(value) {
|
|
||||||
return "bytea"
|
|
||||||
} else if isUUID(value) {
|
|
||||||
return "uuid"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
var byteType = reflect.TypeOf(uint8(0))
|
|
||||||
|
|
||||||
func isByteArrayOrSlice(value reflect.Value) bool {
|
|
||||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
|
|
||||||
}
|
|
||||||
|
|
||||||
func isUUID(value reflect.Value) bool {
|
|
||||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
typename := value.Type().Name()
|
|
||||||
lower := strings.ToLower(typename)
|
|
||||||
return "uuid" == lower || "guid" == lower
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) ReturningStr(tableName, key string) string {
|
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var hstoreType = reflect.TypeOf(Hstore{})
|
|
||||||
|
|
||||||
type Hstore map[string]*string
|
|
||||||
|
|
||||||
func (h Hstore) Value() (driver.Value, error) {
|
|
||||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
|
||||||
if len(h) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range h {
|
|
||||||
var s sql.NullString
|
|
||||||
if value != nil {
|
|
||||||
s.String = *value
|
|
||||||
s.Valid = true
|
|
||||||
}
|
|
||||||
hstore.Map[key] = s
|
|
||||||
}
|
|
||||||
return hstore.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Hstore) Scan(value interface{}) error {
|
|
||||||
hstore := hstore.Hstore{}
|
|
||||||
|
|
||||||
if err := hstore.Scan(value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hstore.Map) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
*h = Hstore{}
|
|
||||||
for k := range hstore.Map {
|
|
||||||
if hstore.Map[k].Valid {
|
|
||||||
s := hstore.Map[k].String
|
|
||||||
(*h)[k] = &s
|
|
||||||
} else {
|
|
||||||
(*h)[k] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
388
preload.go
388
preload.go
|
@ -1,388 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
|
||||||
// If value is a nil pointer, Indirect returns a zero Value!
|
|
||||||
// Therefor we need to check for a zero value,
|
|
||||||
// as FieldByName could panic
|
|
||||||
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
|
|
||||||
for _, column := range columns {
|
|
||||||
if pointedValue.FieldByName(column).IsValid() {
|
|
||||||
result := pointedValue.FieldByName(column).Interface()
|
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
|
||||||
result, _ = r.Value()
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func equalAsString(a interface{}, b interface{}) bool {
|
|
||||||
return toString(a) == toString(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Preload(scope *Scope) {
|
|
||||||
if scope.Search.preload == nil || scope.HasError() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadMap := map[string]bool{}
|
|
||||||
fields := scope.Fields()
|
|
||||||
for _, preload := range scope.Search.preload {
|
|
||||||
schema, conditions := preload.schema, preload.conditions
|
|
||||||
keys := strings.Split(schema, ".")
|
|
||||||
currentScope := scope
|
|
||||||
currentFields := fields
|
|
||||||
originalConditions := conditions
|
|
||||||
conditions = []interface{}{}
|
|
||||||
for i, key := range keys {
|
|
||||||
var found bool
|
|
||||||
if preloadMap[strings.Join(keys[:i+1], ".")] {
|
|
||||||
goto nextLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == len(keys)-1 {
|
|
||||||
conditions = originalConditions
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, field := range currentFields {
|
|
||||||
if field.Name != key || field.Relationship == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
found = true
|
|
||||||
switch field.Relationship.Kind {
|
|
||||||
case "has_one":
|
|
||||||
currentScope.handleHasOnePreload(field, conditions)
|
|
||||||
case "has_many":
|
|
||||||
currentScope.handleHasManyPreload(field, conditions)
|
|
||||||
case "belongs_to":
|
|
||||||
currentScope.handleBelongsToPreload(field, conditions)
|
|
||||||
case "many_to_many":
|
|
||||||
currentScope.handleManyToManyPreload(field, conditions)
|
|
||||||
default:
|
|
||||||
currentScope.Err(errors.New("not supported relation"))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
value := reflect.ValueOf(currentScope.Value)
|
|
||||||
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
|
|
||||||
value = value.Index(0).Elem()
|
|
||||||
}
|
|
||||||
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadMap[strings.Join(keys[:i+1], ".")] = true
|
|
||||||
|
|
||||||
nextLoop:
|
|
||||||
if i < len(keys)-1 {
|
|
||||||
currentScope = currentScope.getColumnsAsScope(key)
|
|
||||||
currentFields = currentScope.Fields()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeSlice(typ reflect.Type) interface{} {
|
|
||||||
if typ.Kind() == reflect.Slice {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
sliceType := reflect.SliceOf(typ)
|
|
||||||
slice := reflect.New(sliceType)
|
|
||||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
|
||||||
return slice.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
|
||||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
|
||||||
|
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
|
||||||
result := resultValues.Index(i)
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
|
||||||
value := getRealValue(result, relation.ForeignFieldNames)
|
|
||||||
objects := scope.IndirectValue()
|
|
||||||
for j := 0; j < objects.Len(); j++ {
|
|
||||||
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
|
|
||||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := scope.SetColumn(field, result); err != nil {
|
|
||||||
scope.Err(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
|
||||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
|
||||||
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
|
||||||
preloadMap := make(map[string][]reflect.Value)
|
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
|
||||||
result := resultValues.Index(i)
|
|
||||||
value := getRealValue(result, relation.ForeignFieldNames)
|
|
||||||
preloadMap[toString(value)] = append(preloadMap[toString(value)], result)
|
|
||||||
}
|
|
||||||
|
|
||||||
objects := scope.IndirectValue()
|
|
||||||
for j := 0; j < objects.Len(); j++ {
|
|
||||||
object := reflect.Indirect(objects.Index(j))
|
|
||||||
objectRealValue := getRealValue(object, relation.AssociationForeignFieldNames)
|
|
||||||
objectStringValue := toString(objectRealValue)
|
|
||||||
if results, ok := preloadMap[objectStringValue]; ok {
|
|
||||||
if object.Kind() == reflect.Ptr {
|
|
||||||
object = object.Elem()
|
|
||||||
}
|
|
||||||
f := object.FieldByName(field.Name)
|
|
||||||
f.Set(reflect.Append(f, results...))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.SetColumn(field, resultValues)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results := makeSlice(field.Struct.Type)
|
|
||||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
|
||||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
|
||||||
|
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
|
||||||
result := resultValues.Index(i)
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
|
||||||
value := getRealValue(result, relation.AssociationForeignFieldNames)
|
|
||||||
objects := scope.IndirectValue()
|
|
||||||
for j := 0; j < objects.Len(); j++ {
|
|
||||||
object := reflect.Indirect(objects.Index(j))
|
|
||||||
if object.Kind() == reflect.Ptr {
|
|
||||||
object = reflect.Indirect(objects.Index(j).Elem())
|
|
||||||
}
|
|
||||||
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
|
|
||||||
object.FieldByName(field.Name).Set(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scope.SetColumn(field, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
|
||||||
relation := field.Relationship
|
|
||||||
joinTableHandler := relation.JoinTableHandler
|
|
||||||
destType := field.StructField.Struct.Type.Elem()
|
|
||||||
var isPtr bool
|
|
||||||
if destType.Kind() == reflect.Ptr {
|
|
||||||
isPtr = true
|
|
||||||
destType = destType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
var sourceKeys []string
|
|
||||||
var linkHash = make(map[string][]reflect.Value)
|
|
||||||
|
|
||||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
|
||||||
sourceKeys = append(sourceKeys, key.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
|
|
||||||
|
|
||||||
preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
|
|
||||||
|
|
||||||
if len(conditions) > 0 {
|
|
||||||
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
|
|
||||||
}
|
|
||||||
rows, err := preloadJoinDB.Rows()
|
|
||||||
|
|
||||||
if scope.Err(err) != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
|
||||||
for rows.Next() {
|
|
||||||
elem := reflect.New(destType).Elem()
|
|
||||||
var values = make([]interface{}, len(columns))
|
|
||||||
|
|
||||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
|
||||||
|
|
||||||
var foundFields = map[string]bool{}
|
|
||||||
for index, column := range columns {
|
|
||||||
if field, ok := fields[column]; ok && !foundFields[column] {
|
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
|
||||||
values[index] = field.Field.Addr().Interface()
|
|
||||||
} else {
|
|
||||||
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
|
||||||
}
|
|
||||||
foundFields[column] = true
|
|
||||||
} else {
|
|
||||||
var i interface{}
|
|
||||||
values[index] = &i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Err(rows.Scan(values...))
|
|
||||||
|
|
||||||
var sourceKey []interface{}
|
|
||||||
|
|
||||||
var scannedFields = map[string]bool{}
|
|
||||||
for index, column := range columns {
|
|
||||||
value := values[index]
|
|
||||||
if field, ok := fields[column]; ok && !scannedFields[column] {
|
|
||||||
if field.Field.Kind() == reflect.Ptr {
|
|
||||||
field.Field.Set(reflect.ValueOf(value).Elem())
|
|
||||||
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
|
||||||
field.Field.Set(v)
|
|
||||||
}
|
|
||||||
scannedFields[column] = true
|
|
||||||
} else if strInSlice(column, sourceKeys) {
|
|
||||||
sourceKey = append(sourceKey, *(value.(*interface{})))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sourceKey) != 0 {
|
|
||||||
if isPtr {
|
|
||||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
|
|
||||||
} else {
|
|
||||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var foreignFieldNames []string
|
|
||||||
for _, dbName := range relation.ForeignFieldNames {
|
|
||||||
if field, ok := scope.FieldByName(dbName); ok {
|
|
||||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
|
||||||
objects := scope.IndirectValue()
|
|
||||||
for j := 0; j < objects.Len(); j++ {
|
|
||||||
object := reflect.Indirect(objects.Index(j))
|
|
||||||
if object.Kind() == reflect.Ptr {
|
|
||||||
object = object.Elem()
|
|
||||||
}
|
|
||||||
source := getRealValue(object, foreignFieldNames)
|
|
||||||
field := object.FieldByName(field.Name)
|
|
||||||
for _, link := range linkHash[toString(source)] {
|
|
||||||
field.Set(reflect.Append(field, link))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if object := scope.IndirectValue(); object.IsValid() {
|
|
||||||
source := getRealValue(object, foreignFieldNames)
|
|
||||||
field := object.FieldByName(field.Name)
|
|
||||||
for _, link := range linkHash[toString(source)] {
|
|
||||||
field.Set(reflect.Append(field, link))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
|
||||||
values := scope.IndirectValue()
|
|
||||||
switch values.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
var result []interface{}
|
|
||||||
for _, column := range columns {
|
|
||||||
value := reflect.Indirect(values.Index(i))
|
|
||||||
if value.Kind() == reflect.Ptr {
|
|
||||||
value = reflect.Indirect(values.Index(i).Elem())
|
|
||||||
}
|
|
||||||
result = append(result, value.FieldByName(column).Interface())
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
var result []interface{}
|
|
||||||
for _, column := range columns {
|
|
||||||
result = append(result, values.FieldByName(column).Interface())
|
|
||||||
}
|
|
||||||
return [][]interface{}{result}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
|
||||||
values := scope.IndirectValue()
|
|
||||||
switch values.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
modelType := values.Type().Elem()
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
fieldStruct, _ := modelType.FieldByName(column)
|
|
||||||
var columns reflect.Value
|
|
||||||
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
|
||||||
} else {
|
|
||||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
|
||||||
}
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
|
||||||
if column.Kind() == reflect.Ptr {
|
|
||||||
column = column.Elem()
|
|
||||||
}
|
|
||||||
if column.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < column.Len(); i++ {
|
|
||||||
elem := column.Index(i)
|
|
||||||
if elem.CanAddr() {
|
|
||||||
columns = reflect.Append(columns, elem.Addr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if column.CanAddr() {
|
|
||||||
columns = reflect.Append(columns, column.Addr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return scope.New(columns.Interface())
|
|
||||||
case reflect.Struct:
|
|
||||||
field := values.FieldByName(column)
|
|
||||||
if !field.CanAddr() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return scope.New(field.Addr().Interface())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
251
preload_test.go
251
preload_test.go
|
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -818,90 +818,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManyToManyPreloadForPointer(t *testing.T) {
|
|
||||||
type (
|
|
||||||
Level1 struct {
|
|
||||||
ID uint
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
Level2 struct {
|
|
||||||
ID uint
|
|
||||||
Value string
|
|
||||||
Level1s []*Level1 `gorm:"many2many:levels;"`
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
DB.DropTableIfExists(&Level2{})
|
|
||||||
DB.DropTableIfExists(&Level1{})
|
|
||||||
DB.DropTableIfExists("levels")
|
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := Level2{Value: "Bob", Level1s: []*Level1{
|
|
||||||
{Value: "ru"},
|
|
||||||
{Value: "en"},
|
|
||||||
}}
|
|
||||||
if err := DB.Save(&want).Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
|
||||||
{Value: "zh"},
|
|
||||||
{Value: "de"},
|
|
||||||
}}
|
|
||||||
if err := DB.Save(&want2).Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var got Level2
|
|
||||||
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
|
||||||
}
|
|
||||||
|
|
||||||
var got2 Level2
|
|
||||||
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got2, want2) {
|
|
||||||
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
|
|
||||||
}
|
|
||||||
|
|
||||||
var got3 []Level2
|
|
||||||
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
|
||||||
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
|
||||||
}
|
|
||||||
|
|
||||||
var got4 []Level2
|
|
||||||
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var got5 Level2
|
|
||||||
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
|
|
||||||
|
|
||||||
var ruLevel1 Level1
|
|
||||||
var zhLevel1 Level1
|
|
||||||
DB.First(&ruLevel1, "value = ?", "ru")
|
|
||||||
DB.First(&zhLevel1, "value = ?", "zh")
|
|
||||||
|
|
||||||
got.Level1s = []*Level1{&ruLevel1}
|
|
||||||
got2.Level1s = []*Level1{&zhLevel1}
|
|
||||||
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
|
||||||
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManyToManyPreloadForNestedPointer(t *testing.T) {
|
func TestManyToManyPreloadForNestedPointer(t *testing.T) {
|
||||||
type (
|
type (
|
||||||
Level1 struct {
|
Level1 struct {
|
||||||
|
@ -1065,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1122,12 +1038,87 @@ func TestNestedManyToManyPreload2(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNestedManyToManyPreload3(t *testing.T) {
|
func TestNestedManyToManyPreload3(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Level1 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
Level2 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
Level1s []*Level1 `gorm:"many2many:level1_level2;"`
|
||||||
|
}
|
||||||
|
Level3 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
Level2ID sql.NullInt64
|
||||||
|
Level2 *Level2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.DropTableIfExists(&Level1{})
|
||||||
|
DB.DropTableIfExists(&Level2{})
|
||||||
|
DB.DropTableIfExists(&Level3{})
|
||||||
|
DB.DropTableIfExists("level1_level2")
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
level1Zh := &Level1{Value: "zh"}
|
||||||
|
level1Ru := &Level1{Value: "ru"}
|
||||||
|
level1En := &Level1{Value: "en"}
|
||||||
|
|
||||||
|
level21 := &Level2{
|
||||||
|
Value: "Level2-1",
|
||||||
|
Level1s: []*Level1{level1Zh, level1Ru},
|
||||||
|
}
|
||||||
|
|
||||||
|
level22 := &Level2{
|
||||||
|
Value: "Level2-2",
|
||||||
|
Level1s: []*Level1{level1Zh, level1En},
|
||||||
|
}
|
||||||
|
|
||||||
|
wants := []*Level3{
|
||||||
|
{
|
||||||
|
Value: "Level3-1",
|
||||||
|
Level2: level21,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Value: "Level3-2",
|
||||||
|
Level2: level22,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Value: "Level3-3",
|
||||||
|
Level2: level21,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range wants {
|
||||||
|
if err := DB.Save(&want).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var gots []*Level3
|
||||||
|
if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Order("level1.id ASC")
|
||||||
|
}).Find(&gots).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(gots, wants) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedManyToManyPreload4(t *testing.T) {
|
||||||
type (
|
type (
|
||||||
Level4 struct {
|
Level4 struct {
|
||||||
ID uint
|
ID uint
|
||||||
|
@ -1185,6 +1176,90 @@ func TestNestedManyToManyPreload3(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManyToManyPreloadForPointer(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Level1 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
Level2 struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
Level1s []*Level1 `gorm:"many2many:levels;"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.DropTableIfExists(&Level2{})
|
||||||
|
DB.DropTableIfExists(&Level1{})
|
||||||
|
DB.DropTableIfExists("levels")
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := Level2{Value: "Bob", Level1s: []*Level1{
|
||||||
|
{Value: "ru"},
|
||||||
|
{Value: "en"},
|
||||||
|
}}
|
||||||
|
if err := DB.Save(&want).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
||||||
|
{Value: "zh"},
|
||||||
|
{Value: "de"},
|
||||||
|
}}
|
||||||
|
if err := DB.Save(&want2).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got2 Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got2, want2) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got3 []Level2
|
||||||
|
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got4 []Level2
|
||||||
|
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got5 Level2
|
||||||
|
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
|
||||||
|
|
||||||
|
var ruLevel1 Level1
|
||||||
|
var zhLevel1 Level1
|
||||||
|
DB.First(&ruLevel1, "value = ?", "ru")
|
||||||
|
DB.First(&zhLevel1, "value = ?", "zh")
|
||||||
|
|
||||||
|
got.Level1s = []*Level1{&ruLevel1}
|
||||||
|
got2.Level1s = []*Level1{&zhLevel1}
|
||||||
|
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
||||||
|
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNilPointerSlice(t *testing.T) {
|
func TestNilPointerSlice(t *testing.T) {
|
||||||
type (
|
type (
|
||||||
Level3 struct {
|
Level3 struct {
|
||||||
|
@ -1234,7 +1309,7 @@ func TestNilPointerSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(got) != 2 {
|
if len(got) != 2 {
|
||||||
t.Error("got %v items, expected 2", len(got))
|
t.Errorf("got %v items, expected 2", len(got))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
||||||
|
|
|
@ -629,14 +629,3 @@ func TestSelectWithArrayInput(t *testing.T) {
|
||||||
t.Errorf("Should have selected both age and name")
|
t.Errorf("Should have selected both age and name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCurrentDatabase(t *testing.T) {
|
|
||||||
databaseName := DB.CurrentDatabase()
|
|
||||||
if err := DB.Error; err != nil {
|
|
||||||
t.Errorf("Problem getting current db name: %s", err)
|
|
||||||
}
|
|
||||||
if databaseName == "" {
|
|
||||||
t.Errorf("Current db name returned empty; this should never happen!")
|
|
||||||
}
|
|
||||||
t.Logf("Got current db name: %v", databaseName)
|
|
||||||
}
|
|
||||||
|
|
310
scope.go
310
scope.go
|
@ -1,48 +1,32 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Scope contain current operation's information when you perform any operation on the database
|
||||||
type Scope struct {
|
type Scope struct {
|
||||||
Search *search
|
Search *search
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Sql string
|
SQL string
|
||||||
SqlVars []interface{}
|
SQLVars []interface{}
|
||||||
db *DB
|
db *DB
|
||||||
indirectValue *reflect.Value
|
instanceID string
|
||||||
instanceId string
|
|
||||||
primaryKeyField *Field
|
primaryKeyField *Field
|
||||||
skipLeft bool
|
skipLeft bool
|
||||||
fields map[string]*Field
|
fields map[string]*Field
|
||||||
selectAttrs *[]string
|
selectAttrs *[]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IndirectValue return scope's reflect value's indirect value
|
||||||
func (scope *Scope) IndirectValue() reflect.Value {
|
func (scope *Scope) IndirectValue() reflect.Value {
|
||||||
if scope.indirectValue == nil {
|
return indirect(reflect.ValueOf(scope.Value))
|
||||||
value := reflect.Indirect(reflect.ValueOf(scope.Value))
|
|
||||||
if value.Kind() == reflect.Ptr {
|
|
||||||
value = value.Elem()
|
|
||||||
}
|
|
||||||
scope.indirectValue = &value
|
|
||||||
}
|
|
||||||
return *scope.indirectValue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) NeedPtr() *Scope {
|
|
||||||
reflectKind := reflect.ValueOf(scope.Value).Kind()
|
|
||||||
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
|
||||||
err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
|
|
||||||
scope.Err(err)
|
|
||||||
fmt.Printf(err.Error())
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New create a new Scope without search information
|
// New create a new Scope without search information
|
||||||
|
@ -61,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB return scope's DB connection
|
||||||
func (scope *Scope) DB() *DB {
|
func (scope *Scope) DB() *DB {
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// SqlDB return *sql.DB
|
// SQLDB return *sql.DB
|
||||||
func (scope *Scope) SqlDB() sqlCommon {
|
func (scope *Scope) SQLDB() sqlCommon {
|
||||||
return scope.db.db
|
return scope.db.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +60,7 @@ func (scope *Scope) SkipLeft() {
|
||||||
scope.skipLeft = true
|
scope.skipLeft = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote used to quote database column name according to database dialect
|
// Quote used to quote string to escape them for database
|
||||||
func (scope *Scope) Quote(str string) string {
|
func (scope *Scope) Quote(str string) string {
|
||||||
if strings.Index(str, ".") != -1 {
|
if strings.Index(str, ".") != -1 {
|
||||||
newStrs := []string{}
|
newStrs := []string{}
|
||||||
|
@ -83,12 +68,12 @@ func (scope *Scope) Quote(str string) string {
|
||||||
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
||||||
}
|
}
|
||||||
return strings.Join(newStrs, ".")
|
return strings.Join(newStrs, ".")
|
||||||
} else {
|
|
||||||
return scope.Dialect().Quote(str)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return scope.Dialect().Quote(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) QuoteIfPossible(str string) string {
|
func (scope *Scope) quoteIfPossible(str string) string {
|
||||||
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
||||||
return scope.Quote(str)
|
return scope.Quote(str)
|
||||||
}
|
}
|
||||||
|
@ -100,7 +85,7 @@ func (scope *Scope) Dialect() Dialect {
|
||||||
return scope.db.parent.dialect
|
return scope.db.parent.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err write error
|
// Err add error to Scope
|
||||||
func (scope *Scope) Err(err error) error {
|
func (scope *Scope) Err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
scope.db.AddError(err)
|
scope.db.AddError(err)
|
||||||
|
@ -118,27 +103,30 @@ func (scope *Scope) HasError() bool {
|
||||||
return scope.db.Error != nil
|
return scope.db.Error != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) PrimaryFields() []*Field {
|
// PrimaryFields return scope's primary fields
|
||||||
var fields = []*Field{}
|
func (scope *Scope) PrimaryFields() (fields []*Field) {
|
||||||
for _, field := range scope.GetModelStruct().PrimaryFields {
|
for _, field := range scope.Fields() {
|
||||||
fields = append(fields, scope.Fields()[field.DBName])
|
if field.IsPrimaryKey {
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
|
||||||
func (scope *Scope) PrimaryField() *Field {
|
func (scope *Scope) PrimaryField() *Field {
|
||||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||||
if len(primaryFields) > 1 {
|
if len(primaryFields) > 1 {
|
||||||
if field, ok := scope.Fields()["id"]; ok {
|
if field, ok := scope.FieldByName("id"); ok {
|
||||||
return field
|
return field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return scope.Fields()[primaryFields[0].DBName]
|
return scope.PrimaryFields()[0]
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKey get the primary key's column name
|
// PrimaryKey get main primary field's db name
|
||||||
func (scope *Scope) PrimaryKey() string {
|
func (scope *Scope) PrimaryKey() string {
|
||||||
if field := scope.PrimaryField(); field != nil {
|
if field := scope.PrimaryField(); field != nil {
|
||||||
return field.DBName
|
return field.DBName
|
||||||
|
@ -146,7 +134,7 @@ func (scope *Scope) PrimaryKey() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKeyZero check the primary key is blank or not
|
// PrimaryKeyZero check main primary field's value is blank or not
|
||||||
func (scope *Scope) PrimaryKeyZero() bool {
|
func (scope *Scope) PrimaryKeyZero() bool {
|
||||||
field := scope.PrimaryField()
|
field := scope.PrimaryField()
|
||||||
return field == nil || field.IsBlank
|
return field == nil || field.IsBlank
|
||||||
|
@ -170,80 +158,85 @@ func (scope *Scope) HasColumn(column string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetColumn to set the column's value
|
// SetColumn to set the column's value, column could be field or field's name/dbname
|
||||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||||
|
var updateAttrs = map[string]interface{}{}
|
||||||
|
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
|
updateAttrs = attrs.(map[string]interface{})
|
||||||
|
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||||
|
}
|
||||||
|
|
||||||
if field, ok := column.(*Field); ok {
|
if field, ok := column.(*Field); ok {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
} else if name, ok := column.(string); ok {
|
} else if name, ok := column.(string); ok {
|
||||||
|
var (
|
||||||
if field, ok := scope.Fields()[name]; ok {
|
dbName = ToDBName(name)
|
||||||
|
mostMatchedField *Field
|
||||||
|
)
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if field.DBName == value {
|
||||||
|
updateAttrs[field.DBName] = value
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
|
if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
|
||||||
dbName := ToDBName(name)
|
mostMatchedField = field
|
||||||
if field, ok := scope.Fields()[dbName]; ok {
|
}
|
||||||
return field.Set(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if field, ok := scope.FieldByName(name); ok {
|
if mostMatchedField != nil {
|
||||||
return field.Set(value)
|
updateAttrs[mostMatchedField.DBName] = value
|
||||||
|
return mostMatchedField.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errors.New("could not convert column to field")
|
return errors.New("could not convert column to field")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) CallMethod(name string, checkError bool) {
|
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
|
||||||
if scope.Value == nil || (checkError && scope.HasError()) {
|
if reflectValue.CanAddr() {
|
||||||
|
reflectValue = reflectValue.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
|
||||||
|
switch method := methodValue.Interface().(type) {
|
||||||
|
case func():
|
||||||
|
method()
|
||||||
|
case func(*Scope):
|
||||||
|
method(scope)
|
||||||
|
case func(*DB):
|
||||||
|
newDB := scope.NewDB()
|
||||||
|
method(newDB)
|
||||||
|
scope.Err(newDB.Error)
|
||||||
|
case func() error:
|
||||||
|
scope.Err(method())
|
||||||
|
case func(*Scope) error:
|
||||||
|
scope.Err(method(scope))
|
||||||
|
case func(*DB) error:
|
||||||
|
newDB := scope.NewDB()
|
||||||
|
scope.Err(method(newDB))
|
||||||
|
scope.Err(newDB.Error)
|
||||||
|
default:
|
||||||
|
scope.Err(fmt.Errorf("unsupported function %v", methodName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
|
||||||
|
func (scope *Scope) CallMethod(methodName string) {
|
||||||
|
if scope.Value == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
call := func(value interface{}) {
|
if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
|
||||||
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
|
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||||
switch f := fm.Interface().(type) {
|
scope.callMethod(methodName, indirectScopeValue.Index(i))
|
||||||
case func():
|
|
||||||
f()
|
|
||||||
case func(s *Scope):
|
|
||||||
f(scope)
|
|
||||||
case func(s *DB):
|
|
||||||
newDB := scope.NewDB()
|
|
||||||
f(newDB)
|
|
||||||
scope.Err(newDB.Error)
|
|
||||||
case func() error:
|
|
||||||
scope.Err(f())
|
|
||||||
case func(s *Scope) error:
|
|
||||||
scope.Err(f(scope))
|
|
||||||
case func(s *DB) error:
|
|
||||||
newDB := scope.NewDB()
|
|
||||||
scope.Err(f(newDB))
|
|
||||||
scope.Err(newDB.Error)
|
|
||||||
default:
|
|
||||||
scope.Err(fmt.Errorf("unsupported function %v", name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
value := values.Index(i).Addr().Interface()
|
|
||||||
if values.Index(i).Kind() == reflect.Ptr {
|
|
||||||
value = values.Index(i).Interface()
|
|
||||||
}
|
|
||||||
call(value)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if scope.IndirectValue().CanAddr() {
|
scope.callMethod(methodName, indirectScopeValue)
|
||||||
call(scope.IndirectValue().Addr().Interface())
|
|
||||||
} else {
|
|
||||||
call(scope.IndirectValue().Interface())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) CallMethodWithErrorCheck(name string) {
|
// AddToVars add value as sql's vars, used to prevent SQL injection
|
||||||
scope.CallMethod(name, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddToVars add value as sql's vars, gorm will escape them
|
|
||||||
func (scope *Scope) AddToVars(value interface{}) string {
|
func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
if expr, ok := value.(*expr); ok {
|
if expr, ok := value.(*expr); ok {
|
||||||
exp := expr.expr
|
exp := expr.expr
|
||||||
|
@ -251,10 +244,10 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
||||||
}
|
}
|
||||||
return exp
|
return exp
|
||||||
} else {
|
|
||||||
scope.SqlVars = append(scope.SqlVars, value)
|
|
||||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scope.SQLVars = append(scope.SQLVars, value)
|
||||||
|
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
type tabler interface {
|
type tabler interface {
|
||||||
|
@ -265,7 +258,7 @@ type dbTabler interface {
|
||||||
TableName(*DB) string
|
TableName(*DB) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName get table name
|
// TableName return table name
|
||||||
func (scope *Scope) TableName() string {
|
func (scope *Scope) TableName() string {
|
||||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||||
return scope.Search.tableName
|
return scope.Search.tableName
|
||||||
|
@ -282,44 +275,54 @@ func (scope *Scope) TableName() string {
|
||||||
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
|
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QuotedTableName return quoted table name
|
||||||
func (scope *Scope) QuotedTableName() (name string) {
|
func (scope *Scope) QuotedTableName() (name string) {
|
||||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||||
if strings.Index(scope.Search.tableName, " ") != -1 {
|
if strings.Index(scope.Search.tableName, " ") != -1 {
|
||||||
return scope.Search.tableName
|
return scope.Search.tableName
|
||||||
}
|
}
|
||||||
return scope.Quote(scope.Search.tableName)
|
return scope.Quote(scope.Search.tableName)
|
||||||
} else {
|
|
||||||
return scope.Quote(scope.TableName())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return scope.Quote(scope.TableName())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CombinedConditionSql get combined condition sql
|
// CombinedConditionSql return combined condition sql
|
||||||
func (scope *Scope) CombinedConditionSql() string {
|
func (scope *Scope) CombinedConditionSql() string {
|
||||||
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
|
return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
|
||||||
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
|
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FieldByName find `gorm.Field` with field name or db name
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
|
var (
|
||||||
|
dbName = ToDBName(name)
|
||||||
|
mostMatchedField *Field
|
||||||
|
)
|
||||||
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if field.Name == name || field.DBName == name {
|
if field.Name == name || field.DBName == name {
|
||||||
return field, true
|
return field, true
|
||||||
}
|
}
|
||||||
|
if field.DBName == dbName {
|
||||||
|
mostMatchedField = field
|
||||||
}
|
}
|
||||||
return nil, false
|
}
|
||||||
|
return mostMatchedField, mostMatchedField != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw set sql
|
// Raw set raw sql
|
||||||
func (scope *Scope) Raw(sql string) *Scope {
|
func (scope *Scope) Raw(sql string) *Scope {
|
||||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
scope.SQL = strings.Replace(sql, "$$", "?", -1)
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec invoke sql
|
// Exec perform generated SQL
|
||||||
func (scope *Scope) Exec() *Scope {
|
func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.Trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||||
scope.db.RowsAffected = count
|
scope.db.RowsAffected = count
|
||||||
}
|
}
|
||||||
|
@ -334,37 +337,32 @@ func (scope *Scope) Set(name string, value interface{}) *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get value by name
|
// Get get setting by name
|
||||||
func (scope *Scope) Get(name string) (interface{}, bool) {
|
func (scope *Scope) Get(name string) (interface{}, bool) {
|
||||||
return scope.db.Get(name)
|
return scope.db.Get(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstanceId get InstanceId for scope
|
// InstanceID get InstanceID for scope
|
||||||
func (scope *Scope) InstanceId() string {
|
func (scope *Scope) InstanceID() string {
|
||||||
if scope.instanceId == "" {
|
if scope.instanceID == "" {
|
||||||
scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
|
scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db)
|
||||||
}
|
}
|
||||||
return scope.instanceId
|
return scope.instanceID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
|
||||||
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
||||||
return scope.Set(name+scope.InstanceId(), value)
|
return scope.Set(name+scope.InstanceID(), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstanceGet get instance setting from current operation
|
||||||
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||||
return scope.Get(name + scope.InstanceId())
|
return scope.Get(name + scope.InstanceID())
|
||||||
}
|
|
||||||
|
|
||||||
// Trace print sql log
|
|
||||||
func (scope *Scope) Trace(t time.Time) {
|
|
||||||
if len(scope.Sql) > 0 {
|
|
||||||
scope.db.slog(scope.Sql, t, scope.SqlVars...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin start a transaction
|
// Begin start a transaction
|
||||||
func (scope *Scope) Begin() *Scope {
|
func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.SqlDB().(sqlDb); ok {
|
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); err == nil {
|
||||||
scope.db.db = interface{}(tx).(sqlCommon)
|
scope.db.db = interface{}(tx).(sqlCommon)
|
||||||
scope.InstanceSet("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
|
@ -373,7 +371,7 @@ func (scope *Scope) Begin() *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
|
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
|
||||||
func (scope *Scope) CommitOrRollback() *Scope {
|
func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
|
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
|
||||||
if db, ok := scope.db.db.(sqlTx); ok {
|
if db, ok := scope.db.db.(sqlTx); ok {
|
||||||
|
@ -388,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectAttrs return selected attributes
|
||||||
func (scope *Scope) SelectAttrs() []string {
|
func (scope *Scope) SelectAttrs() []string {
|
||||||
if scope.selectAttrs == nil {
|
if scope.selectAttrs == nil {
|
||||||
attrs := []string{}
|
attrs := []string{}
|
||||||
|
@ -407,57 +406,38 @@ func (scope *Scope) SelectAttrs() []string {
|
||||||
return *scope.selectAttrs
|
return *scope.selectAttrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OmitAttrs return omited attributes
|
||||||
func (scope *Scope) OmitAttrs() []string {
|
func (scope *Scope) OmitAttrs() []string {
|
||||||
return scope.Search.omits
|
return scope.Search.omits
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) changeableDBColumn(column string) bool {
|
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
||||||
selectAttrs := scope.SelectAttrs()
|
var values = make([]interface{}, len(columns))
|
||||||
omitAttrs := scope.OmitAttrs()
|
var ignored interface{}
|
||||||
|
|
||||||
if len(selectAttrs) > 0 {
|
for index, column := range columns {
|
||||||
for _, attr := range selectAttrs {
|
if field, ok := fieldsMap[column]; ok {
|
||||||
if column == ToDBName(attr) {
|
if field.Field.Kind() == reflect.Ptr {
|
||||||
return true
|
values[index] = field.Field.Addr().Interface()
|
||||||
|
} else {
|
||||||
|
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
|
||||||
|
reflectValue.Elem().Set(field.Field.Addr())
|
||||||
|
values[index] = reflectValue.Interface()
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
values[index] = &ignored
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, attr := range omitAttrs {
|
scope.Err(rows.Scan(values...))
|
||||||
if column == ToDBName(attr) {
|
|
||||||
return false
|
for index, column := range columns {
|
||||||
|
if field, ok := fieldsMap[column]; ok {
|
||||||
|
if field.Field.Kind() != reflect.Ptr {
|
||||||
|
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
||||||
|
field.Field.Set(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) changeableField(field *Field) bool {
|
|
||||||
selectAttrs := scope.SelectAttrs()
|
|
||||||
omitAttrs := scope.OmitAttrs()
|
|
||||||
|
|
||||||
if len(selectAttrs) > 0 {
|
|
||||||
for _, attr := range selectAttrs {
|
|
||||||
if field.Name == attr || field.DBName == attr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, attr := range omitAttrs {
|
|
||||||
if field.Name == attr || field.DBName == attr {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return !field.IsIgnored
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) shouldSaveAssociations() bool {
|
|
||||||
saveAssociations, ok := scope.Get("gorm:save_associations")
|
|
||||||
if ok && !saveAssociations.(bool) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true && !scope.HasError()
|
|
||||||
}
|
}
|
||||||
|
|
278
scope_private.go
278
scope_private.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||||
|
@ -75,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
var notEqualSql string
|
var notEqualSQL string
|
||||||
var primaryKey = scope.PrimaryKey()
|
var primaryKey = scope.PrimaryKey()
|
||||||
|
|
||||||
switch value := clause["query"].(type) {
|
switch value := clause["query"].(type) {
|
||||||
|
@ -86,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||||
} else {
|
} else {
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||||
|
@ -138,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = scanner.Value()
|
arg, _ = scanner.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
|
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -172,17 +173,20 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) whereSql() (sql string) {
|
func (scope *Scope) whereSQL() (sql string) {
|
||||||
var primaryConditions, andConditions, orConditions []string
|
var (
|
||||||
|
quotedTableName = scope.QuotedTableName()
|
||||||
|
primaryConditions, andConditions, orConditions []string
|
||||||
|
)
|
||||||
|
|
||||||
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
|
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
||||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
||||||
primaryConditions = append(primaryConditions, sql)
|
primaryConditions = append(primaryConditions, sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.PrimaryKeyZero() {
|
if !scope.PrimaryKeyZero() {
|
||||||
for _, field := range scope.PrimaryFields() {
|
for _, field := range scope.PrimaryFields() {
|
||||||
sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||||
primaryConditions = append(primaryConditions, sql)
|
primaryConditions = append(primaryConditions, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -205,30 +209,30 @@ func (scope *Scope) whereSql() (sql string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
orSql := strings.Join(orConditions, " OR ")
|
orSQL := strings.Join(orConditions, " OR ")
|
||||||
combinedSql := strings.Join(andConditions, " AND ")
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
if len(combinedSql) > 0 {
|
if len(combinedSQL) > 0 {
|
||||||
if len(orSql) > 0 {
|
if len(orSQL) > 0 {
|
||||||
combinedSql = combinedSql + " OR " + orSql
|
combinedSQL = combinedSQL + " OR " + orSQL
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
combinedSql = orSql
|
combinedSQL = orSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(primaryConditions) > 0 {
|
if len(primaryConditions) > 0 {
|
||||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||||
if len(combinedSql) > 0 {
|
if len(combinedSQL) > 0 {
|
||||||
sql = sql + " AND (" + combinedSql + ")"
|
sql = sql + " AND (" + combinedSQL + ")"
|
||||||
}
|
}
|
||||||
} else if len(combinedSql) > 0 {
|
} else if len(combinedSQL) > 0 {
|
||||||
sql = "WHERE " + combinedSql
|
sql = "WHERE " + combinedSQL
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) selectSql() string {
|
func (scope *Scope) selectSQL() string {
|
||||||
if len(scope.Search.selects) == 0 {
|
if len(scope.Search.selects) == 0 {
|
||||||
if scope.Search.joins != "" {
|
if len(scope.Search.joinConditions) > 0 {
|
||||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||||
}
|
}
|
||||||
return "*"
|
return "*"
|
||||||
|
@ -236,87 +240,60 @@ func (scope *Scope) selectSql() string {
|
||||||
return scope.buildSelectQuery(scope.Search.selects)
|
return scope.buildSelectQuery(scope.Search.selects)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) orderSql() string {
|
func (scope *Scope) orderSQL() string {
|
||||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) limitSql() string {
|
func (scope *Scope) limitAndOffsetSQL() string {
|
||||||
if !scope.Dialect().HasTop() {
|
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||||
if len(scope.Search.limit) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return " LIMIT " + scope.Search.limit
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) topSql() string {
|
func (scope *Scope) groupSQL() string {
|
||||||
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
|
|
||||||
if len(scope.Search.limit) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return " TOP(" + scope.Search.limit + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) offsetSql() string {
|
|
||||||
if len(scope.Search.offset) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope.Dialect().HasTop() {
|
|
||||||
sql := " OFFSET " + scope.Search.offset + " ROW "
|
|
||||||
if len(scope.Search.limit) > 0 {
|
|
||||||
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
|
|
||||||
}
|
|
||||||
return sql
|
|
||||||
}
|
|
||||||
return " OFFSET " + scope.Search.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) groupSql() string {
|
|
||||||
if len(scope.Search.group) == 0 {
|
if len(scope.Search.group) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return " GROUP BY " + scope.Search.group
|
return " GROUP BY " + scope.Search.group
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) havingSql() string {
|
func (scope *Scope) havingSQL() string {
|
||||||
if scope.Search.havingConditions == nil {
|
if len(scope.Search.havingConditions) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var andConditions []string
|
var andConditions []string
|
||||||
|
|
||||||
for _, clause := range scope.Search.havingConditions {
|
for _, clause := range scope.Search.havingConditions {
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
andConditions = append(andConditions, sql)
|
andConditions = append(andConditions, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
combinedSql := strings.Join(andConditions, " AND ")
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
if len(combinedSql) == 0 {
|
if len(combinedSQL) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return " HAVING " + combinedSql
|
return " HAVING " + combinedSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) joinsSql() string {
|
func (scope *Scope) joinsSQL() string {
|
||||||
return scope.Search.joins + " "
|
var joinConditions []string
|
||||||
|
for _, clause := range scope.Search.joinConditions {
|
||||||
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(joinConditions, " ") + " "
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) prepareQuerySql() {
|
func (scope *Scope) prepareQuerySQL() {
|
||||||
if scope.Search.raw {
|
if scope.Search.raw {
|
||||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -338,61 +315,53 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||||
if !scope.IndirectValue().CanAddr() {
|
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||||
return values, true
|
return values, true
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasExpr bool
|
results = map[string]interface{}{}
|
||||||
for key, value := range values {
|
for key, value := range values {
|
||||||
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
if _, ok := value.(*expr); ok {
|
if _, ok := value.(*expr); ok {
|
||||||
hasExpr = true
|
|
||||||
} else if !equalAsString(field.Field.Interface(), value) {
|
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
|
results[field.DBName] = value
|
||||||
|
} else if !equalAsString(field.Field.Interface(), value) {
|
||||||
|
field.Set(value)
|
||||||
|
if field.IsNormal {
|
||||||
|
hasUpdate = true
|
||||||
|
results[field.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
field.Set(value)
|
field.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if hasExpr {
|
|
||||||
var updateMap = map[string]interface{}{}
|
|
||||||
for key, field := range scope.Fields() {
|
|
||||||
if field.IsNormal {
|
|
||||||
if v, ok := values[key]; ok {
|
|
||||||
updateMap[key] = v
|
|
||||||
} else {
|
|
||||||
updateMap[key] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return updateMap, true
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row() *sql.Row {
|
||||||
defer scope.Trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
defer scope.Trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) initialize() *Scope {
|
func (scope *Scope) initialize() *Scope {
|
||||||
for _, clause := range scope.Search.whereConditions {
|
for _, clause := range scope.Search.whereConditions {
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||||
}
|
}
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,23 +402,45 @@ func (scope *Scope) typeName() string {
|
||||||
return typ.Name()
|
return typ.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trace print sql log
|
||||||
|
func (scope *Scope) trace(t time.Time) {
|
||||||
|
if len(scope.SQL) > 0 {
|
||||||
|
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) changeableField(field *Field) bool {
|
||||||
|
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
||||||
|
for _, attr := range selectAttrs {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attr := range scope.OmitAttrs() {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) shouldSaveAssociations() bool {
|
||||||
|
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true && !scope.HasError()
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
toScope := scope.db.NewScope(value)
|
toScope := scope.db.NewScope(value)
|
||||||
fromFields := scope.Fields()
|
|
||||||
toFields := toScope.Fields()
|
|
||||||
|
|
||||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||||
var fromField, toField *Field
|
fromField, _ := scope.FieldByName(foreignKey)
|
||||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
toField, _ := toScope.FieldByName(foreignKey)
|
||||||
fromField = field
|
|
||||||
} else {
|
|
||||||
fromField = fromFields[ToDBName(foreignKey)]
|
|
||||||
}
|
|
||||||
if field, ok := toScope.FieldByName(foreignKey); ok {
|
|
||||||
toField = field
|
|
||||||
} else {
|
|
||||||
toField = toFields[ToDBName(foreignKey)]
|
|
||||||
}
|
|
||||||
|
|
||||||
if fromField != nil {
|
if fromField != nil {
|
||||||
if relationship := fromField.Relationship; relationship != nil {
|
if relationship := fromField.Relationship; relationship != nil {
|
||||||
|
@ -508,30 +499,26 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(joinTable) {
|
||||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||||
|
|
||||||
var sqlTypes, primaryKeys []string
|
var sqlTypes, primaryKeys []string
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
if field, ok := scope.Fields()[fieldName]; ok {
|
if field, ok := scope.FieldByName(fieldName); ok {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
foreignKeyStruct := field.clone()
|
||||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
if primaryKeySqlType == "" {
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
}
|
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
foreignKeyStruct := field.clone()
|
||||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
if primaryKeySqlType == "" {
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
}
|
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -545,10 +532,10 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
func (scope *Scope) createTable() *Scope {
|
func (scope *Scope) createTable() *Scope {
|
||||||
var tags []string
|
var tags []string
|
||||||
var primaryKeys []string
|
var primaryKeys []string
|
||||||
var primaryKeyInColumnType bool = false
|
var primaryKeyInColumnType = false
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.generateSqlTag(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
|
|
||||||
// Check if the primary key constraint was specified as
|
// Check if the primary key constraint was specified as
|
||||||
// part of the column type. If so, we can only support
|
// part of the column type. If so, we can only support
|
||||||
|
@ -582,13 +569,6 @@ func (scope *Scope) dropTable() *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropTableIfExists() *Scope {
|
|
||||||
if scope.Dialect().HasTable(scope, scope.TableName()) {
|
|
||||||
scope.dropTable()
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
||||||
}
|
}
|
||||||
|
@ -598,13 +578,13 @@ func (scope *Scope) dropColumn(column string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
|
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var columns []string
|
var columns []string
|
||||||
for _, name := range column {
|
for _, name := range column {
|
||||||
columns = append(columns, scope.QuoteIfPossible(name))
|
columns = append(columns, scope.quoteIfPossible(name))
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlCreate := "CREATE INDEX"
|
sqlCreate := "CREATE INDEX"
|
||||||
|
@ -612,31 +592,35 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
sqlCreate = "CREATE UNIQUE INDEX"
|
sqlCreate = "CREATE UNIQUE INDEX"
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
|
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
||||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||||
|
|
||||||
|
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(indexName string) {
|
func (scope *Scope) removeIndex(indexName string) {
|
||||||
scope.Dialect().RemoveIndex(scope, indexName)
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) autoMigrate() *Scope {
|
func (scope *Scope) autoMigrate() *Scope {
|
||||||
tableName := scope.TableName()
|
tableName := scope.TableName()
|
||||||
quotedTableName := scope.QuotedTableName()
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
|
||||||
if !scope.Dialect().HasTable(scope, tableName) {
|
if !scope.Dialect().HasTable(tableName) {
|
||||||
scope.createTable()
|
scope.createTable()
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.generateSqlTag(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||||
|
for _, value := range values {
|
||||||
|
indirectValue := reflect.ValueOf(value)
|
||||||
|
for indirectValue.Kind() == reflect.Ptr {
|
||||||
|
indirectValue = indirectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch indirectValue.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < indirectValue.Len(); i++ {
|
||||||
|
var result []interface{}
|
||||||
|
var object = indirect(indirectValue.Index(i))
|
||||||
|
for _, column := range columns {
|
||||||
|
result = append(result, object.FieldByName(column).Interface())
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
var result []interface{}
|
||||||
|
for _, column := range columns {
|
||||||
|
result = append(result, indirectValue.FieldByName(column).Interface())
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||||
|
indirectScopeValue := scope.IndirectValue()
|
||||||
|
|
||||||
|
switch indirectScopeValue.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||||
|
fieldType := fieldStruct.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||||
|
|
||||||
|
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||||
|
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||||
|
|
||||||
|
if result.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < result.Len(); j++ {
|
||||||
|
if elem := result.Index(j); elem.CanAddr() {
|
||||||
|
results = reflect.Append(results, elem.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if result.CanAddr() {
|
||||||
|
results = reflect.Append(results, result.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope.New(results.Interface())
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||||
|
return scope.New(field.Addr().Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
24
search.go
24
search.go
|
@ -8,15 +8,15 @@ type search struct {
|
||||||
orConditions []map[string]interface{}
|
orConditions []map[string]interface{}
|
||||||
notConditions []map[string]interface{}
|
notConditions []map[string]interface{}
|
||||||
havingConditions []map[string]interface{}
|
havingConditions []map[string]interface{}
|
||||||
|
joinConditions []map[string]interface{}
|
||||||
initAttrs []interface{}
|
initAttrs []interface{}
|
||||||
assignAttrs []interface{}
|
assignAttrs []interface{}
|
||||||
selects map[string]interface{}
|
selects map[string]interface{}
|
||||||
omits []string
|
omits []string
|
||||||
orders []string
|
orders []string
|
||||||
joins string
|
|
||||||
preload []searchPreload
|
preload []searchPreload
|
||||||
offset string
|
offset int
|
||||||
limit string
|
limit int
|
||||||
group string
|
group string
|
||||||
tableName string
|
tableName string
|
||||||
raw bool
|
raw bool
|
||||||
|
@ -82,18 +82,18 @@ func (s *search) Omit(columns ...string) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) Limit(value interface{}) *search {
|
func (s *search) Limit(limit int) *search {
|
||||||
s.limit = s.getInterfaceAsSql(value)
|
s.limit = limit
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) Offset(value interface{}) *search {
|
func (s *search) Offset(offset int) *search {
|
||||||
s.offset = s.getInterfaceAsSql(value)
|
s.offset = offset
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) Group(query string) *search {
|
func (s *search) Group(query string) *search {
|
||||||
s.group = s.getInterfaceAsSql(query)
|
s.group = s.getInterfaceAsSQL(query)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) Joins(query string) *search {
|
func (s *search) Joins(query string, values ...interface{}) *search {
|
||||||
s.joins = query
|
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
|
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||||
switch value.(type) {
|
switch value.(type) {
|
||||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
str = fmt.Sprintf("%v", value)
|
str = fmt.Sprintf("%v", value)
|
||||||
default:
|
default:
|
||||||
s.db.AddError(InvalidSql)
|
s.db.AddError(ErrInvalidSQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if str == "-1" {
|
if str == "-1" {
|
||||||
|
|
84
sqlite3.go
84
sqlite3.go
|
@ -1,84 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type sqlite3 struct {
|
|
||||||
commonDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
return "bool"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if autoIncrease {
|
|
||||||
return "integer primary key autoincrement"
|
|
||||||
}
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if autoIncrease {
|
|
||||||
return "integer primary key autoincrement"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "real"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "text"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := value.Interface().(time.Time); ok {
|
|
||||||
return "datetime"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := value.Interface().([]byte); ok {
|
|
||||||
return "blob"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var count int
|
|
||||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
|
|
||||||
var (
|
|
||||||
ifaces = make([]interface{}, 3)
|
|
||||||
pointers = make([]*string, 3)
|
|
||||||
i int
|
|
||||||
)
|
|
||||||
for i = 0; i < 3; i++ {
|
|
||||||
ifaces[i] = &pointers[i]
|
|
||||||
}
|
|
||||||
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if pointers[1] != nil {
|
|
||||||
name = *pointers[1]
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -42,9 +42,9 @@ type CreditCard struct {
|
||||||
ID int8
|
ID int8
|
||||||
Number string
|
Number string
|
||||||
UserId sql.NullInt64
|
UserId sql.NullInt64
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time `sql:"not null"`
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt time.Time
|
DeletedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Email struct {
|
type Email struct {
|
||||||
|
@ -62,7 +62,7 @@ type Address struct {
|
||||||
Post string
|
Post string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt time.Time
|
DeletedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Language struct {
|
type Language struct {
|
||||||
|
|
|
@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.First(&product4, product4.Id)
|
DB.First(&product4, product4.Id)
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||||
var product5 Product
|
var product5 Product
|
||||||
DB.First(&product5, product4.Id)
|
DB.First(&product5, product4.Id)
|
||||||
if product5.Price != product4.Price+100-50 {
|
if product5.Price != product4.Price+100-50 {
|
||||||
t.Errorf("Update with expression")
|
t.Errorf("Update with expression")
|
||||||
}
|
}
|
||||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("Update with expression should update UpdatedAt")
|
t.Errorf("Update with expression should update UpdatedAt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
|
||||||
t.Errorf("product2's code should be updated")
|
t.Errorf("product2's code should be updated")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||||
var product5 Product
|
var product5 Product
|
||||||
DB.First(&product5, product4.Id)
|
DB.First(&product5, product4.Id)
|
||||||
if product5.Price != product4.Price+100 {
|
if product5.Price != product4.Price+100 {
|
||||||
t.Errorf("Updates with expression")
|
t.Errorf("Updates with expression")
|
||||||
}
|
}
|
||||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
// product4's UpdatedAt will be reset when updating
|
||||||
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("Updates with expression should update UpdatedAt")
|
t.Errorf("Updates with expression should update UpdatedAt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -419,3 +422,32 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
||||||
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
|
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||||
|
product := Product{Code: "product1", Price: 10}
|
||||||
|
DB.Save(&product)
|
||||||
|
|
||||||
|
DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
|
||||||
|
|
||||||
|
var product1 Product
|
||||||
|
DB.First(&product1, product.Id)
|
||||||
|
|
||||||
|
if product1.Code != "product1" || product1.Price != 100 {
|
||||||
|
t.Errorf("product's code should not be updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDecodeVirtualAttributes(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "jinzhu",
|
||||||
|
IgnoreMe: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
|
||||||
|
|
||||||
|
if user.IgnoreMe != 100 {
|
||||||
|
t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
233
utils.go
233
utils.go
|
@ -2,10 +2,26 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NowFunc returns current time, this function is exported in order to be able
|
||||||
|
// to give the flexibility to the developer to customize it according to their
|
||||||
|
// needs, e.g:
|
||||||
|
// gorm.NowFunc = func() time.Time {
|
||||||
|
// return time.Now().UTC()
|
||||||
|
// }
|
||||||
|
var NowFunc = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
// Copied from golint
|
// Copied from golint
|
||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
var commonInitialismsReplacer *strings.Replacer
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
@ -41,30 +57,239 @@ func newSafeMap() *safeMap {
|
||||||
|
|
||||||
var smap = newSafeMap()
|
var smap = newSafeMap()
|
||||||
|
|
||||||
|
type strCase bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
lower strCase = false
|
||||||
|
upper strCase = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToDBName convert string to db name
|
||||||
func ToDBName(name string) string {
|
func ToDBName(name string) string {
|
||||||
if v := smap.Get(name); v != "" {
|
if v := smap.Get(name); v != "" {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
value := commonInitialismsReplacer.Replace(name)
|
if name == "" {
|
||||||
buf := bytes.NewBufferString("")
|
return ""
|
||||||
for i, v := range value {
|
}
|
||||||
if i > 0 && v >= 'A' && v <= 'Z' {
|
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf = bytes.NewBufferString("")
|
||||||
|
lastCase, currCase, nextCase strCase
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z'
|
||||||
|
if i > 0 {
|
||||||
|
if currCase == upper {
|
||||||
|
if lastCase == upper && nextCase == upper {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
} else {
|
||||||
|
if value[i-1] != '_' && value[i+1] != '_' {
|
||||||
buf.WriteRune('_')
|
buf.WriteRune('_')
|
||||||
}
|
}
|
||||||
buf.WriteRune(v)
|
buf.WriteRune(v)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currCase = upper
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
lastCase = currCase
|
||||||
|
currCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
|
||||||
s := strings.ToLower(buf.String())
|
s := strings.ToLower(buf.String())
|
||||||
smap.Set(name, s)
|
smap.Set(name, s)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SQL expression
|
||||||
type expr struct {
|
type expr struct {
|
||||||
expr string
|
expr string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expr generate raw SQL expression, for example:
|
||||||
|
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
|
||||||
func Expr(expression string, args ...interface{}) *expr {
|
func Expr(expression string, args ...interface{}) *expr {
|
||||||
return &expr{expr: expression, args: args}
|
return &expr{expr: expression, args: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||||
|
for reflectValue.Kind() == reflect.Ptr {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
return reflectValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
var marks []string
|
||||||
|
for _ = range primaryValue {
|
||||||
|
marks = append(marks, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(marks) > 1 {
|
||||||
|
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||||
|
} else {
|
||||||
|
results = append(results, strings.Join(marks, ""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(results, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryCondition(scope *Scope, columns []string) string {
|
||||||
|
var newColumns []string
|
||||||
|
for _, column := range columns {
|
||||||
|
newColumns = append(newColumns, scope.Quote(column))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) > 1 {
|
||||||
|
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||||
|
}
|
||||||
|
return strings.Join(newColumns, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||||
|
for _, value := range values {
|
||||||
|
for _, v := range value {
|
||||||
|
results = append(results, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileWithLineNum() string {
|
||||||
|
for i := 2; i < 15; i++ {
|
||||||
|
_, file, line, ok := runtime.Caller(i)
|
||||||
|
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||||
|
return fmt.Sprintf("%v:%v", file, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBlank(value reflect.Value) bool {
|
||||||
|
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||||
|
if len(attrs) > 1 {
|
||||||
|
if str, ok := attrs[0].(string); ok {
|
||||||
|
result = map[string]interface{}{str: attrs[1]}
|
||||||
|
}
|
||||||
|
} else if len(attrs) == 1 {
|
||||||
|
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||||
|
result = attr
|
||||||
|
}
|
||||||
|
|
||||||
|
if attr, ok := attrs[0].(interface{}); ok {
|
||||||
|
result = attr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
|
||||||
|
switch value := values.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
return value
|
||||||
|
case []interface{}:
|
||||||
|
for _, v := range value {
|
||||||
|
for key, value := range convertInterfaceToMap(v) {
|
||||||
|
attrs[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case interface{}:
|
||||||
|
reflectValue := reflect.ValueOf(values)
|
||||||
|
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
for _, key := range reflectValue.MapKeys() {
|
||||||
|
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
for _, field := range (&Scope{Value: values}).Fields() {
|
||||||
|
if !field.IsBlank {
|
||||||
|
attrs[field.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
|
return toString(a) == toString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toString(str interface{}) string {
|
||||||
|
if values, ok := str.([]interface{}); ok {
|
||||||
|
var results []string
|
||||||
|
for _, value := range values {
|
||||||
|
results = append(results, toString(value))
|
||||||
|
}
|
||||||
|
return strings.Join(results, "_")
|
||||||
|
} else if bytes, ok := str.([]byte); ok {
|
||||||
|
return string(bytes)
|
||||||
|
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||||
|
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSlice(elemType reflect.Type) interface{} {
|
||||||
|
if elemType.Kind() == reflect.Slice {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
sliceType := reflect.SliceOf(elemType)
|
||||||
|
slice := reflect.New(sliceType)
|
||||||
|
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||||
|
return slice.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func strInSlice(a string, list []string) bool {
|
||||||
|
for _, b := range list {
|
||||||
|
if b == a {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValueFromFields return given fields's value
|
||||||
|
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
|
||||||
|
// If value is a nil pointer, Indirect returns a zero Value!
|
||||||
|
// Therefor we need to check for a zero value,
|
||||||
|
// as FieldByName could panic
|
||||||
|
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||||
|
for _, fieldName := range fieldNames {
|
||||||
|
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
||||||
|
result := fieldValue.Interface()
|
||||||
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
|
result, _ = r.Value()
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func addExtraSpaceIfExist(str string) string {
|
||||||
|
if str != "" {
|
||||||
|
return " " + str
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
|
@ -1,98 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func fileWithLineNum() string {
|
|
||||||
for i := 2; i < 15; i++ {
|
|
||||||
_, file, line, ok := runtime.Caller(i)
|
|
||||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
|
||||||
return fmt.Sprintf("%v:%v", file, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBlank(value reflect.Value) bool {
|
|
||||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
|
||||||
}
|
|
||||||
|
|
||||||
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
|
||||||
if len(attrs) > 1 {
|
|
||||||
if str, ok := attrs[0].(string); ok {
|
|
||||||
result = map[string]interface{}{str: attrs[1]}
|
|
||||||
}
|
|
||||||
} else if len(attrs) == 1 {
|
|
||||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
|
||||||
result = attr
|
|
||||||
}
|
|
||||||
|
|
||||||
if attr, ok := attrs[0].(interface{}); ok {
|
|
||||||
result = attr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
|
|
||||||
switch value := values.(type) {
|
|
||||||
case map[string]interface{}:
|
|
||||||
for k, v := range value {
|
|
||||||
attrs[ToDBName(k)] = v
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, v := range value {
|
|
||||||
for key, value := range convertInterfaceToMap(v) {
|
|
||||||
attrs[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case interface{}:
|
|
||||||
reflectValue := reflect.ValueOf(values)
|
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
|
||||||
case reflect.Map:
|
|
||||||
for _, key := range reflectValue.MapKeys() {
|
|
||||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
scope := Scope{Value: values}
|
|
||||||
for _, field := range scope.Fields() {
|
|
||||||
if !field.IsBlank && !field.IsIgnored {
|
|
||||||
attrs[field.DBName] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
func toString(str interface{}) string {
|
|
||||||
if values, ok := str.([]interface{}); ok {
|
|
||||||
var results []string
|
|
||||||
for _, value := range values {
|
|
||||||
results = append(results, toString(value))
|
|
||||||
}
|
|
||||||
return strings.Join(results, "_")
|
|
||||||
} else if bytes, ok := str.([]byte); ok {
|
|
||||||
return string(bytes)
|
|
||||||
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
|
||||||
return fmt.Sprintf("%v", reflectValue.Interface())
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func strInSlice(a string, list []string) bool {
|
|
||||||
for _, b := range list {
|
|
||||||
if b == a {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
||||||
|
var maps = map[string]string{
|
||||||
|
"": "",
|
||||||
|
"ThisIsATest": "this_is_a_test",
|
||||||
|
"PFAndESI": "pf_and_esi",
|
||||||
|
"AbcAndJkl": "abc_and_jkl",
|
||||||
|
"EmployeeID": "employee_id",
|
||||||
|
"SKU_ID": "sku_id",
|
||||||
|
"HTTPAndSMTP": "http_and_smtp",
|
||||||
|
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
||||||
|
"UUID": "uuid",
|
||||||
|
"HTTPURL": "http_url",
|
||||||
|
"HTTP_URL": "http_url",
|
||||||
|
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range maps {
|
||||||
|
if gorm.ToDBName(key) != value {
|
||||||
|
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue