mirror of https://github.com/go-gorm/gorm.git
commit
9d57c6b961
|
@ -0,0 +1,2 @@
|
|||
documents
|
||||
_book
|
598
association.go
598
association.go
|
@ -4,32 +4,289 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
Scope *Scope
|
||||
Column string
|
||||
Error error
|
||||
Field *Field
|
||||
scope *Scope
|
||||
column string
|
||||
field *Field
|
||||
}
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.scope.related(value, association.column)
|
||||
return association.setErr(association.scope.db.Error)
|
||||
}
|
||||
|
||||
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
}
|
||||
return association.saveAssociations(values...)
|
||||
}
|
||||
|
||||
// Replace replace current associations with new one
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
// Append new values
|
||||
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||
association.saveAssociations(values...)
|
||||
|
||||
// Belongs To
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// Set foreign key to be null when clearing value (length equals 0)
|
||||
if len(values) == 0 {
|
||||
// Set foreign key to be nil
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
} else {
|
||||
// Polymorphic Relations
|
||||
if relationship.PolymorphicDBName != "" {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
|
||||
// Delete Relations except new created
|
||||
if len(values) > 0 {
|
||||
var associationForeignFieldNames []string
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, get association fields name from association foreign keys
|
||||
associationScope := scope.New(reflect.New(field.Type()).Interface())
|
||||
for _, dbName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(dbName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If other relations, use primary keys
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
|
||||
|
||||
if len(newPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, delete related relations from join table
|
||||
var sourceForeignFieldNames []string
|
||||
|
||||
for _, dbName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.Scope.related(value, association.Column)
|
||||
return association.setErr(association.Scope.db.Error)
|
||||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
if len(values) == 0 {
|
||||
return association
|
||||
}
|
||||
|
||||
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||
}
|
||||
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// source value's foreign keys
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// get association's foreign fields name
|
||||
var associationScope = scope.New(reflect.New(field.Type()).Interface())
|
||||
var associationForeignFieldNames []string
|
||||
for _, associationDBName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(associationDBName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// association value's foreign keys
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
|
||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
} else {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// find with deleting relation's foreign keys
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// set foreign key to be null if there are some records affected
|
||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||
if results.RowsAffected > 0 {
|
||||
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||
}
|
||||
} else {
|
||||
association.setErr(results.Error)
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// find all relations
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// only include those deleting relations
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||
toQueryValues(deletingPrimaryKeys)...,
|
||||
)
|
||||
|
||||
// set matched relation's foreign key to be null
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deleted records from source's field
|
||||
if association.Error == nil {
|
||||
if field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(field.Type())
|
||||
|
||||
for i := 0; i < field.Len(); i++ {
|
||||
reflectValue := field.Index(i)
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||
var isDeleted = false
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
isDeleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isDeleted {
|
||||
leftValues = reflect.Append(leftValues, reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
association.field.Set(leftValues)
|
||||
} else if field.Kind() == reflect.Struct {
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
association.field.Set(reflect.Zero(field.Type()))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association
|
||||
}
|
||||
|
||||
// Clear remove relationship between source & current associations, won't delete those associations
|
||||
func (association *Association) Clear() *Association {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
// Count return the count of current associations
|
||||
func (association *Association) Count() int {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
fieldValue = association.field.Field.Interface()
|
||||
query = scope.DB()
|
||||
)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
|
||||
scope.TableName(),
|
||||
)
|
||||
}
|
||||
|
||||
query.Model(fieldValue).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// saveAssociations save passed values as associations
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
field := association.Field
|
||||
relationship := association.Field.Relationship
|
||||
var (
|
||||
scope = association.scope
|
||||
field = association.field
|
||||
relationship = field.Relationship
|
||||
)
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
|
@ -94,318 +351,9 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
|
|||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
}
|
||||
return association.saveAssociations(values...)
|
||||
}
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
field = association.Field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
// Append new values
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
association.saveAssociations(values...)
|
||||
|
||||
// Belongs To
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// Set foreign key to be null only when clearing value
|
||||
if len(values) == 0 {
|
||||
// Set foreign key to be nil
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
} else {
|
||||
// Relations
|
||||
if relationship.PolymorphicDBName != "" {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
|
||||
// Relations except new created
|
||||
if len(values) > 0 {
|
||||
var newPrimaryKeys [][]interface{}
|
||||
var associationForeignFieldNames []string
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// If many to many relations, get it from foreign key
|
||||
associationForeignFieldNames = relationship.AssociationForeignFieldNames
|
||||
} else {
|
||||
// If other relations, get real primary keys
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
||||
|
||||
if len(newPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
var (
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
field = association.Field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
if len(values) == 0 {
|
||||
return association
|
||||
}
|
||||
|
||||
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// source value's foreign keys
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// association value's foreign keys
|
||||
deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
|
||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
} else {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// find with deleting relation's foreign keys
|
||||
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// set foreign key to be null
|
||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||
if results.RowsAffected > 0 {
|
||||
scope.updatedAttrsWithValues(foreignKeyMap, false)
|
||||
}
|
||||
} else {
|
||||
association.setErr(results.Error)
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// find all relations
|
||||
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// only include those deleting relations
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||
toQueryValues(deletingPrimaryKeys)...,
|
||||
)
|
||||
|
||||
// set matched relation's foreign key to be null
|
||||
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deleted records from field
|
||||
if association.Error == nil {
|
||||
if association.Field.Field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
reflectValue := association.Field.Field.Index(i)
|
||||
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||
var included = false
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
included = true
|
||||
}
|
||||
}
|
||||
if !included {
|
||||
leftValues = reflect.Append(leftValues, reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
association.Field.Set(leftValues)
|
||||
} else if association.Field.Field.Kind() == reflect.Struct {
|
||||
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Clear() *Association {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
func (association *Association) Count() int {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.Field.Relationship
|
||||
scope = association.Scope
|
||||
fieldValue = association.Field.Field.Interface()
|
||||
newScope = scope.New(fieldValue)
|
||||
)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
query := scope.DB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
|
||||
field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
query.Model(fieldValue).Count(&count)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
query := scope.DB()
|
||||
for idx, primaryKey := range relationship.AssociationForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),
|
||||
field.Field.Interface())
|
||||
}
|
||||
}
|
||||
query.Model(fieldValue).Count(&count)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
scope := association.Scope
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
primaryKeys := []interface{}{}
|
||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
newScope := scope.New(value)
|
||||
var primaryKeys []interface{}
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _, _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
} else {
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestBelongsTo(t *testing.T) {
|
||||
|
@ -16,7 +18,7 @@ func TestBelongsTo(t *testing.T) {
|
|||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Error("Got errors when save post", err)
|
||||
}
|
||||
|
||||
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||
|
@ -177,6 +179,49 @@ func TestBelongsTo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBelongsToOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
|
||||
ProfileRefer int
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "belongs_to" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBelongsToOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
|
||||
ProfileID int
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "belongs_to" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOne(t *testing.T) {
|
||||
user := User{
|
||||
Name: "has one",
|
||||
|
@ -184,7 +229,7 @@ func TestHasOne(t *testing.T) {
|
|||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Errorf("Got errors when save user", err.Error())
|
||||
t.Error("Got errors when save user", err.Error())
|
||||
}
|
||||
|
||||
if user.CreditCard.UserId.Int64 == 0 {
|
||||
|
@ -323,6 +368,49 @@ func TestHasOne(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:UserRefer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_one" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserID uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_one" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasMany(t *testing.T) {
|
||||
post := Post{
|
||||
Title: "post has many",
|
||||
|
@ -331,7 +419,7 @@ func TestHasMany(t *testing.T) {
|
|||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
t.Error("Got errors when save post", err)
|
||||
}
|
||||
|
||||
for _, comment := range post.Comments {
|
||||
|
@ -462,6 +550,49 @@ func TestHasMany(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHasManyOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile []Profile `gorm:"ForeignKey:UserRefer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_many" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasManyOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserID uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_many" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyToMany(t *testing.T) {
|
||||
DB.Raw("delete from languages")
|
||||
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
||||
|
|
258
callback.go
258
callback.go
|
@ -4,34 +4,39 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
type callback struct {
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{}
|
||||
|
||||
// Callback is a struct that contains all CURD callbacks
|
||||
// Field `creates` contains callbacks will be call when creating object
|
||||
// Field `updates` contains callbacks will be call when updating object
|
||||
// Field `deletes` contains callbacks will be call when deleting object
|
||||
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||
type Callback struct {
|
||||
creates []*func(scope *Scope)
|
||||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
rowQueries []*func(scope *Scope)
|
||||
processors []*callbackProcessor
|
||||
processors []*CallbackProcessor
|
||||
}
|
||||
|
||||
type callbackProcessor struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
replace bool
|
||||
remove bool
|
||||
typ string
|
||||
processor *func(scope *Scope)
|
||||
callback *callback
|
||||
// CallbackProcessor contains callback informations
|
||||
type CallbackProcessor struct {
|
||||
name string // current callback's name
|
||||
before string // register current callback before a callback
|
||||
after string // register current callback after a callback
|
||||
replace bool // replace callbacks with same name
|
||||
remove bool // delete callbacks with same name
|
||||
kind string // callback type: create, update, delete, query, row_query
|
||||
processor *func(scope *Scope) // callback handler
|
||||
parent *Callback
|
||||
}
|
||||
|
||||
func (c *callback) addProcessor(typ string) *callbackProcessor {
|
||||
cp := &callbackProcessor{typ: typ, callback: c}
|
||||
c.processors = append(c.processors, cp)
|
||||
return cp
|
||||
}
|
||||
|
||||
func (c *callback) clone() *callback {
|
||||
return &callback{
|
||||
func (c *Callback) clone() *Callback {
|
||||
return &Callback{
|
||||
creates: c.creates,
|
||||
updates: c.updates,
|
||||
deletes: c.deletes,
|
||||
|
@ -40,57 +45,95 @@ func (c *callback) clone() *callback {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *callback) Create() *callbackProcessor {
|
||||
return c.addProcessor("create")
|
||||
// Create could be used to register callbacks for creating object
|
||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||
// // business logic
|
||||
// ...
|
||||
//
|
||||
// // set error if some thing wrong happened, will rollback the creating
|
||||
// scope.Err(errors.New("error"))
|
||||
// })
|
||||
func (c *Callback) Create() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "create", parent: c}
|
||||
}
|
||||
|
||||
func (c *callback) Update() *callbackProcessor {
|
||||
return c.addProcessor("update")
|
||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||
func (c *Callback) Update() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "update", parent: c}
|
||||
}
|
||||
|
||||
func (c *callback) Delete() *callbackProcessor {
|
||||
return c.addProcessor("delete")
|
||||
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||
func (c *Callback) Delete() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "delete", parent: c}
|
||||
}
|
||||
|
||||
func (c *callback) Query() *callbackProcessor {
|
||||
return c.addProcessor("query")
|
||||
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||
// Refer `Create` for usage
|
||||
func (c *Callback) Query() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "query", parent: c}
|
||||
}
|
||||
|
||||
func (c *callback) RowQuery() *callbackProcessor {
|
||||
return c.addProcessor("row_query")
|
||||
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "row_query", parent: c}
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
|
||||
cp.before = name
|
||||
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||
cp.after = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) After(name string) *callbackProcessor {
|
||||
cp.after = name
|
||||
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||
cp.before = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
|
||||
cp.name = name
|
||||
cp.processor = &fc
|
||||
cp.callback.sort()
|
||||
// Register a new callback, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) Remove(name string) {
|
||||
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
|
||||
cp.name = name
|
||||
// Remove a registered callback
|
||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.remove = true
|
||||
cp.callback.sort()
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
|
||||
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
|
||||
cp.name = name
|
||||
cp.processor = &fc
|
||||
// Replace a registered callback with new callback
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||
// scope.SetColumn("Created", now)
|
||||
// scope.SetColumn("Updated", now)
|
||||
// })
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.replace = true
|
||||
cp.callback.sort()
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Get registered callback
|
||||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||
return *p.processor
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
|
@ -100,93 +143,88 @@ func getRIndex(strs []string, str string) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
||||
var sortCallbackProcessor func(c *callbackProcessor)
|
||||
var names, sortedNames = []string{}, []string{}
|
||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
var (
|
||||
allNames, sortedNames []string
|
||||
sortCallbackProcessor func(c *CallbackProcessor)
|
||||
)
|
||||
|
||||
for _, cp := range cps {
|
||||
if index := getRIndex(names, cp.name); index > -1 {
|
||||
if !cp.replace && !cp.remove {
|
||||
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||
}
|
||||
// show warning message the callback name already exists
|
||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||
}
|
||||
names = append(names, cp.name)
|
||||
allNames = append(allNames, cp.name)
|
||||
}
|
||||
|
||||
sortCallbackProcessor = func(c *callbackProcessor) {
|
||||
if getRIndex(sortedNames, c.name) > -1 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(c.before) > 0 {
|
||||
if index := getRIndex(sortedNames, c.before); index > -1 {
|
||||
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||
} else if index := getRIndex(names, c.before); index > -1 {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
sortCallbackProcessor(cps[index])
|
||||
} else {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.after) > 0 {
|
||||
if index := getRIndex(sortedNames, c.after); index > -1 {
|
||||
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||
} else if index := getRIndex(names, c.after); index > -1 {
|
||||
cp := cps[index]
|
||||
if len(cp.before) == 0 {
|
||||
cp.before = c.name
|
||||
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
|
||||
if c.before != "" { // if defined before callback
|
||||
if index := getRIndex(sortedNames, c.before); index != -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||
} else if index := getRIndex(allNames, c.before); index != -1 {
|
||||
// if before callback exists but haven't sorted, append current callback to last
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
sortCallbackProcessor(cps[index])
|
||||
}
|
||||
sortCallbackProcessor(cp)
|
||||
} else {
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if index := getRIndex(sortedNames, c.after); index != -1 {
|
||||
// if after callback already sorted, append current callback just before it
|
||||
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||
} else if index := getRIndex(allNames, c.after); index != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
cp := cps[index]
|
||||
// set after callback's before callback to current callback
|
||||
if cp.before == "" {
|
||||
cp.before = c.name
|
||||
}
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sortedNames, c.name) == -1 {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
}
|
||||
}
|
||||
|
||||
if getRIndex(sortedNames, c.name) == -1 {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cp := range cps {
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
|
||||
var funcs = []*func(scope *Scope){}
|
||||
var sortedFuncs = []*func(scope *Scope){}
|
||||
var sortedFuncs []*func(scope *Scope)
|
||||
for _, name := range sortedNames {
|
||||
index := getRIndex(names, name)
|
||||
if !cps[index].remove {
|
||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cp := range cps {
|
||||
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
|
||||
if !cp.remove {
|
||||
funcs = append(funcs, cp.processor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return append(sortedFuncs, funcs...)
|
||||
return sortedFuncs
|
||||
}
|
||||
|
||||
func (c *callback) sort() {
|
||||
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
|
||||
// reorder all registered processors, and reset CURD callbacks
|
||||
func (c *Callback) reorder() {
|
||||
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||
|
||||
for _, processor := range c.processors {
|
||||
switch processor.typ {
|
||||
case "create":
|
||||
creates = append(creates, processor)
|
||||
case "update":
|
||||
updates = append(updates, processor)
|
||||
case "delete":
|
||||
deletes = append(deletes, processor)
|
||||
case "query":
|
||||
queries = append(queries, processor)
|
||||
case "row_query":
|
||||
rowQueries = append(rowQueries, processor)
|
||||
if processor.name != "" {
|
||||
switch processor.kind {
|
||||
case "create":
|
||||
creates = append(creates, processor)
|
||||
case "update":
|
||||
updates = append(updates, processor)
|
||||
case "delete":
|
||||
deletes = append(deletes, processor)
|
||||
case "query":
|
||||
queries = append(queries, processor)
|
||||
case "row_query":
|
||||
rowQueries = append(rowQueries, processor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,5 +234,3 @@ func (c *callback) sort() {
|
|||
c.queries = sortProcessors(queries)
|
||||
c.rowQueries = sortProcessors(rowQueries)
|
||||
}
|
||||
|
||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
||||
|
|
|
@ -5,12 +5,31 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func BeforeCreate(scope *Scope) {
|
||||
scope.CallMethodWithErrorCheck("BeforeSave")
|
||||
scope.CallMethodWithErrorCheck("BeforeCreate")
|
||||
// Define callbacks for creating
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
func UpdateTimeStampWhenCreate(scope *Scope) {
|
||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||
func beforeCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
now := NowFunc()
|
||||
scope.SetColumn("CreatedAt", now)
|
||||
|
@ -18,109 +37,108 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
|
|||
}
|
||||
}
|
||||
|
||||
func Create(scope *Scope) {
|
||||
defer scope.Trace(NowFunc())
|
||||
|
||||
// createCallback the callback used to insert data into database
|
||||
func createCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
// set create sql
|
||||
var sqls, columns []string
|
||||
fields := scope.Fields()
|
||||
for _, field := range fields {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
columns, placeholders []string
|
||||
blankColumnsWithDefaultValue []string
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||
if !field.IsBlank || !field.HasDefaultValue {
|
||||
if !field.IsPrimaryKey || !field.IsBlank {
|
||||
if field.IsBlank && field.HasDefaultValue {
|
||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
|
||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||
} else {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
} else if field.HasDefaultValue {
|
||||
var hasDefaultValueColumns []string
|
||||
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
||||
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
|
||||
}
|
||||
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
|
||||
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
|
||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
}
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, dbName := range relationship.ForeignDBNames {
|
||||
if relationField := fields[dbName]; !scope.changeableField(relationField) {
|
||||
columns = append(columns, scope.Quote(relationField.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
returningKey := "*"
|
||||
primaryField := scope.PrimaryField()
|
||||
if primaryField != nil {
|
||||
returningKey = scope.Quote(primaryField.DBName)
|
||||
var (
|
||||
returningColumn = "*"
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryField = scope.PrimaryField()
|
||||
extraOption string
|
||||
)
|
||||
|
||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if primaryField != nil {
|
||||
returningColumn = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
|
||||
scope.QuotedTableName(),
|
||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||
quotedTableName,
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(sqls, ","),
|
||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if scope.Dialect().SupportLastInsertId() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
id, err := result.LastInsertId()
|
||||
if scope.Err(err) == nil {
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
scope.Err(scope.SetColumn(primaryField, id))
|
||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ForceReloadAfterCreate(scope *Scope) {
|
||||
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
|
||||
scope.DB().New().Select(columns.([]string)).First(scope.Value)
|
||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func AfterCreate(scope *Scope) {
|
||||
scope.CallMethodWithErrorCheck("AfterCreate")
|
||||
scope.CallMethodWithErrorCheck("AfterSave")
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
|
||||
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
|
||||
DefaultCallback.Create().Register("gorm:create", Create)
|
||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
|
||||
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
|
||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||
func afterCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterCreate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,35 +2,52 @@ package gorm
|
|||
|
||||
import "fmt"
|
||||
|
||||
func BeforeDelete(scope *Scope) {
|
||||
scope.CallMethodWithErrorCheck("BeforeDelete")
|
||||
// Define callbacks for deleting
|
||||
func init() {
|
||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
func Delete(scope *Scope) {
|
||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||
func beforeDeleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
|
||||
scope.Raw(
|
||||
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
|
||||
scope.QuotedTableName(),
|
||||
scope.AddToVars(NowFunc()),
|
||||
scope.CombinedConditionSql(),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
|
||||
scope.Exec()
|
||||
scope.CallMethod("BeforeDelete")
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(scope *Scope) {
|
||||
scope.CallMethodWithErrorCheck("AfterDelete")
|
||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||
func deleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET deleted_at=%v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
scope.AddToVars(NowFunc()),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"DELETE FROM %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
|
||||
DefaultCallback.Delete().Register("gorm:delete", Delete)
|
||||
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
|
||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||
func afterDeleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterDelete")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,115 +6,89 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
func Query(scope *Scope) {
|
||||
defer scope.Trace(NowFunc())
|
||||
// Define callbacks for querying
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func queryCallback(scope *Scope) {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice bool
|
||||
isPtr bool
|
||||
anyRecordFound bool
|
||||
destType reflect.Type
|
||||
isSlice bool
|
||||
isPtr bool
|
||||
results = scope.IndirectValue()
|
||||
resultType reflect.Type
|
||||
)
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
|
||||
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||
}
|
||||
}
|
||||
|
||||
var dest = scope.IndirectValue()
|
||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
||||
results = reflect.Indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if kind := dest.Kind(); kind == reflect.Slice {
|
||||
if kind := results.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
destType = dest.Type().Elem()
|
||||
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
|
||||
resultType = results.Type().Elem()
|
||||
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
if resultType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
destType = destType.Elem()
|
||||
resultType = resultType.Elem()
|
||||
}
|
||||
} else if kind != reflect.Struct {
|
||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||
return
|
||||
}
|
||||
|
||||
scope.prepareQuerySql()
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if !scope.HasError() {
|
||||
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
scope.db.RowsAffected = 0
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected++
|
||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
||||
anyRecordFound = true
|
||||
elem := dest
|
||||
if isSlice {
|
||||
elem = reflect.New(destType).Elem()
|
||||
}
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected++
|
||||
|
||||
var values = make([]interface{}, len(columns))
|
||||
elem := results
|
||||
if isSlice {
|
||||
elem = reflect.New(resultType).Elem()
|
||||
}
|
||||
|
||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
|
||||
|
||||
for index, column := range columns {
|
||||
if field, ok := fields[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
values[index] = field.Field.Addr().Interface()
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
results.Set(reflect.Append(results, elem.Addr()))
|
||||
} else {
|
||||
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
|
||||
reflectValue.Elem().Set(field.Field.Addr())
|
||||
values[index] = reflectValue.Interface()
|
||||
}
|
||||
} else {
|
||||
var value interface{}
|
||||
values[index] = &value
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
for index, column := range columns {
|
||||
value := values[index]
|
||||
if field, ok := fields[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
field.Field.Set(reflect.ValueOf(value).Elem())
|
||||
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
||||
field.Field.Set(v)
|
||||
results.Set(reflect.Append(results, elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
dest.Set(reflect.Append(dest, elem.Addr()))
|
||||
} else {
|
||||
dest.Set(reflect.Append(dest, elem))
|
||||
}
|
||||
if scope.db.RowsAffected == 0 && !isSlice {
|
||||
scope.Err(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
if !anyRecordFound && !isSlice {
|
||||
scope.Err(RecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(scope *Scope) {
|
||||
scope.CallMethodWithErrorCheck("AfterFind")
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", Query)
|
||||
DefaultCallback.Query().Register("gorm:preload", Preload)
|
||||
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||
func afterQueryCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterFind")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,308 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
if scope.Search.preload == nil || scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
preloadedMap = map[string]bool{}
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
for _, preload := range scope.Search.preload {
|
||||
var (
|
||||
preloadFields = strings.Split(preload.schema, ".")
|
||||
currentScope = scope
|
||||
currentFields = fields
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
var currentPreloadConditions []interface{}
|
||||
|
||||
// if not preloaded
|
||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||
|
||||
// assign search conditions to last preload
|
||||
if idx == len(preloadFields)-1 {
|
||||
currentPreloadConditions = preload.conditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != preloadField || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||
default:
|
||||
scope.Err(errors.New("unsupported relation"))
|
||||
}
|
||||
|
||||
preloadedMap[preloadKey] = true
|
||||
break
|
||||
}
|
||||
|
||||
if !preloadedMap[preloadKey] {
|
||||
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// preload next level
|
||||
if idx < len(preloadFields)-1 {
|
||||
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||
var (
|
||||
preloadDB = scope.NewDB()
|
||||
preloadConditions []interface{}
|
||||
)
|
||||
|
||||
for _, condition := range conditions {
|
||||
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||
preloadDB = scopes(preloadDB)
|
||||
} else {
|
||||
preloadConditions = append(preloadConditions, condition)
|
||||
}
|
||||
}
|
||||
|
||||
return preloadDB, preloadConditions
|
||||
}
|
||||
|
||||
// handleHasOnePreload used to preload has one associations
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||
indirectValue.FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleHasManyPreload used to preload has many associations
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
|
||||
objectField := object.FieldByName(field.Name)
|
||||
objectField.Set(reflect.Append(objectField, result))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(resultsValue))
|
||||
}
|
||||
}
|
||||
|
||||
// handleBelongsToPreload used to preload belongs to associations
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleManyToManyPreload used to preload many to many associations
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
var (
|
||||
relation = field.Relationship
|
||||
joinTableHandler = relation.JoinTableHandler
|
||||
fieldType = field.Struct.Type.Elem()
|
||||
foreignKeyValue interface{}
|
||||
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||
linkHash = map[string][]reflect.Value{}
|
||||
isPtr bool
|
||||
)
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
var sourceKeys = []string{}
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// generate query with join table
|
||||
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||
preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
|
||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||
|
||||
// preload inline conditions
|
||||
if len(preloadConditions) > 0 {
|
||||
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||
}
|
||||
|
||||
rows, err := preloadDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var (
|
||||
elem = reflect.New(fieldType).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).fieldsMap()
|
||||
)
|
||||
|
||||
// register foreign keys in join tables
|
||||
for _, sourceKey := range sourceKeys {
|
||||
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, fields)
|
||||
|
||||
// generate hashed forkey keys in join table
|
||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||
for idx, sourceKey := range sourceKeys {
|
||||
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
|
||||
}
|
||||
hashedSourceKeys := toString(foreignKeys)
|
||||
|
||||
if isPtr {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||
} else {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||
}
|
||||
}
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
fieldsSourceMap = map[string]reflect.Value{}
|
||||
foreignFieldNames = []string{}
|
||||
fields = scope.fieldsMap()
|
||||
)
|
||||
|
||||
for _, dbName := range relation.ForeignFieldNames {
|
||||
if field, ok := fields[dbName]; ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
|
||||
}
|
||||
} else if indirectScopeValue.IsValid() {
|
||||
fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
|
||||
}
|
||||
|
||||
for source, link := range linkHash {
|
||||
fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
|
||||
}
|
||||
}
|
|
@ -2,15 +2,15 @@ package gorm
|
|||
|
||||
import "reflect"
|
||||
|
||||
func BeginTransaction(scope *Scope) {
|
||||
func beginTransactionCallback(scope *Scope) {
|
||||
scope.Begin()
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(scope *Scope) {
|
||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||
scope.CommitOrRollback()
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(scope *Scope) {
|
||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||
if !scope.shouldSaveAssociations() {
|
||||
return
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
|
|||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(scope *Scope) {
|
||||
func saveAfterAssociationsCallback(scope *Scope) {
|
||||
if !scope.shouldSaveAssociations() {
|
||||
return
|
||||
}
|
|
@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {}
|
|||
func afterCreate2(s *Scope) {}
|
||||
|
||||
func TestRegisterCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("before_create2", beforeCreate2)
|
||||
|
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
||||
var callback1 = &Callback{}
|
||||
callback1.Create().Register("before_create1", beforeCreate1)
|
||||
callback1.Create().Register("create", create)
|
||||
callback1.Create().Register("after_create1", afterCreate1)
|
||||
|
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
||||
var callback2 = &Callback{}
|
||||
|
||||
callback2.Update().Register("create", create)
|
||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||
|
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
||||
var callback1 = &Callback{}
|
||||
|
||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback1.Query().Register("before_create1", beforeCreate1)
|
||||
|
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
||||
var callback2 = &Callback{}
|
||||
|
||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||
|
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|||
func replaceCreate(s *Scope) {}
|
||||
|
||||
func TestReplaceCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
|
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRemoveCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
|
|
|
@ -5,91 +5,102 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func AssignUpdateAttributes(scope *Scope) {
|
||||
// Define callbacks for updating
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
|
||||
protected, ok := scope.Get("gorm:ignore_protected_attrs")
|
||||
_, updateColumn := scope.Get("gorm:update_column")
|
||||
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
|
||||
|
||||
if updateColumn {
|
||||
scope.InstanceSet("gorm:update_attrs", maps)
|
||||
} else if len(updateAttrs) > 0 {
|
||||
scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||
} else if !hasUpdate {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BeforeUpdate(scope *Scope) {
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
func beforeUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.CallMethodWithErrorCheck("BeforeSave")
|
||||
scope.CallMethodWithErrorCheck("BeforeUpdate")
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.SetColumn("UpdatedAt", NowFunc())
|
||||
}
|
||||
}
|
||||
|
||||
func Update(scope *Scope) {
|
||||
// updateCallback the callback used to update data to database
|
||||
func updateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var sqls []string
|
||||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
||||
if scope.changeableDBColumn(key) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
||||
}
|
||||
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
fields := scope.Fields()
|
||||
for _, field := range fields {
|
||||
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, dbName := range relationship.ForeignDBNames {
|
||||
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
|
||||
sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
|
||||
sqls = append(sqls, sql)
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if !field.IsPrimaryKey && field.IsNormal {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
sqls = append(sqls,
|
||||
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if len(sqls) > 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
"UPDATE %v SET %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(sqls, ", "),
|
||||
scope.CombinedConditionSql(),
|
||||
))
|
||||
scope.Exec()
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterUpdate(scope *Scope) {
|
||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||
func afterUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.CallMethodWithErrorCheck("AfterUpdate")
|
||||
scope.CallMethodWithErrorCheck("AfterSave")
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterUpdate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
|
||||
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
|
||||
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
|
||||
DefaultCallback.Update().Register("gorm:update", Update)
|
||||
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
|
||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
}
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type commonDialect struct{}
|
||||
|
||||
func (commonDialect) BinVar(i int) string {
|
||||
return "$$" // ?
|
||||
}
|
||||
|
||||
func (commonDialect) SupportLastInsertId() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (commonDialect) HasTop() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "INTEGER AUTO_INCREMENT"
|
||||
}
|
||||
return "INTEGER"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "BIGINT AUTO_INCREMENT"
|
||||
}
|
||||
return "BIGINT"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("VARCHAR(%d)", size)
|
||||
}
|
||||
return "VARCHAR(65532)"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("BINARY(%d)", size)
|
||||
}
|
||||
return "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (commonDialect) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = c.CurrentDatabase(scope)
|
||||
)
|
||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = c.CurrentDatabase(scope)
|
||||
)
|
||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = c.CurrentDatabase(scope)
|
||||
)
|
||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
|
||||
}
|
||||
|
||||
// RawScanInt scans the first column of the first row into the `scan' int pointer.
|
||||
// This function captures raw query errors and propagates them to the original scope.
|
||||
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
|
||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
||||
}
|
||||
|
||||
// RawScanString scans the first column of the first row into the `scan' string pointer.
|
||||
// This function captures raw query errors and propagates them to the original scope.
|
||||
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
|
||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
||||
}
|
||||
|
||||
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
|
||||
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
||||
return
|
||||
}
|
|
@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
|
|||
DB.AutoMigrate(&CustomizeColumn{})
|
||||
|
||||
scope := DB.NewScope(&CustomizeColumn{})
|
||||
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
|
||||
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||
t.Errorf("CustomizeColumn should have column %s", col)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
DB.HasTable("foobarbaz")
|
||||
if DB.Error == nil {
|
||||
if err := DB.Find(&User{}).Error; err == nil {
|
||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func TestSoftDelete(t *testing.T) {
|
|||
type User struct {
|
||||
Id int64
|
||||
Name string
|
||||
DeletedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
DB.AutoMigrate(&User{})
|
||||
|
||||
|
|
115
dialect.go
115
dialect.go
|
@ -1,41 +1,100 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Dialect interface contains behaviors that differ across SQL database
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
HasTop() bool
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
ReturningStr(tableName, key string) string
|
||||
SelectFromDummyTable() string
|
||||
// GetName get dialect's name
|
||||
GetName() string
|
||||
|
||||
// SetDB set db for dialect
|
||||
SetDB(db *sql.DB)
|
||||
|
||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||
BindVar(i int) string
|
||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
Quote(key string) string
|
||||
HasTable(scope *Scope, tableName string) bool
|
||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||
RemoveIndex(scope *Scope, indexName string)
|
||||
CurrentDatabase(scope *Scope) string
|
||||
// DataTypeOf return data's sql type
|
||||
DataTypeOf(field *StructField) string
|
||||
|
||||
// HasIndex check has index or not
|
||||
HasIndex(tableName string, indexName string) bool
|
||||
// HasForeignKey check has foreign key or not
|
||||
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||
// RemoveIndex remove index
|
||||
RemoveIndex(tableName string, indexName string) error
|
||||
// HasTable check has table or not
|
||||
HasTable(tableName string) bool
|
||||
// HasColumn check has column or not
|
||||
HasColumn(tableName string, columnName string) bool
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||
LimitAndOffsetSQL(limit, offset int) string
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||
}
|
||||
|
||||
func NewDialect(driver string) Dialect {
|
||||
var d Dialect
|
||||
switch driver {
|
||||
case "postgres":
|
||||
d = &postgres{}
|
||||
case "foundation":
|
||||
d = &foundation{}
|
||||
case "mysql":
|
||||
d = &mysql{}
|
||||
case "sqlite3":
|
||||
d = &sqlite3{}
|
||||
case "mssql":
|
||||
d = &mssql{}
|
||||
default:
|
||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
|
||||
d = &commonDialect{}
|
||||
var dialectsMap = map[string]Dialect{}
|
||||
|
||||
func newDialect(name string, db *sql.DB) Dialect {
|
||||
if value, ok := dialectsMap[name]; ok {
|
||||
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||
dialect.SetDB(db)
|
||||
return dialect
|
||||
}
|
||||
return d
|
||||
|
||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||
commontDialect := &commonDialect{}
|
||||
commontDialect.SetDB(db)
|
||||
return commontDialect
|
||||
}
|
||||
|
||||
// RegisterDialect register new dialect
|
||||
func RegisterDialect(name string, dialect Dialect) {
|
||||
dialectsMap[name] = dialect
|
||||
}
|
||||
|
||||
// ParseFieldStructForDialect parse field struct for dialect
|
||||
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||
// Get redirected field type
|
||||
var reflectType = field.Struct.Type
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
// Get redirected field value
|
||||
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||
|
||||
// Get scanner's real value
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
fieldValue = value
|
||||
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||
getScannerValue(fieldValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(fieldValue)
|
||||
|
||||
// Default Size
|
||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
} else {
|
||||
size = 255
|
||||
}
|
||||
|
||||
// Default type from tag setting
|
||||
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type commonDialect struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("common", &commonDialect{})
|
||||
}
|
||||
|
||||
func (commonDialect) GetName() string {
|
||||
return "common"
|
||||
}
|
||||
|
||||
func (s *commonDialect) SetDB(db *sql.DB) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (commonDialect) BindVar(i int) string {
|
||||
return "$$" // ?
|
||||
}
|
||||
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (commonDialect) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||
sqlType = "INTEGER AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "INTEGER"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||
sqlType = "BIGINT AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "BIGINT"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
|
||||
} else {
|
||||
sqlType = "VARCHAR(65532)"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("BINARY(%d)", size)
|
||||
} else {
|
||||
sqlType = "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s commonDialect) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) currentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||
if limit > 0 || offset > 0 {
|
||||
if limit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", limit)
|
||||
}
|
||||
if offset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", offset)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysql struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("mysql", &mysql{})
|
||||
}
|
||||
|
||||
func (mysql) GetName() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
// Get Data Type for MySQL Dialect
|
||||
func (mysql) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "int AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int"
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "int unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int unsigned"
|
||||
}
|
||||
case reflect.Int64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "bigint AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint unsigned"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "double"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "longtext"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
||||
sqlType = "timestamp"
|
||||
} else {
|
||||
sqlType = "timestamp NULL"
|
||||
}
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
sqlType = "longblob"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mysql) currentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql) SelectFromDummyTable() string {
|
||||
return "FROM DUAL"
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("postgres", &postgres{})
|
||||
}
|
||||
|
||||
func (postgres) GetName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (postgres) BindVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (postgres) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "serial"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "bigserial"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "numeric"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if dataValue.Type().Name() == "Hstore" {
|
||||
sqlType = "hstore"
|
||||
}
|
||||
default:
|
||||
if isByteArrayOrSlice(dataValue) {
|
||||
sqlType = "bytea"
|
||||
} else if isUUID(dataValue) {
|
||||
sqlType = "uuid"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) currentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertID() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sqlite3 struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("sqlite", &sqlite3{})
|
||||
RegisterDialect("sqlite3", &sqlite3{})
|
||||
}
|
||||
|
||||
func (sqlite3) GetName() string {
|
||||
return "sqlite3"
|
||||
}
|
||||
|
||||
// Get Data Type for Sqlite Dialect
|
||||
func (sqlite3) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "real"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "datetime"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
sqlType = "blob"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) currentDatabase() (name string) {
|
||||
var (
|
||||
ifaces = make([]interface{}, 3)
|
||||
pointers = make([]*string, 3)
|
||||
i int
|
||||
)
|
||||
for i = 0; i < 3; i++ {
|
||||
ifaces[i] = &pointers[i]
|
||||
}
|
||||
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||
return
|
||||
}
|
||||
if pointers[1] != nil {
|
||||
name = *pointers[1]
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func setIdentityInsert(scope *gorm.Scope) {
|
||||
if scope.Dialect().GetName() == "mssql" {
|
||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
||||
gorm.RegisterDialect("mssql", &mssql{})
|
||||
}
|
||||
|
||||
type mssql struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (mssql) GetName() string {
|
||||
return "mssql"
|
||||
}
|
||||
|
||||
func (s *mssql) SetDB(db *sql.DB) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (mssql) BindVar(i int) string {
|
||||
return "$$" // ?
|
||||
}
|
||||
|
||||
func (mssql) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (mssql) DataTypeOf(field *gorm.StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "bit"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "int IDENTITY(1,1)"
|
||||
} else {
|
||||
sqlType = "int"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
sqlType = "bigint IDENTITY(1,1)"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "float"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("nvarchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "datetime2"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s mssql) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) currentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||
if limit > 0 || offset > 0 {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
|
||||
|
||||
if limit >= 0 {
|
||||
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
package mysql
|
||||
|
||||
import _ "github.com/go-sql-driver/mysql"
|
|
@ -0,0 +1,52 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/lib/pq/hstore"
|
||||
)
|
||||
|
||||
type Hstore map[string]*string
|
||||
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for key, value := range h {
|
||||
var s sql.NullString
|
||||
if value != nil {
|
||||
s.String = *value
|
||||
s.Valid = true
|
||||
}
|
||||
hstore.Map[key] = s
|
||||
}
|
||||
return hstore.Value()
|
||||
}
|
||||
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
hstore := hstore.Hstore{}
|
||||
|
||||
if err := hstore.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hstore.Map) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
*h = Hstore{}
|
||||
for k := range hstore.Map {
|
||||
if hstore.Map[k].Valid {
|
||||
s := hstore.Map[k].String
|
||||
(*h)[k] = &s
|
||||
} else {
|
||||
(*h)[k] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
package sqlite
|
||||
|
||||
import _ "github.com/mattn/go-sqlite3"
|
|
@ -1,68 +0,0 @@
|
|||
# Gorm Development
|
||||
|
||||
## Architecture
|
||||
|
||||
The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
|
||||
|
||||
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
|
||||
Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
|
||||
|
||||
Lets use below code to explain how it works:
|
||||
|
||||
db.Where("name = ?", "jinzhu").Find(&users)
|
||||
|
||||
// equivalent code
|
||||
newdb := db.Where("name =?", "jinzhu")
|
||||
newdb.Find(&user)
|
||||
|
||||
`newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
|
||||
`Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
|
||||
|
||||
There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
|
||||
|
||||
## Callbacks
|
||||
|
||||
### Register a new callback
|
||||
|
||||
func updateCreated(scope *Scope) {
|
||||
if scope.HasColumn("Created") {
|
||||
scope.SetColumn("Created", NowFunc())
|
||||
}
|
||||
}
|
||||
|
||||
db.Callback().Create().Register("update_created_at", updateCreated)
|
||||
// register a callback for Create process
|
||||
|
||||
### Delete an existing callback
|
||||
|
||||
db.Callback().Create().Remove("gorm:create")
|
||||
// delete callback `gorm:create` from Create callbacks
|
||||
|
||||
### Replace an existing callback
|
||||
|
||||
db.Callback().Create().Replace("gorm:create", newCreateFunction)
|
||||
// replace callback `gorm:create` with new function `newCreateFunction` for Create process
|
||||
|
||||
### Register callback orders
|
||||
|
||||
db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
|
||||
db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
|
||||
db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
|
||||
db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
|
||||
db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
|
||||
db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
|
||||
|
||||
### Callback API
|
||||
|
||||
Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
|
||||
|
||||
[Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
|
||||
|
||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
||||
|
||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
|
||||
|
||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
||||
|
||||
View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API
|
17
errors.go
17
errors.go
|
@ -6,25 +6,31 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
RecordNotFound = errors.New("record not found")
|
||||
InvalidSql = errors.New("invalid sql")
|
||||
NoNewAttrs = errors.New("no new attributes")
|
||||
NoValidTransaction = errors.New("no valid transaction")
|
||||
CantStartTransaction = errors.New("can't start transaction")
|
||||
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||
ErrInvalidSQL = errors.New("invalid SQL")
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||
)
|
||||
|
||||
type errorsInterface interface {
|
||||
GetErrors() []error
|
||||
}
|
||||
|
||||
// Errors contains all happened errors
|
||||
type Errors struct {
|
||||
errors []error
|
||||
}
|
||||
|
||||
// GetErrors get all happened errors
|
||||
func (errs Errors) GetErrors() []error {
|
||||
return errs.errors
|
||||
}
|
||||
|
||||
// Add add an error
|
||||
func (errs *Errors) Add(err error) {
|
||||
if errors, ok := err.(errorsInterface); ok {
|
||||
for _, err := range errors.GetErrors() {
|
||||
|
@ -40,6 +46,7 @@ func (errs *Errors) Add(err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Error format happened errors
|
||||
func (errs Errors) Error() string {
|
||||
var errors = []string{}
|
||||
for _, e := range errs.errors {
|
||||
|
|
49
field.go
49
field.go
|
@ -7,12 +7,14 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// Field model field definition
|
||||
type Field struct {
|
||||
*StructField
|
||||
IsBlank bool
|
||||
Field reflect.Value
|
||||
}
|
||||
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Field.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
|
@ -56,35 +58,34 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||
}
|
||||
|
||||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() map[string]*Field {
|
||||
if scope.fields == nil {
|
||||
fields := map[string]*Field{}
|
||||
modelStruct := scope.GetModelStruct()
|
||||
func (scope *Scope) Fields() []*Field {
|
||||
var (
|
||||
fields []*Field
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||
)
|
||||
|
||||
indirectValue := scope.IndirectValue()
|
||||
isStruct := indirectValue.Kind() == reflect.Struct
|
||||
for _, structField := range modelStruct.StructFields {
|
||||
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
|
||||
if isStruct {
|
||||
fields[structField.DBName] = getField(indirectValue, structField)
|
||||
} else {
|
||||
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
||||
}
|
||||
for _, structField := range scope.GetModelStruct().StructFields {
|
||||
if isStruct {
|
||||
fieldValue := indirectScopeValue
|
||||
for _, name := range structField.Names {
|
||||
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||
}
|
||||
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
||||
} else {
|
||||
fields = append(fields, &Field{StructField: structField, IsBlank: true})
|
||||
}
|
||||
|
||||
scope.fields = fields
|
||||
return fields
|
||||
}
|
||||
return scope.fields
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
||||
field := &Field{StructField: structField}
|
||||
for _, name := range structField.Names {
|
||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
||||
func (scope *Scope) fieldsMap() map[string]*Field {
|
||||
var results = map[string]*Field{}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
results[field.DBName] = field
|
||||
}
|
||||
}
|
||||
field.Field = indirectValue
|
||||
field.IsBlank = isBlank(indirectValue)
|
||||
return field
|
||||
return results
|
||||
}
|
||||
|
|
|
@ -32,12 +32,16 @@ type CalculateFieldCategory struct {
|
|||
|
||||
func TestCalculateField(t *testing.T) {
|
||||
var field CalculateField
|
||||
fields := DB.NewScope(&field).Fields()
|
||||
if fields["children"].Relationship == nil || fields["category"].Relationship == nil {
|
||||
var scope = DB.NewScope(&field)
|
||||
if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
|
||||
t.Errorf("Should calculate fields correctly for the first time")
|
||||
}
|
||||
|
||||
if field, ok := fields["embedded_name"]; !ok {
|
||||
if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
|
||||
t.Errorf("Should calculate fields correctly for the first time")
|
||||
}
|
||||
|
||||
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||
t.Errorf("should find embedded field")
|
||||
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
||||
t.Errorf("should find embedded field's tag settings")
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type foundation struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (foundation) BinVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (foundation) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "serial"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "clob"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s foundation) ReturningStr(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (s foundation) HasTable(scope *Scope, tableName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s foundation) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
|
||||
}
|
||||
|
||||
func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s foundation) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
|
||||
return
|
||||
}
|
Binary file not shown.
Before Width: | Height: | Size: 65 KiB |
|
@ -7,40 +7,54 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||
type JoinTableHandlerInterface interface {
|
||||
// initialize join table handler
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
// Table return join table's table name
|
||||
Table(db *DB) string
|
||||
// Add create relationship in join table for source and destination
|
||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
// Delete delete relationship in join table for sources
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
// JoinWith query with `Join` conditions
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
// SourceForeignKeys return source foreign keys
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableForeignKey join table foreign key struct
|
||||
type JoinTableForeignKey struct {
|
||||
DBName string
|
||||
AssociationDBName string
|
||||
}
|
||||
|
||||
// JoinTableSource is a struct that contains model type and foreign keys
|
||||
type JoinTableSource struct {
|
||||
ModelType reflect.Type
|
||||
ForeignKeys []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableHandler default join table handler
|
||||
type JoinTableHandler struct {
|
||||
TableName string `sql:"-"`
|
||||
Source JoinTableSource `sql:"-"`
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
// SourceForeignKeys return source foreign keys
|
||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||
return s.Source.ForeignKeys
|
||||
}
|
||||
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||
return s.Destination.ForeignKeys
|
||||
}
|
||||
|
||||
// Setup initialize a default join table handler
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
|
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
|
|||
}
|
||||
}
|
||||
|
||||
// Table return join table's table name
|
||||
func (s JoinTableHandler) Table(db *DB) string {
|
||||
return s.TableName
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
values := map[string]interface{}{}
|
||||
|
||||
for _, source := range sources {
|
||||
|
@ -74,20 +89,25 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
|
|||
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else if s.Destination.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
|
||||
// Add create relationship in join table for source and destination
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := s.GetSearchMap(db, source1, source2)
|
||||
searchMap := s.getSearchMap(db, source, destination)
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
|
@ -116,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
|
|||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
// Delete delete relationship in join table for sources
|
||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||
var (
|
||||
scope = db.NewScope(nil)
|
||||
|
@ -123,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
|||
values []interface{}
|
||||
)
|
||||
|
||||
for key, value := range s.GetSearchMap(db, sources...) {
|
||||
for key, value := range s.getSearchMap(db, sources...) {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
@ -131,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
|||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` conditions
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
var (
|
||||
scope = db.NewScope(source)
|
||||
|
@ -151,10 +173,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames)
|
||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
|
||||
var condString string
|
||||
if len(foreignFieldValues) > 0 {
|
||||
|
@ -165,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||
|
||||
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
||||
|
||||
keys := scope.getColumnAsArray(foreignFieldNames)
|
||||
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
values = append(values, toQueryValues(keys))
|
||||
} else {
|
||||
condString = fmt.Sprintf("1 <> 1")
|
||||
|
@ -173,8 +197,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
|
|||
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||
} else {
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type PersonAddress struct {
|
|||
gorm.JoinTableHandler
|
||||
PersonID int
|
||||
AddressID int
|
||||
DeletedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
|
|
57
logger.go
57
logger.go
|
@ -8,25 +8,28 @@ import (
|
|||
"reflect"
|
||||
"regexp"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Print(v ...interface{})
|
||||
}
|
||||
|
||||
type LogWriter interface {
|
||||
type logWriter interface {
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
// Logger default logger
|
||||
type Logger struct {
|
||||
LogWriter
|
||||
logWriter
|
||||
}
|
||||
|
||||
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
|
||||
// Format log
|
||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||
|
||||
// Print format & print log
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
if len(values) > 1 {
|
||||
level := values[0]
|
||||
|
@ -38,29 +41,44 @@ func (logger Logger) Print(values ...interface{}) {
|
|||
// duration
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
// sql
|
||||
var formatedValues []interface{}
|
||||
var sql string
|
||||
var formattedValues []string
|
||||
|
||||
for _, value := range values[4].([]interface{}) {
|
||||
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if indirectValue.IsValid() {
|
||||
value = indirectValue.Interface()
|
||||
if t, ok := value.(time.Time); ok {
|
||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
|
||||
} else if b, ok := value.([]byte); ok {
|
||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
|
||||
if str := string(b); isPrintable(str) {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "'<binary>'")
|
||||
}
|
||||
} else if r, ok := value.(driver.Valuer); ok {
|
||||
if value, err := r.Value(); err == nil && value != nil {
|
||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
} else {
|
||||
formatedValues = append(formatedValues, "NULL")
|
||||
formattedValues = append(formattedValues, "NULL")
|
||||
}
|
||||
} else {
|
||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
}
|
||||
} else {
|
||||
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
}
|
||||
}
|
||||
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
|
||||
|
||||
var formattedValuesLength = len(formattedValues)
|
||||
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
||||
sql += value
|
||||
if index < formattedValuesLength {
|
||||
sql += formattedValues[index]
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, sql)
|
||||
} else {
|
||||
messages = append(messages, "\033[31;1m")
|
||||
messages = append(messages, values[2:]...)
|
||||
|
@ -69,3 +87,12 @@ func (logger Logger) Print(values ...interface{}) {
|
|||
logger.Println(messages...)
|
||||
}
|
||||
}
|
||||
|
||||
func isPrintable(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
285
main.go
285
main.go
|
@ -6,24 +6,14 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NowFunc returns current time, this function is exported in order to be able
|
||||
// to give the flexibility to the developer to customize it according to their
|
||||
// needs
|
||||
//
|
||||
// e.g: return time.Now().UTC()
|
||||
//
|
||||
var NowFunc = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// DB contains information for current db connection
|
||||
type DB struct {
|
||||
Value interface{}
|
||||
Error error
|
||||
RowsAffected int64
|
||||
callback *callback
|
||||
callbacks *Callback
|
||||
db sqlCommon
|
||||
parent *DB
|
||||
search *search
|
||||
|
@ -36,7 +26,18 @@ type DB struct {
|
|||
joinTableHandlers map[string]JoinTableHandler
|
||||
}
|
||||
|
||||
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||
// Open initialize a new db connection, need to import driver first, e.g:
|
||||
//
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
// func main() {
|
||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||
// }
|
||||
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
|
||||
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||
var db DB
|
||||
var err error
|
||||
|
||||
|
@ -44,7 +45,7 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
|||
err = errors.New("invalid database source")
|
||||
} else {
|
||||
var source string
|
||||
var dbSql sqlCommon
|
||||
var dbSQL sqlCommon
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
|
@ -55,22 +56,19 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
|||
driver = value
|
||||
source = args[1].(string)
|
||||
}
|
||||
if driver == "foundation" {
|
||||
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
|
||||
}
|
||||
dbSql, err = sql.Open(driver, source)
|
||||
dbSQL, err = sql.Open(driver, source)
|
||||
case sqlCommon:
|
||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||
dbSql = value
|
||||
dbSQL = value
|
||||
}
|
||||
|
||||
db = DB{
|
||||
dialect: NewDialect(dialect),
|
||||
logger: defaultLogger,
|
||||
callback: DefaultCallback,
|
||||
source: source,
|
||||
values: map[string]interface{}{},
|
||||
db: dbSql,
|
||||
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
|
||||
logger: defaultLogger,
|
||||
callbacks: DefaultCallback,
|
||||
source: source,
|
||||
values: map[string]interface{}{},
|
||||
db: dbSQL,
|
||||
}
|
||||
db.parent = &db
|
||||
|
||||
|
@ -79,17 +77,20 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return db, err
|
||||
return &db, err
|
||||
}
|
||||
|
||||
// Close close current db connection
|
||||
func (s *DB) Close() error {
|
||||
return s.parent.db.(*sql.DB).Close()
|
||||
}
|
||||
|
||||
// DB get `*sql.DB` from current connection
|
||||
func (s *DB) DB() *sql.DB {
|
||||
return s.db.(*sql.DB)
|
||||
}
|
||||
|
||||
// New clone a new db connection without search conditions
|
||||
func (s *DB) New() *DB {
|
||||
clone := s.clone()
|
||||
clone.search = nil
|
||||
|
@ -97,29 +98,32 @@ func (s *DB) New() *DB {
|
|||
return clone
|
||||
}
|
||||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
func (db *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := db.clone()
|
||||
// NewScope create a scope for current operation
|
||||
func (s *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := s.clone()
|
||||
dbClone.Value = value
|
||||
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
||||
}
|
||||
|
||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||
// Use of this method is discouraged. It's mainly intended to allow
|
||||
// coexistence with legacy non-GORM code.
|
||||
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() sqlCommon {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func (s *DB) Callback() *callback {
|
||||
s.parent.callback = s.parent.callback.clone()
|
||||
return s.parent.callback
|
||||
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||
func (s *DB) Callback() *Callback {
|
||||
s.parent.callbacks = s.parent.callbacks.clone()
|
||||
return s.parent.callbacks
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l logger) {
|
||||
s.logger = l
|
||||
// SetLogger replace default logger
|
||||
func (s *DB) SetLogger(log logger) {
|
||||
s.logger = log
|
||||
}
|
||||
|
||||
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||
func (s *DB) LogMode(enable bool) *DB {
|
||||
if enable {
|
||||
s.logMode = 2
|
||||
|
@ -129,55 +133,82 @@ func (s *DB) LogMode(enable bool) *DB {
|
|||
return s
|
||||
}
|
||||
|
||||
// SingularTable use singular table by default
|
||||
func (s *DB) SingularTable(enable bool) {
|
||||
modelStructsMap = newModelStructsMap()
|
||||
s.parent.singularTable = enable
|
||||
}
|
||||
|
||||
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
|
||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Where(query, args...).db
|
||||
}
|
||||
|
||||
// Or filter records that match before conditions or this one, similar to `Where`
|
||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Or(query, args...).db
|
||||
}
|
||||
|
||||
// Not filter records that don't match current conditions, similar to `Where`
|
||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Not(query, args...).db
|
||||
}
|
||||
|
||||
func (s *DB) Limit(value interface{}) *DB {
|
||||
return s.clone().search.Limit(value).db
|
||||
// Limit specify the number of records to be retrieved
|
||||
func (s *DB) Limit(limit int) *DB {
|
||||
return s.clone().search.Limit(limit).db
|
||||
}
|
||||
|
||||
func (s *DB) Offset(value interface{}) *DB {
|
||||
return s.clone().search.Offset(value).db
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
func (s *DB) Offset(offset int) *DB {
|
||||
return s.clone().search.Offset(offset).db
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
|
||||
func (s *DB) Order(value string, reorder ...bool) *DB {
|
||||
return s.clone().search.Order(value, reorder...).db
|
||||
}
|
||||
|
||||
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
|
||||
// When creating/updating, specify fields that you want to save to database
|
||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Select(query, args...).db
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when saving to database for creating, updating
|
||||
func (s *DB) Omit(columns ...string) *DB {
|
||||
return s.clone().search.Omit(columns...).db
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
func (s *DB) Group(query string) *DB {
|
||||
return s.clone().search.Group(query).db
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
func (s *DB) Having(query string, values ...interface{}) *DB {
|
||||
return s.clone().search.Having(query, values...).db
|
||||
}
|
||||
|
||||
func (s *DB) Joins(query string) *DB {
|
||||
return s.clone().search.Joins(query).db
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||
return s.clone().search.Joins(query, args...).db
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// Refer https://jinzhu.github.io/gorm/curd.html#scopes
|
||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||
for _, f := range funcs {
|
||||
s = f(s)
|
||||
|
@ -185,60 +216,91 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
|||
return s
|
||||
}
|
||||
|
||||
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete
|
||||
func (s *DB) Unscoped() *DB {
|
||||
return s.clone().search.unscoped().db
|
||||
}
|
||||
|
||||
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||
return s.clone().search.Attrs(attrs...).db
|
||||
}
|
||||
|
||||
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||
func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||
return s.clone().search.Assign(attrs...).db
|
||||
}
|
||||
|
||||
// First find first record that match given conditions, order by primary key
|
||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.clone().NewScope(out)
|
||||
newScope.Search.Limit(1)
|
||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.clone().NewScope(out)
|
||||
newScope.Search.Limit(1)
|
||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Find find records that match given conditions
|
||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (s *DB) Scan(dest interface{}) *DB {
|
||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
|
||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||
}
|
||||
|
||||
// Row return `*sql.Row` with given conditions
|
||||
func (s *DB) Row() *sql.Row {
|
||||
return s.NewScope(s.Value).row()
|
||||
}
|
||||
|
||||
// Rows return `*sql.Rows` with given conditions
|
||||
func (s *DB) Rows() (*sql.Rows, error) {
|
||||
return s.NewScope(s.Value).rows()
|
||||
}
|
||||
|
||||
// ScanRows scan `*sql.Rows` to give struct
|
||||
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||
var (
|
||||
clone = s.clone()
|
||||
scope = clone.NewScope(result)
|
||||
columns, err = rows.Columns()
|
||||
)
|
||||
|
||||
if clone.AddError(err) == nil {
|
||||
scope.scan(rows, columns, scope.fieldsMap())
|
||||
}
|
||||
|
||||
return clone.Error
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Find(&users).Pluck("age", &ages)
|
||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||
return s.NewScope(s.Value).pluck(column, value).db
|
||||
}
|
||||
|
||||
// Count get how many records for a model
|
||||
func (s *DB) Count(value interface{}) *DB {
|
||||
return s.NewScope(s.Value).count(value).db
|
||||
}
|
||||
|
||||
// Related get related associations
|
||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||
}
|
||||
|
||||
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
|
||||
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
if result := c.First(out, where...); result.Error != nil {
|
||||
|
@ -247,82 +309,100 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize()
|
||||
} else {
|
||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false)
|
||||
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||
// https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
if result := c.First(out, where...); result.Error != nil {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
|
||||
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error)
|
||||
} else if len(c.search.assignAttrs) > 0 {
|
||||
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
|
||||
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||
return s.Updates(toSearchableMap(attrs...), true)
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||
return s.clone().NewScope(s.Value).
|
||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||
InstanceSet("gorm:update_interface", values).
|
||||
callCallbacks(s.parent.callback.updates).db
|
||||
callCallbacks(s.parent.callbacks.updates).db
|
||||
}
|
||||
|
||||
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||
return s.UpdateColumns(toSearchableMap(attrs...))
|
||||
}
|
||||
|
||||
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||
return s.clone().NewScope(s.Value).
|
||||
Set("gorm:update_column", true).
|
||||
Set("gorm:save_associations", false).
|
||||
InstanceSet("gorm:update_interface", values).
|
||||
callCallbacks(s.parent.callback.updates).db
|
||||
callCallbacks(s.parent.callbacks.updates).db
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (s *DB) Save(value interface{}) *DB {
|
||||
scope := s.clone().NewScope(value)
|
||||
if scope.PrimaryKeyZero() {
|
||||
return scope.callCallbacks(s.parent.callback.creates).db
|
||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||
}
|
||||
return scope.callCallbacks(s.parent.callback.updates).db
|
||||
return scope.callCallbacks(s.parent.callbacks.updates).db
|
||||
}
|
||||
|
||||
// Create insert the value into database
|
||||
func (s *DB) Create(value interface{}) *DB {
|
||||
scope := s.clone().NewScope(value)
|
||||
return scope.callCallbacks(s.parent.callback.creates).db
|
||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
|
||||
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||
}
|
||||
|
||||
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
|
||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||
return s.clone().search.Raw(true).Where(sql, values...).db
|
||||
}
|
||||
|
||||
// Exec execute raw sql
|
||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||
scope := s.clone().NewScope(nil)
|
||||
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
||||
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
|
||||
scope.Raw(generatedSql)
|
||||
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
||||
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||
scope.Raw(generatedSQL)
|
||||
return scope.Exec().db
|
||||
}
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (s *DB) Model(value interface{}) *DB {
|
||||
c := s.clone()
|
||||
c.Value = value
|
||||
return c
|
||||
}
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
func (s *DB) Table(name string) *DB {
|
||||
clone := s.clone()
|
||||
clone.search.Table(name)
|
||||
|
@ -330,10 +410,12 @@ func (s *DB) Table(name string) *DB {
|
|||
return clone
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (s *DB) Debug() *DB {
|
||||
return s.clone().LogMode(true)
|
||||
}
|
||||
|
||||
// Begin begin a transaction
|
||||
func (s *DB) Begin() *DB {
|
||||
c := s.clone()
|
||||
if db, ok := c.db.(sqlDb); ok {
|
||||
|
@ -341,46 +423,56 @@ func (s *DB) Begin() *DB {
|
|||
c.db = interface{}(tx).(sqlCommon)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(CantStartTransaction)
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
func (s *DB) Commit() *DB {
|
||||
if db, ok := s.db.(sqlTx); ok {
|
||||
s.AddError(db.Commit())
|
||||
} else {
|
||||
s.AddError(NoValidTransaction)
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Rollback rollback a transaction
|
||||
func (s *DB) Rollback() *DB {
|
||||
if db, ok := s.db.(sqlTx); ok {
|
||||
s.AddError(db.Rollback())
|
||||
} else {
|
||||
s.AddError(NoValidTransaction)
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// NewRecord check if value's primary key is blank
|
||||
func (s *DB) NewRecord(value interface{}) bool {
|
||||
return s.clone().NewScope(value).PrimaryKeyZero()
|
||||
}
|
||||
|
||||
// RecordNotFound check if returning ErrRecordNotFound error
|
||||
func (s *DB) RecordNotFound() bool {
|
||||
return s.Error == RecordNotFound
|
||||
for _, err := range s.GetErrors() {
|
||||
if err == ErrRecordNotFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Migrations
|
||||
func (s *DB) CreateTable(values ...interface{}) *DB {
|
||||
// CreateTable create table for models
|
||||
func (s *DB) CreateTable(models ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
db = db.NewScope(value).createTable().db
|
||||
for _, model := range models {
|
||||
db = db.NewScope(model).createTable().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// DropTable drop table for models
|
||||
func (s *DB) DropTable(values ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
|
@ -393,18 +485,18 @@ func (s *DB) DropTable(values ...interface{}) *DB {
|
|||
return db
|
||||
}
|
||||
|
||||
// DropTableIfExists drop table if it is exist
|
||||
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
if tableName, ok := value.(string); ok {
|
||||
db = db.Table(tableName)
|
||||
if s.HasTable(value) {
|
||||
db.AddError(s.DropTable(value).Error)
|
||||
}
|
||||
|
||||
db = db.NewScope(value).dropTableIfExists().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// HasTable check has table or not
|
||||
func (s *DB) HasTable(value interface{}) bool {
|
||||
var (
|
||||
scope = s.clone().NewScope(value)
|
||||
|
@ -417,69 +509,64 @@ func (s *DB) HasTable(value interface{}) bool {
|
|||
tableName = scope.TableName()
|
||||
}
|
||||
|
||||
has := scope.Dialect().HasTable(scope, tableName)
|
||||
has := scope.Dialect().HasTable(tableName)
|
||||
s.AddError(scope.db.Error)
|
||||
return has
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
||||
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||
db := s.clone()
|
||||
for _, value := range values {
|
||||
db = db.NewScope(value).NeedPtr().autoMigrate().db
|
||||
db = db.NewScope(value).autoMigrate().db
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// ModifyColumn modify column to type
|
||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.modifyColumn(column, typ)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// DropColumn drop a column
|
||||
func (s *DB) DropColumn(column string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.dropColumn(column)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) AddIndex(indexName string, column ...string) *DB {
|
||||
// AddIndex add index for columns with given name
|
||||
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
||||
scope := s.Unscoped().NewScope(s.Value)
|
||||
scope.addIndex(false, indexName, column...)
|
||||
scope.addIndex(false, indexName, columns...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
|
||||
// AddUniqueIndex add unique index for columns with given name
|
||||
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.addIndex(true, indexName, column...)
|
||||
scope.addIndex(true, indexName, columns...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// RemoveIndex remove index with name
|
||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.removeIndex(indexName)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) CurrentDatabase() string {
|
||||
var (
|
||||
scope = s.clone().NewScope(s.Value)
|
||||
name = s.dialect.CurrentDatabase(scope)
|
||||
)
|
||||
return name
|
||||
}
|
||||
|
||||
/*
|
||||
Add foreign key to the given scope
|
||||
|
||||
Example:
|
||||
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||
*/
|
||||
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
||||
func (s *DB) Association(column string) *Association {
|
||||
var err error
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
|
@ -491,7 +578,7 @@ func (s *DB) Association(column string) *Association {
|
|||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||
} else {
|
||||
return &Association{Scope: scope, Column: column, Field: field}
|
||||
return &Association{scope: scope, column: column, field: field}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||
|
@ -501,26 +588,30 @@ func (s *DB) Association(column string) *Association {
|
|||
return &Association{Error: err}
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||
return s.clone().search.Preload(column, conditions...).db
|
||||
}
|
||||
|
||||
// Set set value by name
|
||||
// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
|
||||
func (s *DB) Set(name string, value interface{}) *DB {
|
||||
return s.clone().InstantSet(name, value)
|
||||
}
|
||||
|
||||
// InstantSet instant set setting, will affect current db
|
||||
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||
s.values[name] = value
|
||||
return s
|
||||
}
|
||||
|
||||
// Get get value by name
|
||||
// Get get setting by name
|
||||
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||
value, ok = s.values[name]
|
||||
return
|
||||
}
|
||||
|
||||
// SetJoinTableHandler set a model's join table handler for a relation
|
||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||
scope := s.NewScope(source)
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
|
@ -530,7 +621,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||
handler.Setup(field.Relationship, many2many, source, destination)
|
||||
field.Relationship.JoinTableHandler = handler
|
||||
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
|
||||
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
||||
s.Table(table).AutoMigrate(handler)
|
||||
}
|
||||
}
|
||||
|
@ -538,9 +629,10 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||
}
|
||||
}
|
||||
|
||||
// AddError add error to the db
|
||||
func (s *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if err != RecordNotFound {
|
||||
if err != ErrRecordNotFound {
|
||||
if s.logMode == 0 {
|
||||
go s.print(fileWithLineNum(), err)
|
||||
} else {
|
||||
|
@ -559,6 +651,7 @@ func (s *DB) AddError(err error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// GetErrors get happened errors from the db
|
||||
func (s *DB) GetErrors() (errors []error) {
|
||||
if errs, ok := s.Error.(errorsInterface); ok {
|
||||
return errs.GetErrors()
|
||||
|
|
|
@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
|
|||
}
|
||||
|
||||
if s.search == nil {
|
||||
db.search = &search{}
|
||||
db.search = &search{limit: -1, offset: -1}
|
||||
} else {
|
||||
db.search = s.search.clone()
|
||||
}
|
||||
|
|
100
main_test.go
100
main_test.go
|
@ -4,23 +4,23 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
testdb "github.com/erikstmartin/go-testdb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/now"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/erikstmartin/go-testdb"
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
"github.com/jinzhu/gorm/dialects/postgres"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/jinzhu/now"
|
||||
)
|
||||
|
||||
var (
|
||||
DB gorm.DB
|
||||
DB *gorm.DB
|
||||
t1, t2, t3, t4, t5 time.Time
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ func init() {
|
|||
runMigration()
|
||||
}
|
||||
|
||||
func OpenTestConnection() (db gorm.DB, err error) {
|
||||
func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
switch os.Getenv("GORM_DIALECT") {
|
||||
case "mysql":
|
||||
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
|
||||
|
@ -115,7 +115,7 @@ func TestSetTable(t *testing.T) {
|
|||
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
|
||||
|
||||
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
|
||||
t.Errorf("No errors should happen if set table for pluck", err.Error())
|
||||
t.Error("No errors should happen if set table for pluck", err)
|
||||
}
|
||||
|
||||
var users []User
|
||||
|
@ -376,7 +376,7 @@ func TestRows(t *testing.T) {
|
|||
|
||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||
if err != nil {
|
||||
t.Errorf("Not error should happen, but got")
|
||||
t.Errorf("Not error should happen, got %v", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
|
@ -386,8 +386,39 @@ func TestRows(t *testing.T) {
|
|||
rows.Scan(&name, &age)
|
||||
count++
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Should found two records with name 3")
|
||||
t.Errorf("Should found two records")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRows(t *testing.T) {
|
||||
user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
||||
user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
||||
user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
||||
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||
|
||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||
if err != nil {
|
||||
t.Errorf("Not error should happen, got %v", err)
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
var results []Result
|
||||
for rows.Next() {
|
||||
var result Result
|
||||
if err := DB.ScanRows(rows, &result); err != nil {
|
||||
t.Errorf("should get no error, but got %v", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||
t.Errorf("Should find expected results")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -448,7 +479,7 @@ func TestRaw(t *testing.T) {
|
|||
}
|
||||
|
||||
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
|
||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
||||
t.Error("Raw sql to update records")
|
||||
}
|
||||
}
|
||||
|
@ -469,15 +500,34 @@ func TestGroup(t *testing.T) {
|
|||
|
||||
func TestJoins(t *testing.T) {
|
||||
var user = User{
|
||||
Name: "joins",
|
||||
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||
Name: "joins",
|
||||
CreditCard: CreditCard{Number: "411111111111"},
|
||||
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||
}
|
||||
DB.Save(&user)
|
||||
|
||||
var result User
|
||||
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result)
|
||||
if result.Name != "joins" || result.Id != user.Id {
|
||||
t.Errorf("Should find all two emails with Join")
|
||||
var users1 []User
|
||||
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
|
||||
if len(users1) != 2 {
|
||||
t.Errorf("should find two users using left join")
|
||||
}
|
||||
|
||||
var users2 []User
|
||||
DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
|
||||
if len(users2) != 1 {
|
||||
t.Errorf("should find one users using left join with conditions")
|
||||
}
|
||||
|
||||
var users3 []User
|
||||
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
|
||||
if len(users3) != 1 {
|
||||
t.Errorf("should find one users using multiple left join conditions")
|
||||
}
|
||||
|
||||
var users4 []User
|
||||
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
|
||||
if len(users4) != 0 {
|
||||
t.Errorf("should find no user when searching with unexisting credit card")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -557,7 +607,7 @@ func TestTimeWithZone(t *testing.T) {
|
|||
DB.First(&findUser, "name = ?", name)
|
||||
foundBirthday = findUser.Birthday.UTC().Format(format)
|
||||
if foundBirthday != expectedBirthday {
|
||||
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
|
||||
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
|
||||
}
|
||||
|
||||
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||
|
@ -573,7 +623,7 @@ func TestTimeWithZone(t *testing.T) {
|
|||
func TestHstore(t *testing.T) {
|
||||
type Details struct {
|
||||
Id int64
|
||||
Bulk gorm.Hstore
|
||||
Bulk postgres.Hstore
|
||||
}
|
||||
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||
|
@ -659,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
|
|||
}
|
||||
|
||||
var user User
|
||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
||||
t.Errorf("Should have found existing record")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
|
|||
}
|
||||
|
||||
scope := DB.NewScope(&Email{})
|
||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||
t.Errorf("Email should have index idx_email_email")
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
|
|||
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||
}
|
||||
|
||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||
t.Errorf("Email's index idx_email_email should be deleted")
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
|
|||
t.Errorf("Got error when tried to create index: %+v", err)
|
||||
}
|
||||
|
||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
|
|||
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||
}
|
||||
|
||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
|
|||
t.Errorf("Got error when tried to create index: %+v", err)
|
||||
}
|
||||
|
||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
|
|||
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||
}
|
||||
|
||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||
}
|
||||
|
||||
|
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
|
|||
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
|
||||
|
||||
scope := DB.NewScope(&BigEmail{})
|
||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
|
||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||
t.Errorf("Failed to create index")
|
||||
}
|
||||
|
||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
|
||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
|
||||
t.Errorf("Failed to create index")
|
||||
}
|
||||
|
||||
|
|
4
model.go
4
model.go
|
@ -2,6 +2,10 @@ package gorm
|
|||
|
||||
import "time"
|
||||
|
||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
|
|
|
@ -3,10 +3,8 @@ package gorm
|
|||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -14,6 +12,7 @@ import (
|
|||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// DefaultTableNameHandler default table name handler
|
||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
@ -41,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
|
|||
|
||||
var modelStructsMap = newModelStructsMap()
|
||||
|
||||
// ModelStruct model definition
|
||||
type ModelStruct struct {
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
|
@ -48,10 +48,12 @@ type ModelStruct struct {
|
|||
defaultTableName string
|
||||
}
|
||||
|
||||
// TableName get model's table name
|
||||
func (s *ModelStruct) TableName(db *DB) string {
|
||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||
}
|
||||
|
||||
// StructField model field's struct definition
|
||||
type StructField struct {
|
||||
DBName string
|
||||
Name string
|
||||
|
@ -107,7 +109,7 @@ func getForeignField(column string, fields []*StructField) *StructField {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetModelStruct generate model struct & relationships based on struct and tag definition
|
||||
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
var modelStruct ModelStruct
|
||||
// Scope value can't be nil
|
||||
|
@ -296,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, associationType) {
|
||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
|
@ -389,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, associationType) {
|
||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType))
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
|
@ -445,7 +453,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
if len(associationForeignKeys) == 0 {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if strings.HasPrefix(foreignKey, field.Name) {
|
||||
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name))
|
||||
associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
|
||||
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||
|
@ -508,63 +519,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
return &modelStruct
|
||||
}
|
||||
|
||||
// GetStructFields get model's field structs
|
||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
||||
func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
var sqlType string
|
||||
structType := field.Struct.Type
|
||||
if structType.Kind() == reflect.Ptr {
|
||||
structType = structType.Elem()
|
||||
}
|
||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
||||
|
||||
if value, ok := field.TagSettings["TYPE"]; ok {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
reflectValue = value
|
||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
||||
getScannerValue(reflectValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(reflectValue)
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
var size = 255
|
||||
|
||||
if value, ok := field.TagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(value)
|
||||
}
|
||||
|
||||
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
|
||||
if field.IsPrimaryKey {
|
||||
autoIncrease = true
|
||||
}
|
||||
if v == "FALSE" {
|
||||
autoIncrease = false
|
||||
}
|
||||
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
} else {
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
}
|
||||
|
||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||
setting := map[string]string{}
|
||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||
|
|
80
mssql.go
80
mssql.go
|
@ -1,80 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mssql struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (mssql) HasTop() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bit"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int IDENTITY(1,1)"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint IDENTITY(1,1)"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "float"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("nvarchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "datetime2"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = s.CurrentDatabase(scope)
|
||||
)
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var (
|
||||
count int
|
||||
databaseName = s.CurrentDatabase(scope)
|
||||
)
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
||||
return
|
||||
}
|
|
@ -21,7 +21,7 @@ type Tag struct {
|
|||
ID uint `gorm:"primary_key"`
|
||||
Locale string `gorm:"primary_key"`
|
||||
Value string
|
||||
Blogs []*Blog `gorm:"many2many:"blogs_tags`
|
||||
Blogs []*Blog `gorm:"many2many:blogs_tags"`
|
||||
}
|
||||
|
||||
func compareTags(tags []Tag, contents []string) bool {
|
||||
|
|
70
mysql.go
70
mysql.go
|
@ -1,70 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysql struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
if autoIncrease {
|
||||
return "int AUTO_INCREMENT"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "int unsigned"
|
||||
case reflect.Int64:
|
||||
if autoIncrease {
|
||||
return "bigint AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint unsigned"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "longtext"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "timestamp NULL"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
}
|
||||
return "longblob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
func (mysql) SelectFromDummyTable() string {
|
||||
return "FROM DUAL"
|
||||
}
|
||||
|
||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
||||
return
|
||||
}
|
|
@ -39,46 +39,46 @@ func TestPointerFields(t *testing.T) {
|
|||
|
||||
var nilPointerStruct = PointerStruct{}
|
||||
if err := DB.Create(&nilPointerStruct).Error; err != nil {
|
||||
t.Errorf("Failed to save nil pointer struct", err)
|
||||
t.Error("Failed to save nil pointer struct", err)
|
||||
}
|
||||
|
||||
var pointerStruct2 PointerStruct
|
||||
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||
t.Errorf("Failed to query saved nil pointer struct", err)
|
||||
t.Error("Failed to query saved nil pointer struct", err)
|
||||
}
|
||||
|
||||
var normalStruct2 NormalStruct
|
||||
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||
t.Errorf("Failed to query saved nil pointer struct", err)
|
||||
t.Error("Failed to query saved nil pointer struct", err)
|
||||
}
|
||||
|
||||
var partialNilPointerStruct1 = PointerStruct{Num: &num}
|
||||
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
|
||||
t.Errorf("Failed to save partial nil pointer struct", err)
|
||||
t.Error("Failed to save partial nil pointer struct", err)
|
||||
}
|
||||
|
||||
var pointerStruct3 PointerStruct
|
||||
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
|
||||
t.Errorf("Failed to query saved partial nil pointer struct", err)
|
||||
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||
}
|
||||
|
||||
var normalStruct3 NormalStruct
|
||||
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
|
||||
t.Errorf("Failed to query saved partial pointer struct", err)
|
||||
t.Error("Failed to query saved partial pointer struct", err)
|
||||
}
|
||||
|
||||
var partialNilPointerStruct2 = PointerStruct{Name: &name}
|
||||
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
|
||||
t.Errorf("Failed to save partial nil pointer struct", err)
|
||||
t.Error("Failed to save partial nil pointer struct", err)
|
||||
}
|
||||
|
||||
var pointerStruct4 PointerStruct
|
||||
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
|
||||
t.Errorf("Failed to query saved partial nil pointer struct", err)
|
||||
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||
}
|
||||
|
||||
var normalStruct4 NormalStruct
|
||||
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
|
||||
t.Errorf("Failed to query saved partial pointer struct", err)
|
||||
t.Error("Failed to query saved partial pointer struct", err)
|
||||
}
|
||||
}
|
||||
|
|
154
postgres.go
154
postgres.go
|
@ -1,154 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/hstore"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (postgres) BinVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertId() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "serial"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "numeric"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if value.Type() == hstoreType {
|
||||
return "hstore"
|
||||
}
|
||||
default:
|
||||
if isByteArrayOrSlice(value) {
|
||||
return "bytea"
|
||||
} else if isUUID(value) {
|
||||
return "uuid"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
var byteType = reflect.TypeOf(uint8(0))
|
||||
|
||||
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
||||
|
||||
func (s postgres) ReturningStr(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (s postgres) HasTable(scope *Scope, tableName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
|
||||
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
||||
return
|
||||
}
|
||||
|
||||
var hstoreType = reflect.TypeOf(Hstore{})
|
||||
|
||||
type Hstore map[string]*string
|
||||
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for key, value := range h {
|
||||
var s sql.NullString
|
||||
if value != nil {
|
||||
s.String = *value
|
||||
s.Valid = true
|
||||
}
|
||||
hstore.Map[key] = s
|
||||
}
|
||||
return hstore.Value()
|
||||
}
|
||||
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
hstore := hstore.Hstore{}
|
||||
|
||||
if err := hstore.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hstore.Map) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
*h = Hstore{}
|
||||
for k := range hstore.Map {
|
||||
if hstore.Map[k].Valid {
|
||||
s := hstore.Map[k].String
|
||||
(*h)[k] = &s
|
||||
} else {
|
||||
(*h)[k] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
388
preload.go
388
preload.go
|
@ -1,388 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
||||
// If value is a nil pointer, Indirect returns a zero Value!
|
||||
// Therefor we need to check for a zero value,
|
||||
// as FieldByName could panic
|
||||
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
|
||||
for _, column := range columns {
|
||||
if pointedValue.FieldByName(column).IsValid() {
|
||||
result := pointedValue.FieldByName(column).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
return toString(a) == toString(b)
|
||||
}
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
if scope.Search.preload == nil || scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap := map[string]bool{}
|
||||
fields := scope.Fields()
|
||||
for _, preload := range scope.Search.preload {
|
||||
schema, conditions := preload.schema, preload.conditions
|
||||
keys := strings.Split(schema, ".")
|
||||
currentScope := scope
|
||||
currentFields := fields
|
||||
originalConditions := conditions
|
||||
conditions = []interface{}{}
|
||||
for i, key := range keys {
|
||||
var found bool
|
||||
if preloadMap[strings.Join(keys[:i+1], ".")] {
|
||||
goto nextLoop
|
||||
}
|
||||
|
||||
if i == len(keys)-1 {
|
||||
conditions = originalConditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != key || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
found = true
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, conditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, conditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, conditions)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, conditions)
|
||||
default:
|
||||
currentScope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if !found {
|
||||
value := reflect.ValueOf(currentScope.Value)
|
||||
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
|
||||
value = value.Index(0).Elem()
|
||||
}
|
||||
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap[strings.Join(keys[:i+1], ".")] = true
|
||||
|
||||
nextLoop:
|
||||
if i < len(keys)-1 {
|
||||
currentScope = currentScope.getColumnsAsScope(key)
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func makeSlice(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, relation.ForeignFieldNames)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
|
||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := scope.SetColumn(field, result); err != nil {
|
||||
scope.Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
preloadMap := make(map[string][]reflect.Value)
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldNames)
|
||||
preloadMap[toString(value)] = append(preloadMap[toString(value)], result)
|
||||
}
|
||||
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
objectRealValue := getRealValue(object, relation.AssociationForeignFieldNames)
|
||||
objectStringValue := toString(objectRealValue)
|
||||
if results, ok := preloadMap[objectStringValue]; ok {
|
||||
if object.Kind() == reflect.Ptr {
|
||||
object = object.Elem()
|
||||
}
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, results...))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, resultValues)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, relation.AssociationForeignFieldNames)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if object.Kind() == reflect.Ptr {
|
||||
object = reflect.Indirect(objects.Index(j).Elem())
|
||||
}
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
joinTableHandler := relation.JoinTableHandler
|
||||
destType := field.StructField.Struct.Type.Elem()
|
||||
var isPtr bool
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
destType = destType.Elem()
|
||||
}
|
||||
|
||||
var sourceKeys []string
|
||||
var linkHash = make(map[string][]reflect.Value)
|
||||
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
|
||||
|
||||
preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
|
||||
|
||||
if len(conditions) > 0 {
|
||||
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
|
||||
}
|
||||
rows, err := preloadJoinDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
elem := reflect.New(destType).Elem()
|
||||
var values = make([]interface{}, len(columns))
|
||||
|
||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||
|
||||
var foundFields = map[string]bool{}
|
||||
for index, column := range columns {
|
||||
if field, ok := fields[column]; ok && !foundFields[column] {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
values[index] = field.Field.Addr().Interface()
|
||||
} else {
|
||||
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
|
||||
}
|
||||
foundFields[column] = true
|
||||
} else {
|
||||
var i interface{}
|
||||
values[index] = &i
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
var sourceKey []interface{}
|
||||
|
||||
var scannedFields = map[string]bool{}
|
||||
for index, column := range columns {
|
||||
value := values[index]
|
||||
if field, ok := fields[column]; ok && !scannedFields[column] {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
field.Field.Set(reflect.ValueOf(value).Elem())
|
||||
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
|
||||
field.Field.Set(v)
|
||||
}
|
||||
scannedFields[column] = true
|
||||
} else if strInSlice(column, sourceKeys) {
|
||||
sourceKey = append(sourceKey, *(value.(*interface{})))
|
||||
}
|
||||
}
|
||||
|
||||
if len(sourceKey) != 0 {
|
||||
if isPtr {
|
||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
|
||||
} else {
|
||||
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var foreignFieldNames []string
|
||||
for _, dbName := range relation.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if object.Kind() == reflect.Ptr {
|
||||
object = object.Elem()
|
||||
}
|
||||
source := getRealValue(object, foreignFieldNames)
|
||||
field := object.FieldByName(field.Name)
|
||||
for _, link := range linkHash[toString(source)] {
|
||||
field.Set(reflect.Append(field, link))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if object := scope.IndirectValue(); object.IsValid() {
|
||||
source := getRealValue(object, foreignFieldNames)
|
||||
field := object.FieldByName(field.Name)
|
||||
for _, link := range linkHash[toString(source)] {
|
||||
field.Set(reflect.Append(field, link))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
value := reflect.Indirect(values.Index(i))
|
||||
if value.Kind() == reflect.Ptr {
|
||||
value = reflect.Indirect(values.Index(i).Elem())
|
||||
}
|
||||
result = append(result, value.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, values.FieldByName(column).Interface())
|
||||
}
|
||||
return [][]interface{}{result}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
modelType := values.Type().Elem()
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fieldStruct, _ := modelType.FieldByName(column)
|
||||
var columns reflect.Value
|
||||
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
||||
} else {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
||||
}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
||||
if column.Kind() == reflect.Ptr {
|
||||
column = column.Elem()
|
||||
}
|
||||
if column.Kind() == reflect.Slice {
|
||||
for i := 0; i < column.Len(); i++ {
|
||||
elem := column.Index(i)
|
||||
if elem.CanAddr() {
|
||||
columns = reflect.Append(columns, elem.Addr())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if column.CanAddr() {
|
||||
columns = reflect.Append(columns, column.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
return scope.New(columns.Interface())
|
||||
case reflect.Struct:
|
||||
field := values.FieldByName(column)
|
||||
if !field.CanAddr() {
|
||||
return nil
|
||||
}
|
||||
return scope.New(field.Addr().Interface())
|
||||
}
|
||||
return nil
|
||||
}
|
251
preload_test.go
251
preload_test.go
|
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
@ -818,90 +818,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestManyToManyPreloadForPointer(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level1s []*Level1 `gorm:"many2many:levels;"`
|
||||
}
|
||||
)
|
||||
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.DropTableIfExists("levels")
|
||||
|
||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := Level2{Value: "Bob", Level1s: []*Level1{
|
||||
{Value: "ru"},
|
||||
{Value: "en"},
|
||||
}}
|
||||
if err := DB.Save(&want).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
||||
{Value: "zh"},
|
||||
{Value: "de"},
|
||||
}}
|
||||
if err := DB.Save(&want2).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var got Level2
|
||||
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
var got2 Level2
|
||||
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got2, want2) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
|
||||
}
|
||||
|
||||
var got3 []Level2
|
||||
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
||||
}
|
||||
|
||||
var got4 []Level2
|
||||
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var got5 Level2
|
||||
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
|
||||
|
||||
var ruLevel1 Level1
|
||||
var zhLevel1 Level1
|
||||
DB.First(&ruLevel1, "value = ?", "ru")
|
||||
DB.First(&zhLevel1, "value = ?", "zh")
|
||||
|
||||
got.Level1s = []*Level1{&ruLevel1}
|
||||
got2.Level1s = []*Level1{&zhLevel1}
|
||||
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyToManyPreloadForNestedPointer(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
|
@ -1065,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
@ -1122,12 +1038,87 @@ func TestNestedManyToManyPreload2(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedManyToManyPreload3(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level1s []*Level1 `gorm:"many2many:level1_level2;"`
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID sql.NullInt64
|
||||
Level2 *Level2
|
||||
}
|
||||
)
|
||||
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists("level1_level2")
|
||||
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
level1Zh := &Level1{Value: "zh"}
|
||||
level1Ru := &Level1{Value: "ru"}
|
||||
level1En := &Level1{Value: "en"}
|
||||
|
||||
level21 := &Level2{
|
||||
Value: "Level2-1",
|
||||
Level1s: []*Level1{level1Zh, level1Ru},
|
||||
}
|
||||
|
||||
level22 := &Level2{
|
||||
Value: "Level2-2",
|
||||
Level1s: []*Level1{level1Zh, level1En},
|
||||
}
|
||||
|
||||
wants := []*Level3{
|
||||
{
|
||||
Value: "Level3-1",
|
||||
Level2: level21,
|
||||
},
|
||||
{
|
||||
Value: "Level3-2",
|
||||
Level2: level22,
|
||||
},
|
||||
{
|
||||
Value: "Level3-3",
|
||||
Level2: level21,
|
||||
},
|
||||
}
|
||||
|
||||
for _, want := range wants {
|
||||
if err := DB.Save(&want).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
var gots []*Level3
|
||||
if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("level1.id ASC")
|
||||
}).Find(&gots).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gots, wants) {
|
||||
t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedManyToManyPreload4(t *testing.T) {
|
||||
type (
|
||||
Level4 struct {
|
||||
ID uint
|
||||
|
@ -1185,6 +1176,90 @@ func TestNestedManyToManyPreload3(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestManyToManyPreloadForPointer(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level1s []*Level1 `gorm:"many2many:levels;"`
|
||||
}
|
||||
)
|
||||
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.DropTableIfExists("levels")
|
||||
|
||||
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want := Level2{Value: "Bob", Level1s: []*Level1{
|
||||
{Value: "ru"},
|
||||
{Value: "en"},
|
||||
}}
|
||||
if err := DB.Save(&want).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
want2 := Level2{Value: "Tom", Level1s: []*Level1{
|
||||
{Value: "zh"},
|
||||
{Value: "de"},
|
||||
}}
|
||||
if err := DB.Save(&want2).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var got Level2
|
||||
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
var got2 Level2
|
||||
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got2, want2) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
|
||||
}
|
||||
|
||||
var got3 []Level2
|
||||
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
|
||||
}
|
||||
|
||||
var got4 []Level2
|
||||
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var got5 Level2
|
||||
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
|
||||
|
||||
var ruLevel1 Level1
|
||||
var zhLevel1 Level1
|
||||
DB.First(&ruLevel1, "value = ?", "ru")
|
||||
DB.First(&zhLevel1, "value = ?", "zh")
|
||||
|
||||
got.Level1s = []*Level1{&ruLevel1}
|
||||
got2.Level1s = []*Level1{&zhLevel1}
|
||||
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilPointerSlice(t *testing.T) {
|
||||
type (
|
||||
Level3 struct {
|
||||
|
@ -1234,7 +1309,7 @@ func TestNilPointerSlice(t *testing.T) {
|
|||
}
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Error("got %v items, expected 2", len(got))
|
||||
t.Errorf("got %v items, expected 2", len(got))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
||||
|
|
|
@ -629,14 +629,3 @@ func TestSelectWithArrayInput(t *testing.T) {
|
|||
t.Errorf("Should have selected both age and name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCurrentDatabase(t *testing.T) {
|
||||
databaseName := DB.CurrentDatabase()
|
||||
if err := DB.Error; err != nil {
|
||||
t.Errorf("Problem getting current db name: %s", err)
|
||||
}
|
||||
if databaseName == "" {
|
||||
t.Errorf("Current db name returned empty; this should never happen!")
|
||||
}
|
||||
t.Logf("Got current db name: %v", databaseName)
|
||||
}
|
||||
|
|
320
scope.go
320
scope.go
|
@ -1,48 +1,32 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Scope contain current operation's information when you perform any operation on the database
|
||||
type Scope struct {
|
||||
Search *search
|
||||
Value interface{}
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
SQL string
|
||||
SQLVars []interface{}
|
||||
db *DB
|
||||
indirectValue *reflect.Value
|
||||
instanceId string
|
||||
instanceID string
|
||||
primaryKeyField *Field
|
||||
skipLeft bool
|
||||
fields map[string]*Field
|
||||
selectAttrs *[]string
|
||||
}
|
||||
|
||||
// IndirectValue return scope's reflect value's indirect value
|
||||
func (scope *Scope) IndirectValue() reflect.Value {
|
||||
if scope.indirectValue == nil {
|
||||
value := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
if value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
scope.indirectValue = &value
|
||||
}
|
||||
return *scope.indirectValue
|
||||
}
|
||||
|
||||
func (scope *Scope) NeedPtr() *Scope {
|
||||
reflectKind := reflect.ValueOf(scope.Value).Kind()
|
||||
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
||||
err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
|
||||
scope.Err(err)
|
||||
fmt.Printf(err.Error())
|
||||
}
|
||||
return scope
|
||||
return indirect(reflect.ValueOf(scope.Value))
|
||||
}
|
||||
|
||||
// New create a new Scope without search information
|
||||
|
@ -61,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DB return scope's DB connection
|
||||
func (scope *Scope) DB() *DB {
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// SqlDB return *sql.DB
|
||||
func (scope *Scope) SqlDB() sqlCommon {
|
||||
// SQLDB return *sql.DB
|
||||
func (scope *Scope) SQLDB() sqlCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
||||
|
@ -75,7 +60,7 @@ func (scope *Scope) SkipLeft() {
|
|||
scope.skipLeft = true
|
||||
}
|
||||
|
||||
// Quote used to quote database column name according to database dialect
|
||||
// Quote used to quote string to escape them for database
|
||||
func (scope *Scope) Quote(str string) string {
|
||||
if strings.Index(str, ".") != -1 {
|
||||
newStrs := []string{}
|
||||
|
@ -83,12 +68,12 @@ func (scope *Scope) Quote(str string) string {
|
|||
newStrs = append(newStrs, scope.Dialect().Quote(str))
|
||||
}
|
||||
return strings.Join(newStrs, ".")
|
||||
} else {
|
||||
return scope.Dialect().Quote(str)
|
||||
}
|
||||
|
||||
return scope.Dialect().Quote(str)
|
||||
}
|
||||
|
||||
func (scope *Scope) QuoteIfPossible(str string) string {
|
||||
func (scope *Scope) quoteIfPossible(str string) string {
|
||||
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
||||
return scope.Quote(str)
|
||||
}
|
||||
|
@ -100,7 +85,7 @@ func (scope *Scope) Dialect() Dialect {
|
|||
return scope.db.parent.dialect
|
||||
}
|
||||
|
||||
// Err write error
|
||||
// Err add error to Scope
|
||||
func (scope *Scope) Err(err error) error {
|
||||
if err != nil {
|
||||
scope.db.AddError(err)
|
||||
|
@ -118,27 +103,30 @@ func (scope *Scope) HasError() bool {
|
|||
return scope.db.Error != nil
|
||||
}
|
||||
|
||||
func (scope *Scope) PrimaryFields() []*Field {
|
||||
var fields = []*Field{}
|
||||
for _, field := range scope.GetModelStruct().PrimaryFields {
|
||||
fields = append(fields, scope.Fields()[field.DBName])
|
||||
// PrimaryFields return scope's primary fields
|
||||
func (scope *Scope) PrimaryFields() (fields []*Field) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
|
||||
func (scope *Scope) PrimaryField() *Field {
|
||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||
if len(primaryFields) > 1 {
|
||||
if field, ok := scope.Fields()["id"]; ok {
|
||||
if field, ok := scope.FieldByName("id"); ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return scope.Fields()[primaryFields[0].DBName]
|
||||
return scope.PrimaryFields()[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrimaryKey get the primary key's column name
|
||||
// PrimaryKey get main primary field's db name
|
||||
func (scope *Scope) PrimaryKey() string {
|
||||
if field := scope.PrimaryField(); field != nil {
|
||||
return field.DBName
|
||||
|
@ -146,7 +134,7 @@ func (scope *Scope) PrimaryKey() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
// PrimaryKeyZero check main primary field's value is blank or not
|
||||
func (scope *Scope) PrimaryKeyZero() bool {
|
||||
field := scope.PrimaryField()
|
||||
return field == nil || field.IsBlank
|
||||
|
@ -170,80 +158,85 @@ func (scope *Scope) HasColumn(column string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// SetColumn to set the column's value
|
||||
// SetColumn to set the column's value, column could be field or field's name/dbname
|
||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
var updateAttrs = map[string]interface{}{}
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
updateAttrs = attrs.(map[string]interface{})
|
||||
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
|
||||
}
|
||||
|
||||
if field, ok := column.(*Field); ok {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
} else if name, ok := column.(string); ok {
|
||||
|
||||
if field, ok := scope.Fields()[name]; ok {
|
||||
return field.Set(value)
|
||||
var (
|
||||
dbName = ToDBName(name)
|
||||
mostMatchedField *Field
|
||||
)
|
||||
for _, field := range scope.Fields() {
|
||||
if field.DBName == value {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
}
|
||||
if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
|
||||
mostMatchedField = field
|
||||
}
|
||||
}
|
||||
|
||||
dbName := ToDBName(name)
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
if field, ok := scope.FieldByName(name); ok {
|
||||
return field.Set(value)
|
||||
if mostMatchedField != nil {
|
||||
updateAttrs[mostMatchedField.DBName] = value
|
||||
return mostMatchedField.Set(value)
|
||||
}
|
||||
}
|
||||
return errors.New("could not convert column to field")
|
||||
}
|
||||
|
||||
func (scope *Scope) CallMethod(name string, checkError bool) {
|
||||
if scope.Value == nil || (checkError && scope.HasError()) {
|
||||
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
|
||||
if reflectValue.CanAddr() {
|
||||
reflectValue = reflectValue.Addr()
|
||||
}
|
||||
|
||||
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
|
||||
switch method := methodValue.Interface().(type) {
|
||||
case func():
|
||||
method()
|
||||
case func(*Scope):
|
||||
method(scope)
|
||||
case func(*DB):
|
||||
newDB := scope.NewDB()
|
||||
method(newDB)
|
||||
scope.Err(newDB.Error)
|
||||
case func() error:
|
||||
scope.Err(method())
|
||||
case func(*Scope) error:
|
||||
scope.Err(method(scope))
|
||||
case func(*DB) error:
|
||||
newDB := scope.NewDB()
|
||||
scope.Err(method(newDB))
|
||||
scope.Err(newDB.Error)
|
||||
default:
|
||||
scope.Err(fmt.Errorf("unsupported function %v", methodName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
|
||||
func (scope *Scope) CallMethod(methodName string) {
|
||||
if scope.Value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
call := func(value interface{}) {
|
||||
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
|
||||
switch f := fm.Interface().(type) {
|
||||
case func():
|
||||
f()
|
||||
case func(s *Scope):
|
||||
f(scope)
|
||||
case func(s *DB):
|
||||
newDB := scope.NewDB()
|
||||
f(newDB)
|
||||
scope.Err(newDB.Error)
|
||||
case func() error:
|
||||
scope.Err(f())
|
||||
case func(s *Scope) error:
|
||||
scope.Err(f(scope))
|
||||
case func(s *DB) error:
|
||||
newDB := scope.NewDB()
|
||||
scope.Err(f(newDB))
|
||||
scope.Err(newDB.Error)
|
||||
default:
|
||||
scope.Err(fmt.Errorf("unsupported function %v", name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
value := values.Index(i).Addr().Interface()
|
||||
if values.Index(i).Kind() == reflect.Ptr {
|
||||
value = values.Index(i).Interface()
|
||||
}
|
||||
call(value)
|
||||
if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||
scope.callMethod(methodName, indirectScopeValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
if scope.IndirectValue().CanAddr() {
|
||||
call(scope.IndirectValue().Addr().Interface())
|
||||
} else {
|
||||
call(scope.IndirectValue().Interface())
|
||||
}
|
||||
scope.callMethod(methodName, indirectScopeValue)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) CallMethodWithErrorCheck(name string) {
|
||||
scope.CallMethod(name, true)
|
||||
}
|
||||
|
||||
// AddToVars add value as sql's vars, gorm will escape them
|
||||
// AddToVars add value as sql's vars, used to prevent SQL injection
|
||||
func (scope *Scope) AddToVars(value interface{}) string {
|
||||
if expr, ok := value.(*expr); ok {
|
||||
exp := expr.expr
|
||||
|
@ -251,10 +244,10 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
return exp
|
||||
} else {
|
||||
scope.SqlVars = append(scope.SqlVars, value)
|
||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
||||
}
|
||||
|
||||
scope.SQLVars = append(scope.SQLVars, value)
|
||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||
}
|
||||
|
||||
type tabler interface {
|
||||
|
@ -265,7 +258,7 @@ type dbTabler interface {
|
|||
TableName(*DB) string
|
||||
}
|
||||
|
||||
// TableName get table name
|
||||
// TableName return table name
|
||||
func (scope *Scope) TableName() string {
|
||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||
return scope.Search.tableName
|
||||
|
@ -282,44 +275,54 @@ func (scope *Scope) TableName() string {
|
|||
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
|
||||
}
|
||||
|
||||
// QuotedTableName return quoted table name
|
||||
func (scope *Scope) QuotedTableName() (name string) {
|
||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||
if strings.Index(scope.Search.tableName, " ") != -1 {
|
||||
return scope.Search.tableName
|
||||
}
|
||||
return scope.Quote(scope.Search.tableName)
|
||||
} else {
|
||||
return scope.Quote(scope.TableName())
|
||||
}
|
||||
|
||||
return scope.Quote(scope.TableName())
|
||||
}
|
||||
|
||||
// CombinedConditionSql get combined condition sql
|
||||
// CombinedConditionSql return combined condition sql
|
||||
func (scope *Scope) CombinedConditionSql() string {
|
||||
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
|
||||
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
|
||||
return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
|
||||
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||
}
|
||||
|
||||
// FieldByName find `gorm.Field` with field name or db name
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
var (
|
||||
dbName = ToDBName(name)
|
||||
mostMatchedField *Field
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name || field.DBName == name {
|
||||
return field, true
|
||||
}
|
||||
if field.DBName == dbName {
|
||||
mostMatchedField = field
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
return mostMatchedField, mostMatchedField != nil
|
||||
}
|
||||
|
||||
// Raw set sql
|
||||
// Raw set raw sql
|
||||
func (scope *Scope) Raw(sql string) *Scope {
|
||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
||||
scope.SQL = strings.Replace(sql, "$$", "?", -1)
|
||||
return scope
|
||||
}
|
||||
|
||||
// Exec invoke sql
|
||||
// Exec perform generated SQL
|
||||
func (scope *Scope) Exec() *Scope {
|
||||
defer scope.Trace(NowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
|
@ -334,37 +337,32 @@ func (scope *Scope) Set(name string, value interface{}) *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
// Get get value by name
|
||||
// Get get setting by name
|
||||
func (scope *Scope) Get(name string) (interface{}, bool) {
|
||||
return scope.db.Get(name)
|
||||
}
|
||||
|
||||
// InstanceId get InstanceId for scope
|
||||
func (scope *Scope) InstanceId() string {
|
||||
if scope.instanceId == "" {
|
||||
scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
|
||||
// InstanceID get InstanceID for scope
|
||||
func (scope *Scope) InstanceID() string {
|
||||
if scope.instanceID == "" {
|
||||
scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db)
|
||||
}
|
||||
return scope.instanceId
|
||||
return scope.instanceID
|
||||
}
|
||||
|
||||
// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
|
||||
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
||||
return scope.Set(name+scope.InstanceId(), value)
|
||||
return scope.Set(name+scope.InstanceID(), value)
|
||||
}
|
||||
|
||||
// InstanceGet get instance setting from current operation
|
||||
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||
return scope.Get(name + scope.InstanceId())
|
||||
}
|
||||
|
||||
// Trace print sql log
|
||||
func (scope *Scope) Trace(t time.Time) {
|
||||
if len(scope.Sql) > 0 {
|
||||
scope.db.slog(scope.Sql, t, scope.SqlVars...)
|
||||
}
|
||||
return scope.Get(name + scope.InstanceID())
|
||||
}
|
||||
|
||||
// Begin start a transaction
|
||||
func (scope *Scope) Begin() *Scope {
|
||||
if db, ok := scope.SqlDB().(sqlDb); ok {
|
||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
scope.db.db = interface{}(tx).(sqlCommon)
|
||||
scope.InstanceSet("gorm:started_transaction", true)
|
||||
|
@ -373,7 +371,7 @@ func (scope *Scope) Begin() *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
|
||||
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
|
||||
func (scope *Scope) CommitOrRollback() *Scope {
|
||||
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db, ok := scope.db.db.(sqlTx); ok {
|
||||
|
@ -388,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
// SelectAttrs return selected attributes
|
||||
func (scope *Scope) SelectAttrs() []string {
|
||||
if scope.selectAttrs == nil {
|
||||
attrs := []string{}
|
||||
|
@ -407,57 +406,38 @@ func (scope *Scope) SelectAttrs() []string {
|
|||
return *scope.selectAttrs
|
||||
}
|
||||
|
||||
// OmitAttrs return omited attributes
|
||||
func (scope *Scope) OmitAttrs() []string {
|
||||
return scope.Search.omits
|
||||
}
|
||||
|
||||
func (scope *Scope) changeableDBColumn(column string) bool {
|
||||
selectAttrs := scope.SelectAttrs()
|
||||
omitAttrs := scope.OmitAttrs()
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
||||
var values = make([]interface{}, len(columns))
|
||||
var ignored interface{}
|
||||
|
||||
if len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if column == ToDBName(attr) {
|
||||
return true
|
||||
for index, column := range columns {
|
||||
if field, ok := fieldsMap[column]; ok {
|
||||
if field.Field.Kind() == reflect.Ptr {
|
||||
values[index] = field.Field.Addr().Interface()
|
||||
} else {
|
||||
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
|
||||
reflectValue.Elem().Set(field.Field.Addr())
|
||||
values[index] = reflectValue.Interface()
|
||||
}
|
||||
} else {
|
||||
values[index] = &ignored
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(rows.Scan(values...))
|
||||
|
||||
for index, column := range columns {
|
||||
if field, ok := fieldsMap[column]; ok {
|
||||
if field.Field.Kind() != reflect.Ptr {
|
||||
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
|
||||
field.Field.Set(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, attr := range omitAttrs {
|
||||
if column == ToDBName(attr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (scope *Scope) changeableField(field *Field) bool {
|
||||
selectAttrs := scope.SelectAttrs()
|
||||
omitAttrs := scope.OmitAttrs()
|
||||
|
||||
if len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, attr := range omitAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return !field.IsIgnored
|
||||
}
|
||||
|
||||
func (scope *Scope) shouldSaveAssociations() bool {
|
||||
saveAssociations, ok := scope.Get("gorm:save_associations")
|
||||
if ok && !saveAssociations.(bool) {
|
||||
return false
|
||||
}
|
||||
return true && !scope.HasError()
|
||||
}
|
||||
|
|
276
scope_private.go
276
scope_private.go
|
@ -8,6 +8,7 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
|
@ -75,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
}
|
||||
|
||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
var notEqualSql string
|
||||
var notEqualSQL string
|
||||
var primaryKey = scope.PrimaryKey()
|
||||
|
||||
switch value := clause["query"].(type) {
|
||||
|
@ -86,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
||||
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||
} else {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||
|
@ -138,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
|
||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -172,17 +173,20 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) whereSql() (sql string) {
|
||||
var primaryConditions, andConditions, orConditions []string
|
||||
func (scope *Scope) whereSQL() (sql string) {
|
||||
var (
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryConditions, andConditions, orConditions []string
|
||||
)
|
||||
|
||||
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
||||
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
||||
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
||||
if !scope.PrimaryKeyZero() {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
}
|
||||
|
@ -205,30 +209,30 @@ func (scope *Scope) whereSql() (sql string) {
|
|||
}
|
||||
}
|
||||
|
||||
orSql := strings.Join(orConditions, " OR ")
|
||||
combinedSql := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSql) > 0 {
|
||||
if len(orSql) > 0 {
|
||||
combinedSql = combinedSql + " OR " + orSql
|
||||
orSQL := strings.Join(orConditions, " OR ")
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
if len(orSQL) > 0 {
|
||||
combinedSQL = combinedSQL + " OR " + orSQL
|
||||
}
|
||||
} else {
|
||||
combinedSql = orSql
|
||||
combinedSQL = orSQL
|
||||
}
|
||||
|
||||
if len(primaryConditions) > 0 {
|
||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||
if len(combinedSql) > 0 {
|
||||
sql = sql + " AND (" + combinedSql + ")"
|
||||
if len(combinedSQL) > 0 {
|
||||
sql = sql + " AND (" + combinedSQL + ")"
|
||||
}
|
||||
} else if len(combinedSql) > 0 {
|
||||
sql = "WHERE " + combinedSql
|
||||
} else if len(combinedSQL) > 0 {
|
||||
sql = "WHERE " + combinedSQL
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) selectSql() string {
|
||||
func (scope *Scope) selectSQL() string {
|
||||
if len(scope.Search.selects) == 0 {
|
||||
if scope.Search.joins != "" {
|
||||
if len(scope.Search.joinConditions) > 0 {
|
||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||
}
|
||||
return "*"
|
||||
|
@ -236,87 +240,60 @@ func (scope *Scope) selectSql() string {
|
|||
return scope.buildSelectQuery(scope.Search.selects)
|
||||
}
|
||||
|
||||
func (scope *Scope) orderSql() string {
|
||||
func (scope *Scope) orderSQL() string {
|
||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||
return ""
|
||||
}
|
||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||
}
|
||||
|
||||
func (scope *Scope) limitSql() string {
|
||||
if !scope.Dialect().HasTop() {
|
||||
if len(scope.Search.limit) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " LIMIT " + scope.Search.limit
|
||||
}
|
||||
|
||||
return ""
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
}
|
||||
|
||||
func (scope *Scope) topSql() string {
|
||||
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
|
||||
if len(scope.Search.limit) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " TOP(" + scope.Search.limit + ")"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (scope *Scope) offsetSql() string {
|
||||
if len(scope.Search.offset) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if scope.Dialect().HasTop() {
|
||||
sql := " OFFSET " + scope.Search.offset + " ROW "
|
||||
if len(scope.Search.limit) > 0 {
|
||||
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
|
||||
}
|
||||
return sql
|
||||
}
|
||||
return " OFFSET " + scope.Search.offset
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSql() string {
|
||||
func (scope *Scope) groupSQL() string {
|
||||
if len(scope.Search.group) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " GROUP BY " + scope.Search.group
|
||||
}
|
||||
|
||||
func (scope *Scope) havingSql() string {
|
||||
if scope.Search.havingConditions == nil {
|
||||
func (scope *Scope) havingSQL() string {
|
||||
if len(scope.Search.havingConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var andConditions []string
|
||||
|
||||
for _, clause := range scope.Search.havingConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
combinedSql := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSql) == 0 {
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return " HAVING " + combinedSql
|
||||
return " HAVING " + combinedSQL
|
||||
}
|
||||
|
||||
func (scope *Scope) joinsSql() string {
|
||||
return scope.Search.joins + " "
|
||||
func (scope *Scope) joinsSQL() string {
|
||||
var joinConditions []string
|
||||
for _, clause := range scope.Search.joinConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(joinConditions, " ") + " "
|
||||
}
|
||||
|
||||
func (scope *Scope) prepareQuerySql() {
|
||||
func (scope *Scope) prepareQuerySQL() {
|
||||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -338,61 +315,53 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
||||
if !scope.IndirectValue().CanAddr() {
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||
return values, true
|
||||
}
|
||||
|
||||
var hasExpr bool
|
||||
results = map[string]interface{}{}
|
||||
for key, value := range values {
|
||||
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasExpr = true
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
field.Set(value)
|
||||
if field.IsNormal {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasExpr {
|
||||
var updateMap = map[string]interface{}{}
|
||||
for key, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
if v, ok := values[key]; ok {
|
||||
updateMap[key] = v
|
||||
} else {
|
||||
updateMap[key] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return updateMap, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||
}
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||
return scope
|
||||
}
|
||||
|
||||
|
@ -433,23 +402,45 @@ func (scope *Scope) typeName() string {
|
|||
return typ.Name()
|
||||
}
|
||||
|
||||
// trace print sql log
|
||||
func (scope *Scope) trace(t time.Time) {
|
||||
if len(scope.SQL) > 0 {
|
||||
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) changeableField(field *Field) bool {
|
||||
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, attr := range scope.OmitAttrs() {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (scope *Scope) shouldSaveAssociations() bool {
|
||||
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
||||
return false
|
||||
}
|
||||
return true && !scope.HasError()
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
toScope := scope.db.NewScope(value)
|
||||
fromFields := scope.Fields()
|
||||
toFields := toScope.Fields()
|
||||
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
var fromField, toField *Field
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
fromField = field
|
||||
} else {
|
||||
fromField = fromFields[ToDBName(foreignKey)]
|
||||
}
|
||||
if field, ok := toScope.FieldByName(foreignKey); ok {
|
||||
toField = field
|
||||
} else {
|
||||
toField = toFields[ToDBName(foreignKey)]
|
||||
}
|
||||
fromField, _ := scope.FieldByName(foreignKey)
|
||||
toField, _ := toScope.FieldByName(foreignKey)
|
||||
|
||||
if fromField != nil {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
|
@ -508,30 +499,26 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
joinTable := joinTableHandler.Table(scope.db)
|
||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||
if !scope.Dialect().HasTable(joinTable) {
|
||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||
|
||||
var sqlTypes, primaryKeys []string
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||
if primaryKeySqlType == "" {
|
||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
||||
}
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
if field, ok := scope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||
if primaryKeySqlType == "" {
|
||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
||||
}
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
@ -545,10 +532,10 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
func (scope *Scope) createTable() *Scope {
|
||||
var tags []string
|
||||
var primaryKeys []string
|
||||
var primaryKeyInColumnType bool = false
|
||||
for _, field := range scope.GetStructFields() {
|
||||
var primaryKeyInColumnType = false
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
|
||||
// Check if the primary key constraint was specified as
|
||||
// part of the column type. If so, we can only support
|
||||
|
@ -582,13 +569,6 @@ func (scope *Scope) dropTable() *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) dropTableIfExists() *Scope {
|
||||
if scope.Dialect().HasTable(scope, scope.TableName()) {
|
||||
scope.dropTable()
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
||||
}
|
||||
|
@ -598,13 +578,13 @@ func (scope *Scope) dropColumn(column string) {
|
|||
}
|
||||
|
||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||
return
|
||||
}
|
||||
|
||||
var columns []string
|
||||
for _, name := range column {
|
||||
columns = append(columns, scope.QuoteIfPossible(name))
|
||||
columns = append(columns, scope.quoteIfPossible(name))
|
||||
}
|
||||
|
||||
sqlCreate := "CREATE INDEX"
|
||||
|
@ -612,31 +592,35 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|||
sqlCreate = "CREATE UNIQUE INDEX"
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||
|
||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||
return
|
||||
}
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
scope.Dialect().RemoveIndex(scope, indexName)
|
||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||
}
|
||||
|
||||
func (scope *Scope) autoMigrate() *Scope {
|
||||
tableName := scope.TableName()
|
||||
quotedTableName := scope.QuotedTableName()
|
||||
|
||||
if !scope.Dialect().HasTable(scope, tableName) {
|
||||
if !scope.Dialect().HasTable(tableName) {
|
||||
scope.createTable()
|
||||
} else {
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
package gorm
|
||||
|
||||
import "reflect"
|
||||
|
||||
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
for _, value := range values {
|
||||
indirectValue := reflect.ValueOf(value)
|
||||
for indirectValue.Kind() == reflect.Ptr {
|
||||
indirectValue = indirectValue.Elem()
|
||||
}
|
||||
|
||||
switch indirectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < indirectValue.Len(); i++ {
|
||||
var result []interface{}
|
||||
var object = indirect(indirectValue.Index(i))
|
||||
for _, column := range columns {
|
||||
result = append(result, object.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, indirectValue.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||
indirectScopeValue := scope.IndirectValue()
|
||||
|
||||
switch indirectScopeValue.Kind() {
|
||||
case reflect.Slice:
|
||||
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||
fieldType := fieldStruct.Type
|
||||
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||
|
||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||
|
||||
if result.Kind() == reflect.Slice {
|
||||
for j := 0; j < result.Len(); j++ {
|
||||
if elem := result.Index(j); elem.CanAddr() {
|
||||
results = reflect.Append(results, elem.Addr())
|
||||
}
|
||||
}
|
||||
} else if result.CanAddr() {
|
||||
results = reflect.Append(results, result.Addr())
|
||||
}
|
||||
}
|
||||
return scope.New(results.Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||
return scope.New(field.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
24
search.go
24
search.go
|
@ -8,15 +8,15 @@ type search struct {
|
|||
orConditions []map[string]interface{}
|
||||
notConditions []map[string]interface{}
|
||||
havingConditions []map[string]interface{}
|
||||
joinConditions []map[string]interface{}
|
||||
initAttrs []interface{}
|
||||
assignAttrs []interface{}
|
||||
selects map[string]interface{}
|
||||
omits []string
|
||||
orders []string
|
||||
joins string
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
offset int
|
||||
limit int
|
||||
group string
|
||||
tableName string
|
||||
raw bool
|
||||
|
@ -82,18 +82,18 @@ func (s *search) Omit(columns ...string) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) Limit(value interface{}) *search {
|
||||
s.limit = s.getInterfaceAsSql(value)
|
||||
func (s *search) Limit(limit int) *search {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Offset(value interface{}) *search {
|
||||
s.offset = s.getInterfaceAsSql(value)
|
||||
func (s *search) Offset(offset int) *search {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Group(query string) *search {
|
||||
s.group = s.getInterfaceAsSql(query)
|
||||
s.group = s.getInterfaceAsSQL(query)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) Joins(query string) *search {
|
||||
s.joins = query
|
||||
func (s *search) Joins(query string, values ...interface{}) *search {
|
||||
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
|
||||
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||
switch value.(type) {
|
||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
str = fmt.Sprintf("%v", value)
|
||||
default:
|
||||
s.db.AddError(InvalidSql)
|
||||
s.db.AddError(ErrInvalidSQL)
|
||||
}
|
||||
|
||||
if str == "-1" {
|
||||
|
|
84
sqlite3.go
84
sqlite3.go
|
@ -1,84 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sqlite3 struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "integer primary key autoincrement"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "integer primary key autoincrement"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "real"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
||||
}
|
||||
|
||||
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
|
||||
var (
|
||||
ifaces = make([]interface{}, 3)
|
||||
pointers = make([]*string, 3)
|
||||
i int
|
||||
)
|
||||
for i = 0; i < 3; i++ {
|
||||
ifaces[i] = &pointers[i]
|
||||
}
|
||||
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
if pointers[1] != nil {
|
||||
name = *pointers[1]
|
||||
}
|
||||
return
|
||||
}
|
|
@ -42,9 +42,9 @@ type CreditCard struct {
|
|||
ID int8
|
||||
Number string
|
||||
UserId sql.NullInt64
|
||||
CreatedAt time.Time
|
||||
CreatedAt time.Time `sql:"not null"`
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
|
@ -62,7 +62,7 @@ type Address struct {
|
|||
Post string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Language struct {
|
||||
|
|
|
@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
DB.First(&product4, product4.Id)
|
||||
updatedAt4 := product4.UpdatedAt
|
||||
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100-50 {
|
||||
t.Errorf("Update with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Update with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
|
|||
t.Errorf("product2's code should be updated")
|
||||
}
|
||||
|
||||
updatedAt4 := product4.UpdatedAt
|
||||
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100 {
|
||||
t.Errorf("Updates with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
// product4's UpdatedAt will be reset when updating
|
||||
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Updates with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
@ -419,3 +422,32 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
|||
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||
product := Product{Code: "product1", Price: 10}
|
||||
DB.Save(&product)
|
||||
|
||||
DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
|
||||
|
||||
var product1 Product
|
||||
DB.First(&product1, product.Id)
|
||||
|
||||
if product1.Code != "product1" || product1.Price != 100 {
|
||||
t.Errorf("product's code should not be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDecodeVirtualAttributes(t *testing.T) {
|
||||
var user = User{
|
||||
Name: "jinzhu",
|
||||
IgnoreMe: 88,
|
||||
}
|
||||
|
||||
DB.Save(&user)
|
||||
|
||||
DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
|
||||
|
||||
if user.IgnoreMe != 100 {
|
||||
t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
|
||||
}
|
||||
}
|
||||
|
|
239
utils.go
239
utils.go
|
@ -2,10 +2,26 @@ package gorm
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NowFunc returns current time, this function is exported in order to be able
|
||||
// to give the flexibility to the developer to customize it according to their
|
||||
// needs, e.g:
|
||||
// gorm.NowFunc = func() time.Time {
|
||||
// return time.Now().UTC()
|
||||
// }
|
||||
var NowFunc = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// Copied from golint
|
||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
var commonInitialismsReplacer *strings.Replacer
|
||||
|
@ -41,30 +57,239 @@ func newSafeMap() *safeMap {
|
|||
|
||||
var smap = newSafeMap()
|
||||
|
||||
type strCase bool
|
||||
|
||||
const (
|
||||
lower strCase = false
|
||||
upper strCase = true
|
||||
)
|
||||
|
||||
// ToDBName convert string to db name
|
||||
func ToDBName(name string) string {
|
||||
if v := smap.Get(name); v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
value := commonInitialismsReplacer.Replace(name)
|
||||
buf := bytes.NewBufferString("")
|
||||
for i, v := range value {
|
||||
if i > 0 && v >= 'A' && v <= 'Z' {
|
||||
buf.WriteRune('_')
|
||||
}
|
||||
buf.WriteRune(v)
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
value = commonInitialismsReplacer.Replace(name)
|
||||
buf = bytes.NewBufferString("")
|
||||
lastCase, currCase, nextCase strCase
|
||||
)
|
||||
|
||||
for i, v := range value[:len(value)-1] {
|
||||
nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z'
|
||||
if i > 0 {
|
||||
if currCase == upper {
|
||||
if lastCase == upper && nextCase == upper {
|
||||
buf.WriteRune(v)
|
||||
} else {
|
||||
if value[i-1] != '_' && value[i+1] != '_' {
|
||||
buf.WriteRune('_')
|
||||
}
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
} else {
|
||||
currCase = upper
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
lastCase = currCase
|
||||
currCase = nextCase
|
||||
}
|
||||
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
|
||||
s := strings.ToLower(buf.String())
|
||||
smap.Set(name, s)
|
||||
return s
|
||||
}
|
||||
|
||||
// SQL expression
|
||||
type expr struct {
|
||||
expr string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// Expr generate raw SQL expression, for example:
|
||||
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
|
||||
func Expr(expression string, args ...interface{}) *expr {
|
||||
return &expr{expr: expression, args: args}
|
||||
}
|
||||
|
||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
return reflectValue
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
}
|
||||
return strings.Join(newColumns, ",")
|
||||
}
|
||||
|
||||
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||
for _, value := range values {
|
||||
for _, v := range value {
|
||||
results = append(results, v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
||||
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||
if len(attrs) > 1 {
|
||||
if str, ok := attrs[0].(string); ok {
|
||||
result = map[string]interface{}{str: attrs[1]}
|
||||
}
|
||||
} else if len(attrs) == 1 {
|
||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
|
||||
if attr, ok := attrs[0].(interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||
attrs := map[string]interface{}{}
|
||||
|
||||
switch value := values.(type) {
|
||||
case map[string]interface{}:
|
||||
return value
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
for key, value := range convertInterfaceToMap(v) {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
case interface{}:
|
||||
reflectValue := reflect.ValueOf(values)
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Map:
|
||||
for _, key := range reflectValue.MapKeys() {
|
||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||
}
|
||||
default:
|
||||
for _, field := range (&Scope{Value: values}).Fields() {
|
||||
if !field.IsBlank {
|
||||
attrs[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
return toString(a) == toString(b)
|
||||
}
|
||||
|
||||
func toString(str interface{}) string {
|
||||
if values, ok := str.([]interface{}); ok {
|
||||
var results []string
|
||||
for _, value := range values {
|
||||
results = append(results, toString(value))
|
||||
}
|
||||
return strings.Join(results, "_")
|
||||
} else if bytes, ok := str.([]byte); ok {
|
||||
return string(bytes)
|
||||
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func makeSlice(elemType reflect.Type) interface{} {
|
||||
if elemType.Kind() == reflect.Slice {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(elemType)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func strInSlice(a string, list []string) bool {
|
||||
for _, b := range list {
|
||||
if b == a {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getValueFromFields return given fields's value
|
||||
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
|
||||
// If value is a nil pointer, Indirect returns a zero Value!
|
||||
// Therefor we need to check for a zero value,
|
||||
// as FieldByName could panic
|
||||
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||
for _, fieldName := range fieldNames {
|
||||
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
||||
result := fieldValue.Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addExtraSpaceIfExist(str string) string {
|
||||
if str != "" {
|
||||
return " " + str
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func fileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
||||
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||
if len(attrs) > 1 {
|
||||
if str, ok := attrs[0].(string); ok {
|
||||
result = map[string]interface{}{str: attrs[1]}
|
||||
}
|
||||
} else if len(attrs) == 1 {
|
||||
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
|
||||
if attr, ok := attrs[0].(interface{}); ok {
|
||||
result = attr
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
||||
attrs := map[string]interface{}{}
|
||||
|
||||
switch value := values.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
attrs[ToDBName(k)] = v
|
||||
}
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
for key, value := range convertInterfaceToMap(v) {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
case interface{}:
|
||||
reflectValue := reflect.ValueOf(values)
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Map:
|
||||
for _, key := range reflectValue.MapKeys() {
|
||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||
}
|
||||
default:
|
||||
scope := Scope{Value: values}
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsBlank && !field.IsIgnored {
|
||||
attrs[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func toString(str interface{}) string {
|
||||
if values, ok := str.([]interface{}); ok {
|
||||
var results []string
|
||||
for _, value := range values {
|
||||
results = append(results, toString(value))
|
||||
}
|
||||
return strings.Join(results, "_")
|
||||
} else if bytes, ok := str.([]byte); ok {
|
||||
return string(bytes)
|
||||
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func strInSlice(a string, list []string) bool {
|
||||
for _, b := range list {
|
||||
if b == a {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package gorm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
||||
var maps = map[string]string{
|
||||
"": "",
|
||||
"ThisIsATest": "this_is_a_test",
|
||||
"PFAndESI": "pf_and_esi",
|
||||
"AbcAndJkl": "abc_and_jkl",
|
||||
"EmployeeID": "employee_id",
|
||||
"SKU_ID": "sku_id",
|
||||
"HTTPAndSMTP": "http_and_smtp",
|
||||
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
||||
"UUID": "uuid",
|
||||
"HTTPURL": "http_url",
|
||||
"HTTP_URL": "http_url",
|
||||
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
||||
}
|
||||
|
||||
for key, value := range maps {
|
||||
if gorm.ToDBName(key) != value {
|
||||
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue