Merge pull request #794 from jinzhu/v1.0_dev

The first public release
This commit is contained in:
Jinzhu 2016-03-08 12:20:12 +08:00
commit 9d57c6b961
58 changed files with 3154 additions and 3856 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
documents
_book

1330
README.md

File diff suppressed because it is too large Load Diff

View File

@ -4,32 +4,289 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct { type Association struct {
Scope *Scope
Column string
Error error Error error
Field *Field scope *Scope
column string
field *Field
} }
func (association *Association) setErr(err error) *Association { // Find find out all related associations
if err != nil { func (association *Association) Find(value interface{}) *Association {
association.Error = err association.scope.related(value, association.column)
return association.setErr(association.scope.db.Error)
}
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
func (association *Association) Append(values ...interface{}) *Association {
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
return association.Replace(values...)
}
return association.saveAssociations(values...)
}
// Replace replace current associations with new one
func (association *Association) Replace(values ...interface{}) *Association {
var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
// Append new values
association.field.Set(reflect.Zero(association.field.Field.Type()))
association.saveAssociations(values...)
// Belongs To
if relationship.Kind == "belongs_to" {
// Set foreign key to be null when clearing value (length equals 0)
if len(values) == 0 {
// Set foreign key to be nil
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
}
} else {
// Polymorphic Relations
if relationship.PolymorphicDBName != "" {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
// Delete Relations except new created
if len(values) > 0 {
var associationForeignFieldNames []string
if relationship.Kind == "many_to_many" {
// if many to many relations, get association fields name from association foreign keys
associationScope := scope.New(reflect.New(field.Type()).Interface())
for _, dbName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(dbName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
} else {
// If other relations, use primary keys
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
}
}
if relationship.Kind == "many_to_many" {
// if many to many relations, delete related relations from join table
var sourceForeignFieldNames []string
for _, dbName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
}
}
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
} }
return association return association
} }
func (association *Association) Find(value interface{}) *Association { // Delete remove relationship between source & passed arguments, but won't delete those arguments
association.Scope.related(value, association.Column) func (association *Association) Delete(values ...interface{}) *Association {
return association.setErr(association.Scope.db.Error) var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
if len(values) == 0 {
return association
}
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
}
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" {
// source value's foreign keys
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
// get association's foreign fields name
var associationScope = scope.New(reflect.New(field.Type()).Interface())
var associationForeignFieldNames []string
for _, associationDBName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(associationDBName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
// association value's foreign keys
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// set foreign key to be null if there are some records affected
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap)
}
} else {
association.setErr(results.Error)
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// only include those deleting relations
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
toQueryValues(deletingPrimaryKeys)...,
)
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
// Remove deleted records from source's field
if association.Error == nil {
if field.Kind() == reflect.Slice {
leftValues := reflect.Zero(field.Type())
for i := 0; i < field.Len(); i++ {
reflectValue := field.Index(i)
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var isDeleted = false
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
isDeleted = true
break
}
}
if !isDeleted {
leftValues = reflect.Append(leftValues, reflectValue)
}
}
association.field.Set(leftValues)
} else if field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
association.field.Set(reflect.Zero(field.Type()))
break
}
}
}
}
return association
} }
// Clear remove relationship between source & current associations, won't delete those associations
func (association *Association) Clear() *Association {
return association.Replace()
}
// Count return the count of current associations
func (association *Association) Count() int {
var (
count = 0
relationship = association.field.Relationship
scope = association.scope
fieldValue = association.field.Field.Interface()
query = scope.DB()
)
if relationship.Kind == "many_to_many" {
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
} else if relationship.Kind == "belongs_to" {
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
}
if relationship.PolymorphicType != "" {
query = query.Where(
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
scope.TableName(),
)
}
query.Model(fieldValue).Count(&count)
return count
}
// saveAssociations save passed values as associations
func (association *Association) saveAssociations(values ...interface{}) *Association { func (association *Association) saveAssociations(values ...interface{}) *Association {
scope := association.Scope var (
field := association.Field scope = association.scope
relationship := association.Field.Relationship field = association.field
relationship = field.Relationship
)
saveAssociation := func(reflectValue reflect.Value) { saveAssociation := func(reflectValue reflect.Value) {
// value has to been pointer // value has to been pointer
@ -94,318 +351,9 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
return association return association
} }
func (association *Association) Append(values ...interface{}) *Association { func (association *Association) setErr(err error) *Association {
if relationship := association.Field.Relationship; relationship.Kind == "has_one" { if err != nil {
return association.Replace(values...) association.Error = err
}
return association.saveAssociations(values...)
}
func (association *Association) Replace(values ...interface{}) *Association {
var (
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
newDB = scope.NewDB()
)
// Append new values
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.saveAssociations(values...)
// Belongs To
if relationship.Kind == "belongs_to" {
// Set foreign key to be null only when clearing value
if len(values) == 0 {
// Set foreign key to be nil
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
}
} else {
// Relations
if relationship.PolymorphicDBName != "" {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
// Relations except new created
if len(values) > 0 {
var newPrimaryKeys [][]interface{}
var associationForeignFieldNames []string
if relationship.Kind == "many_to_many" {
// If many to many relations, get it from foreign key
associationForeignFieldNames = relationship.AssociationForeignFieldNames
} else {
// If other relations, get real primary keys
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
if field.IsPrimaryKey {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
}
newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
}
}
if relationship.Kind == "many_to_many" {
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
} }
return association return association
} }
func (association *Association) Delete(values ...interface{}) *Association {
var (
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
newDB = scope.NewDB()
)
if len(values) == 0 {
return association
}
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
if field.IsPrimaryKey {
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
}
}
deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" {
// source value's foreign keys
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
// association value's foreign keys
deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// set foreign key to be null
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap, false)
}
} else {
association.setErr(results.Error)
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// only include those deleting relations
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
toQueryValues(deletingPrimaryKeys)...,
)
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
// Remove deleted records from field
if association.Error == nil {
if association.Field.Field.Kind() == reflect.Slice {
leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
reflectValue := association.Field.Field.Index(i)
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var included = false
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
included = true
}
}
if !included {
leftValues = reflect.Append(leftValues, reflectValue)
}
}
association.Field.Set(leftValues)
} else if association.Field.Field.Kind() == reflect.Struct {
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
break
}
}
}
}
return association
}
func (association *Association) Clear() *Association {
return association.Replace()
}
func (association *Association) Count() int {
var (
count = 0
relationship = association.Field.Relationship
scope = association.Scope
fieldValue = association.Field.Field.Interface()
newScope = scope.New(fieldValue)
)
if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := scope.DB()
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
field.Field.Interface())
}
}
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
query.Model(fieldValue).Count(&count)
} else if relationship.Kind == "belongs_to" {
query := scope.DB()
for idx, primaryKey := range relationship.AssociationForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),
field.Field.Interface())
}
}
query.Model(fieldValue).Count(&count)
}
return count
}
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
primaryKeys := []interface{}{}
newScope := scope.New(reflectValue.Index(i).Interface())
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
} else if reflectValue.Kind() == reflect.Struct {
newScope := scope.New(value)
var primaryKeys []interface{}
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
}
return
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for _, _ = range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
} else {
return strings.Join(newColumns, ",")
}
}
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
for _, primaryValue := range primaryValues {
for _, value := range primaryValue {
values = append(values, value)
}
}
return values
}

View File

@ -5,6 +5,8 @@ import (
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
"github.com/jinzhu/gorm"
) )
func TestBelongsTo(t *testing.T) { func TestBelongsTo(t *testing.T) {
@ -16,7 +18,7 @@ func TestBelongsTo(t *testing.T) {
} }
if err := DB.Save(&post).Error; err != nil { if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error()) t.Error("Got errors when save post", err)
} }
if post.Category.ID == 0 || post.MainCategory.ID == 0 { if post.Category.ID == 0 || post.MainCategory.ID == 0 {
@ -177,6 +179,49 @@ func TestBelongsTo(t *testing.T) {
} }
} }
func TestBelongsToOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
ProfileRefer int
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestBelongsToOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
ProfileID int
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "belongs_to" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasOne(t *testing.T) { func TestHasOne(t *testing.T) {
user := User{ user := User{
Name: "has one", Name: "has one",
@ -184,7 +229,7 @@ func TestHasOne(t *testing.T) {
} }
if err := DB.Save(&user).Error; err != nil { if err := DB.Save(&user).Error; err != nil {
t.Errorf("Got errors when save user", err.Error()) t.Error("Got errors when save user", err.Error())
} }
if user.CreditCard.UserId.Int64 == 0 { if user.CreditCard.UserId.Int64 == 0 {
@ -323,6 +368,49 @@ func TestHasOne(t *testing.T) {
} }
} }
func TestHasOneOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasOneOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_one" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasMany(t *testing.T) { func TestHasMany(t *testing.T) {
post := Post{ post := Post{
Title: "post has many", Title: "post has many",
@ -331,7 +419,7 @@ func TestHasMany(t *testing.T) {
} }
if err := DB.Save(&post).Error; err != nil { if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post", err.Error()) t.Error("Got errors when save post", err)
} }
for _, comment := range post.Comments { for _, comment := range post.Comments {
@ -462,6 +550,49 @@ func TestHasMany(t *testing.T) {
} }
} }
func TestHasManyOverrideForeignKey1(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestHasManyOverrideForeignKey2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
}
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
if relation.Relationship.Kind != "has_many" ||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
t.Errorf("Override belongs to foreign key with tag")
}
}
}
func TestManyToMany(t *testing.T) { func TestManyToMany(t *testing.T) {
DB.Raw("delete from languages") DB.Raw("delete from languages")
var languages = []Language{{Name: "ZH"}, {Name: "EN"}} var languages = []Language{{Name: "ZH"}, {Name: "EN"}}

View File

@ -4,34 +4,39 @@ import (
"fmt" "fmt"
) )
type callback struct { // DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
// Callback is a struct that contains all CURD callbacks
// Field `creates` contains callbacks will be call when creating object
// Field `updates` contains callbacks will be call when updating object
// Field `deletes` contains callbacks will be call when deleting object
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
creates []*func(scope *Scope) creates []*func(scope *Scope)
updates []*func(scope *Scope) updates []*func(scope *Scope)
deletes []*func(scope *Scope) deletes []*func(scope *Scope)
queries []*func(scope *Scope) queries []*func(scope *Scope)
rowQueries []*func(scope *Scope) rowQueries []*func(scope *Scope)
processors []*callbackProcessor processors []*CallbackProcessor
} }
type callbackProcessor struct { // CallbackProcessor contains callback informations
name string type CallbackProcessor struct {
before string name string // current callback's name
after string before string // register current callback before a callback
replace bool after string // register current callback after a callback
remove bool replace bool // replace callbacks with same name
typ string remove bool // delete callbacks with same name
processor *func(scope *Scope) kind string // callback type: create, update, delete, query, row_query
callback *callback processor *func(scope *Scope) // callback handler
parent *Callback
} }
func (c *callback) addProcessor(typ string) *callbackProcessor { func (c *Callback) clone() *Callback {
cp := &callbackProcessor{typ: typ, callback: c} return &Callback{
c.processors = append(c.processors, cp)
return cp
}
func (c *callback) clone() *callback {
return &callback{
creates: c.creates, creates: c.creates,
updates: c.updates, updates: c.updates,
deletes: c.deletes, deletes: c.deletes,
@ -40,57 +45,95 @@ func (c *callback) clone() *callback {
} }
} }
func (c *callback) Create() *callbackProcessor { // Create could be used to register callbacks for creating object
return c.addProcessor("create") // db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// // business logic
// ...
//
// // set error if some thing wrong happened, will rollback the creating
// scope.Err(errors.New("error"))
// })
func (c *Callback) Create() *CallbackProcessor {
return &CallbackProcessor{kind: "create", parent: c}
} }
func (c *callback) Update() *callbackProcessor { // Update could be used to register callbacks for updating object, refer `Create` for usage
return c.addProcessor("update") func (c *Callback) Update() *CallbackProcessor {
return &CallbackProcessor{kind: "update", parent: c}
} }
func (c *callback) Delete() *callbackProcessor { // Delete could be used to register callbacks for deleting object, refer `Create` for usage
return c.addProcessor("delete") func (c *Callback) Delete() *CallbackProcessor {
return &CallbackProcessor{kind: "delete", parent: c}
} }
func (c *callback) Query() *callbackProcessor { // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
return c.addProcessor("query") // Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
return &CallbackProcessor{kind: "query", parent: c}
} }
func (c *callback) RowQuery() *callbackProcessor { // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
return c.addProcessor("row_query") func (c *Callback) RowQuery() *CallbackProcessor {
return &CallbackProcessor{kind: "row_query", parent: c}
} }
func (cp *callbackProcessor) Before(name string) *callbackProcessor { // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
cp.before = name func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
return cp return cp
} }
func (cp *callbackProcessor) After(name string) *callbackProcessor { // Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
cp.after = name func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
cp.before = callbackName
return cp return cp
} }
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) { // Register a new callback, refer `Callbacks.Create`
cp.name = name func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
cp.processor = &fc cp.name = callbackName
cp.callback.sort() cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
} }
func (cp *callbackProcessor) Remove(name string) { // Remove a registered callback
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum()) // db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
cp.name = name func (cp *CallbackProcessor) Remove(callbackName string) {
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true cp.remove = true
cp.callback.sort() cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
} }
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) { // Replace a registered callback with new callback
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum()) // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
cp.name = name // scope.SetColumn("Created", now)
cp.processor = &fc // scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true cp.replace = true
cp.callback.sort() cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
} }
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
}
}
return nil
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int { func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- { for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str { if strs[i] == str {
@ -100,83 +143,77 @@ func getRIndex(strs []string, str string) int {
return -1 return -1
} }
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { // sortProcessors sort callback processors based on its before, after, remove, replace
var sortCallbackProcessor func(c *callbackProcessor) func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
var names, sortedNames = []string{}, []string{} var (
allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor)
)
for _, cp := range cps { for _, cp := range cps {
if index := getRIndex(names, cp.name); index > -1 { // show warning message the callback name already exists
if !cp.replace && !cp.remove { if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
} }
} allNames = append(allNames, cp.name)
names = append(names, cp.name)
} }
sortCallbackProcessor = func(c *callbackProcessor) { sortCallbackProcessor = func(c *CallbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 { if getRIndex(sortedNames, c.name) == -1 { // if not sorted
return if c.before != "" { // if defined before callback
} if index := getRIndex(sortedNames, c.before); index != -1 {
// if before callback already sorted, append current callback just after it
if len(c.before) > 0 {
if index := getRIndex(sortedNames, c.before); index > -1 {
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(names, c.before); index > -1 { } else if index := getRIndex(allNames, c.before); index != -1 {
// if before callback exists but haven't sorted, append current callback to last
sortedNames = append(sortedNames, c.name) sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index]) sortCallbackProcessor(cps[index])
} else {
sortedNames = append(sortedNames, c.name)
} }
} }
if len(c.after) > 0 { if c.after != "" { // if defined after callback
if index := getRIndex(sortedNames, c.after); index > -1 { if index := getRIndex(sortedNames, c.after); index != -1 {
// if after callback already sorted, append current callback just before it
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(names, c.after); index > -1 { } else if index := getRIndex(allNames, c.after); index != -1 {
// if after callback exists but haven't sorted
cp := cps[index] cp := cps[index]
if len(cp.before) == 0 { // set after callback's before callback to current callback
if cp.before == "" {
cp.before = c.name cp.before = c.name
} }
sortCallbackProcessor(cp) sortCallbackProcessor(cp)
} else {
sortedNames = append(sortedNames, c.name)
} }
} }
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 { if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name) sortedNames = append(sortedNames, c.name)
} }
} }
}
for _, cp := range cps { for _, cp := range cps {
sortCallbackProcessor(cp) sortCallbackProcessor(cp)
} }
var funcs = []*func(scope *Scope){} var sortedFuncs []*func(scope *Scope)
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames { for _, name := range sortedNames {
index := getRIndex(names, name) if index := getRIndex(allNames, name); !cps[index].remove {
if !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor) sortedFuncs = append(sortedFuncs, cps[index].processor)
} }
} }
for _, cp := range cps { return sortedFuncs
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
if !cp.remove {
funcs = append(funcs, cp.processor)
}
}
}
return append(sortedFuncs, funcs...)
} }
func (c *callback) sort() { // reorder all registered processors, and reset CURD callbacks
var creates, updates, deletes, queries, rowQueries []*callbackProcessor func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
for _, processor := range c.processors { for _, processor := range c.processors {
switch processor.typ { if processor.name != "" {
switch processor.kind {
case "create": case "create":
creates = append(creates, processor) creates = append(creates, processor)
case "update": case "update":
@ -189,6 +226,7 @@ func (c *callback) sort() {
rowQueries = append(rowQueries, processor) rowQueries = append(rowQueries, processor)
} }
} }
}
c.creates = sortProcessors(creates) c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates) c.updates = sortProcessors(updates)
@ -196,5 +234,3 @@ func (c *callback) sort() {
c.queries = sortProcessors(queries) c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries) c.rowQueries = sortProcessors(rowQueries)
} }
var DefaultCallback = &callback{processors: []*callbackProcessor{}}

View File

@ -5,12 +5,31 @@ import (
"strings" "strings"
) )
func BeforeCreate(scope *Scope) { // Define callbacks for creating
scope.CallMethodWithErrorCheck("BeforeSave") func init() {
scope.CallMethodWithErrorCheck("BeforeCreate") DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
DefaultCallback.Create().Register("gorm:create", createCallback)
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
} }
func UpdateTimeStampWhenCreate(scope *Scope) { // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeCreate")
}
}
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
now := NowFunc() now := NowFunc()
scope.SetColumn("CreatedAt", now) scope.SetColumn("CreatedAt", now)
@ -18,109 +37,108 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
} }
} }
func Create(scope *Scope) { // createCallback the callback used to insert data into database
defer scope.Trace(NowFunc()) func createCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
// set create sql defer scope.trace(NowFunc())
var sqls, columns []string
fields := scope.Fields() var (
for _, field := range fields { columns, placeholders []string
blankColumnsWithDefaultValue []string
)
for _, field := range scope.Fields() {
if scope.changeableField(field) { if scope.changeableField(field) {
if field.IsNormal { if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { if !field.IsPrimaryKey || !field.IsBlank {
if !field.IsBlank || !field.HasDefaultValue { if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName)
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else {
columns = append(columns, scope.Quote(field.DBName)) columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface())) placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
} else if field.HasDefaultValue {
var hasDefaultValueColumns []string
if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
}
hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
} }
} }
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, dbName := range relationship.ForeignDBNames { for _, foreignKey := range field.Relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(relationField.DBName)) columns = append(columns, scope.Quote(foreignField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
} }
} }
} }
} }
} }
returningKey := "*" var (
primaryField := scope.PrimaryField() returningColumn = "*"
if primaryField != nil { quotedTableName = scope.QuotedTableName()
returningKey = scope.Quote(primaryField.DBName) primaryField = scope.PrimaryField()
extraOption string
)
if str, ok := scope.Get("gorm:insert_option"); ok {
extraOption = fmt.Sprint(str)
} }
if primaryField != nil {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
if len(columns) == 0 { if len(columns) == 0 {
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", scope.Raw(fmt.Sprintf(
scope.QuotedTableName(), "INSERT INTO %v DEFAULT VALUES%v%v",
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), quotedTableName,
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
)) ))
} else { } else {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v) %v", "INSERT INTO %v (%v) VALUES (%v)%v%v",
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(columns, ","), strings.Join(columns, ","),
strings.Join(sqls, ","), strings.Join(placeholders, ","),
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
)) ))
} }
// execute create sql // execute create sql
if scope.Dialect().SupportLastInsertId() { if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
id, err := result.LastInsertId() // set rows affected count
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank { if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id)) if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
} }
} }
} }
} else { } else {
if primaryField == nil { if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1 scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
} }
} }
} }
} }
func ForceReloadAfterCreate(scope *Scope) { // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { func forceReloadAfterCreateCallback(scope *Scope) {
scope.DB().New().Select(columns.([]string)).First(scope.Value) if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
} }
} }
func AfterCreate(scope *Scope) { // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
scope.CallMethodWithErrorCheck("AfterCreate") func afterCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterSave") if !scope.HasError() {
} scope.CallMethod("AfterCreate")
}
func init() { if !scope.HasError() {
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) scope.CallMethod("AfterSave")
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate) }
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("gorm:create", Create)
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
} }

View File

@ -2,35 +2,52 @@ package gorm
import "fmt" import "fmt"
func BeforeDelete(scope *Scope) { // Define callbacks for deleting
scope.CallMethodWithErrorCheck("BeforeDelete") func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
} }
func Delete(scope *Scope) { // beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.CallMethod("BeforeDelete")
}
}
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) {
if !scope.HasError() {
var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok {
extraOption = fmt.Sprint(str)
}
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw( scope.Raw(fmt.Sprintf(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", "UPDATE %v SET deleted_at=%v%v%v",
scope.QuotedTableName(), scope.QuotedTableName(),
scope.AddToVars(NowFunc()), scope.AddToVars(NowFunc()),
scope.CombinedConditionSql(), addExtraSpaceIfExist(scope.CombinedConditionSql()),
)) addExtraSpaceIfExist(extraOption),
)).Exec()
} else { } else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) scope.Raw(fmt.Sprintf(
"DELETE FROM %v%v%v",
scope.QuotedTableName(),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
} }
scope.Exec()
} }
} }
func AfterDelete(scope *Scope) { // afterDeleteCallback will invoke `AfterDelete` method after deleting
scope.CallMethodWithErrorCheck("AfterDelete") func afterDeleteCallback(scope *Scope) {
} if !scope.HasError() {
scope.CallMethod("AfterDelete")
func init() { }
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
DefaultCallback.Delete().Register("gorm:delete", Delete)
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
} }

View File

@ -6,115 +6,89 @@ import (
"reflect" "reflect"
) )
func Query(scope *Scope) { // Define callbacks for querying
defer scope.Trace(NowFunc()) func init() {
DefaultCallback.Query().Register("gorm:query", queryCallback)
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
}
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
var ( var (
isSlice bool isSlice bool
isPtr bool isPtr bool
anyRecordFound bool results = scope.IndirectValue()
destType reflect.Type resultType reflect.Type
) )
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" { if primaryField := scope.PrimaryField(); primaryField != nil {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
} }
} }
var dest = scope.IndirectValue()
if value, ok := scope.Get("gorm:query_destination"); ok { if value, ok := scope.Get("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value)) results = reflect.Indirect(reflect.ValueOf(value))
} }
if kind := dest.Kind(); kind == reflect.Slice { if kind := results.Kind(); kind == reflect.Slice {
isSlice = true isSlice = true
destType = dest.Type().Elem() resultType = results.Type().Elem()
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) results.Set(reflect.MakeSlice(results.Type(), 0, 0))
if destType.Kind() == reflect.Ptr { if resultType.Kind() == reflect.Ptr {
isPtr = true isPtr = true
destType = destType.Elem() resultType = resultType.Elem()
} }
} else if kind != reflect.Struct { } else if kind != reflect.Struct {
scope.Err(errors.New("unsupported destination, should be slice or struct")) scope.Err(errors.New("unsupported destination, should be slice or struct"))
return return
} }
scope.prepareQuerySql() scope.prepareQuerySQL()
if !scope.HasError() { if !scope.HasError() {
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0 scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_option"); ok {
if scope.Err(err) != nil { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
return
} }
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close() defer rows.Close()
columns, _ := rows.Columns() columns, _ := rows.Columns()
for rows.Next() { for rows.Next() {
scope.db.RowsAffected++ scope.db.RowsAffected++
anyRecordFound = true elem := results
elem := dest
if isSlice { if isSlice {
elem = reflect.New(destType).Elem() elem = reflect.New(resultType).Elem()
} }
var values = make([]interface{}, len(columns)) scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap())
fields := scope.New(elem.Addr().Interface()).Fields()
for index, column := range columns {
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
}
} else {
var value interface{}
values[index] = &value
}
}
scope.Err(rows.Scan(values...))
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
}
if isSlice { if isSlice {
if isPtr { if isPtr {
dest.Set(reflect.Append(dest, elem.Addr())) results.Set(reflect.Append(results, elem.Addr()))
} else { } else {
dest.Set(reflect.Append(dest, elem)) results.Set(reflect.Append(results, elem))
} }
} }
} }
if !anyRecordFound && !isSlice { if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(RecordNotFound) scope.Err(ErrRecordNotFound)
}
} }
} }
} }
func AfterQuery(scope *Scope) { // afterQueryCallback will invoke `AfterFind` method after querying
scope.CallMethodWithErrorCheck("AfterFind") func afterQueryCallback(scope *Scope) {
} if !scope.HasError() {
scope.CallMethod("AfterFind")
func init() { }
DefaultCallback.Query().Register("gorm:query", Query)
DefaultCallback.Query().Register("gorm:preload", Preload)
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
} }

308
callback_query_preload.go Normal file
View File

@ -0,0 +1,308 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() {
return
}
var (
preloadedMap = map[string]bool{}
fields = scope.Fields()
)
for _, preload := range scope.Search.preload {
var (
preloadFields = strings.Split(preload.schema, ".")
currentScope = scope
currentFields = fields
)
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
}
for _, field := range currentFields {
if field.Name != preloadField || field.Relationship == nil {
continue
}
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
}
}
// preload next level
if idx < len(preloadFields)-1 {
currentScope = currentScope.getColumnAsScope(preloadField)
currentFields = currentScope.Fields()
}
}
}
}
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
preloadConditions []interface{}
)
for _, condition := range conditions {
if scopes, ok := condition.(func(*DB) *DB); ok {
preloadDB = scopes(preloadDB)
} else {
preloadConditions = append(preloadConditions, condition)
}
}
return preloadDB, preloadConditions
}
// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
indirectValue.FieldByName(field.Name).Set(result)
break
}
}
} else {
scope.Err(field.Set(result))
}
}
}
// handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) {
objectField := object.FieldByName(field.Name)
objectField.Set(reflect.Append(objectField, result))
break
}
}
}
} else {
scope.Err(field.Set(resultsValue))
}
}
// handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// find relations
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.Err(field.Set(result))
}
}
}
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
fieldType = field.Struct.Type.Elem()
foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
linkHash = map[string][]reflect.Value{}
isPtr bool
)
if fieldType.Kind() == reflect.Ptr {
isPtr = true
fieldType = fieldType.Elem()
}
var sourceKeys = []string{}
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Select("*")
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
// preload inline conditions
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}
rows, err := preloadDB.Rows()
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).fieldsMap()
)
// register foreign keys in join tables
for _, sourceKey := range sourceKeys {
fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()}
}
scope.scan(rows, columns, fields)
// generate hashed forkey keys in join table
var foreignKeys = make([]interface{}, len(sourceKeys))
for idx, sourceKey := range sourceKeys {
foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface()
}
hashedSourceKeys := toString(foreignKeys)
if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
}
// assign find results
var (
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string]reflect.Value{}
foreignFieldNames = []string{}
fields = scope.fieldsMap()
)
for _, dbName := range relation.ForeignFieldNames {
if field, ok := fields[dbName]; ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
}
} else if indirectScopeValue.IsValid() {
fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
}
for source, link := range linkHash {
fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...))
}
}

View File

@ -2,15 +2,15 @@ package gorm
import "reflect" import "reflect"
func BeginTransaction(scope *Scope) { func beginTransactionCallback(scope *Scope) {
scope.Begin() scope.Begin()
} }
func CommitOrRollbackTransaction(scope *Scope) { func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback() scope.CommitOrRollback()
} }
func SaveBeforeAssociations(scope *Scope) { func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() { if !scope.shouldSaveAssociations() {
return return
} }
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
} }
} }
func SaveAfterAssociations(scope *Scope) { func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() { if !scope.shouldSaveAssociations() {
return return
} }

View File

@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {} func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) { func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}} var callback = &Callback{}
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2) callback.Create().Register("before_create2", beforeCreate2)
@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) {
} }
func TestRegisterCallbackWithOrder(t *testing.T) { func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}} var callback1 = &Callback{}
callback1.Create().Register("before_create1", beforeCreate1) callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create) callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1) callback1.Create().Register("after_create1", afterCreate1)
@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &callback{processors: []*callbackProcessor{}} var callback2 = &Callback{}
callback2.Update().Register("create", create) callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1) callback2.Update().Before("create").Register("before_create1", beforeCreate1)
@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) {
} }
func TestRegisterCallbackWithComplexOrder(t *testing.T) { func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}} var callback1 = &Callback{}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1) callback1.Query().Register("before_create1", beforeCreate1)
@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &callback{processors: []*callbackProcessor{}} var callback2 = &Callback{}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1) callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
func replaceCreate(s *Scope) {} func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) { func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}} var callback = &Callback{}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)
@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) {
} }
func TestRemoveCallback(t *testing.T) { func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}} var callback = &Callback{}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)

View File

@ -5,91 +5,102 @@ import (
"strings" "strings"
) )
func AssignUpdateAttributes(scope *Scope) { // Define callbacks for updating
func init() {
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
DefaultCallback.Update().Register("gorm:update", updateCallback)
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 { if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs") if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate {
_, updateColumn := scope.Get("gorm:update_column") scope.InstanceSet("gorm:update_attrs", updateMaps)
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) } else {
if updateColumn {
scope.InstanceSet("gorm:update_attrs", maps)
} else if len(updateAttrs) > 0 {
scope.InstanceSet("gorm:update_attrs", updateAttrs)
} else if !hasUpdate {
scope.SkipLeft() scope.SkipLeft()
return
} }
} }
} }
} }
func BeforeUpdate(scope *Scope) { // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave") if !scope.HasError() {
scope.CallMethodWithErrorCheck("BeforeUpdate") scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeUpdate")
}
} }
} }
func UpdateTimeStampWhenUpdate(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc()) scope.SetColumn("UpdatedAt", NowFunc())
} }
} }
func Update(scope *Scope) { // updateCallback the callback used to update data to database
func updateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
var sqls []string var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) { for column, value := range updateAttrs.(map[string]interface{}) {
if scope.changeableDBColumn(key) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
} }
} else { } else {
fields := scope.Fields() for _, field := range scope.Fields() {
for _, field := range fields { if scope.changeableField(field) {
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { if !field.IsPrimaryKey && field.IsNormal {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, dbName := range relationship.ForeignDBNames { for _, foreignKey := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) sqls = append(sqls,
sqls = append(sqls, sql) fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
} }
} }
} }
} }
} }
}
var extraOption string
if str, ok := scope.Get("gorm:update_option"); ok {
extraOption = fmt.Sprint(str)
}
if len(sqls) > 0 { if len(sqls) > 0 {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v", "UPDATE %v SET %v%v%v",
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(sqls, ", "), strings.Join(sqls, ", "),
scope.CombinedConditionSql(), addExtraSpaceIfExist(scope.CombinedConditionSql()),
)) addExtraSpaceIfExist(extraOption),
scope.Exec() )).Exec()
} }
} }
} }
func AfterUpdate(scope *Scope) { // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate") if !scope.HasError() {
scope.CallMethodWithErrorCheck("AfterSave") scope.CallMethod("AfterUpdate")
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
}
} }
} }
func init() {
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
DefaultCallback.Update().Register("gorm:update", Update)
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}

View File

@ -1,117 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type commonDialect struct{}
func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) HasTop() bool {
return false
}
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("VARCHAR(%d)", size)
}
return "VARCHAR(65532)"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size)
}
return "BINARY(65532)"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
return count > 0
}
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
return count > 0
}
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
return count > 0
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
}
// RawScanInt scans the first column of the first row into the `scan' int pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}
// RawScanString scans the first column of the first row into the `scan' string pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
}

View File

@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.AutoMigrate(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col) t.Errorf("CustomizeColumn should have column %s", col)
} }

View File

@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
} }
}() }()
DB.HasTable("foobarbaz") if err := DB.Find(&User{}).Error; err == nil {
if DB.Error == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil") t.Errorf("Expected operation on closed db to produce an error, but err was nil")
} }
} }

View File

@ -45,7 +45,7 @@ func TestSoftDelete(t *testing.T) {
type User struct { type User struct {
Id int64 Id int64
Name string Name string
DeletedAt time.Time DeletedAt *time.Time
} }
DB.AutoMigrate(&User{}) DB.AutoMigrate(&User{})

View File

@ -1,41 +1,100 @@
package gorm package gorm
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings"
) )
// Dialect interface contains behaviors that differ across SQL database
type Dialect interface { type Dialect interface {
BinVar(i int) string // GetName get dialect's name
SupportLastInsertId() bool GetName() string
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string // SetDB set db for dialect
ReturningStr(tableName, key string) string SetDB(db *sql.DB)
SelectFromDummyTable() string
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string Quote(key string) string
HasTable(scope *Scope, tableName string) bool // DataTypeOf return data's sql type
HasColumn(scope *Scope, tableName string, columnName string) bool DataTypeOf(field *StructField) string
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string) // HasIndex check has index or not
CurrentDatabase(scope *Scope) string HasIndex(tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset int) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
} }
func NewDialect(driver string) Dialect { var dialectsMap = map[string]Dialect{}
var d Dialect
switch driver { func newDialect(name string, db *sql.DB) Dialect {
case "postgres": if value, ok := dialectsMap[name]; ok {
d = &postgres{} dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
case "foundation": dialect.SetDB(db)
d = &foundation{} return dialect
case "mysql":
d = &mysql{}
case "sqlite3":
d = &sqlite3{}
case "mssql":
d = &mssql{}
default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{}
} }
return d
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}
// RegisterDialect register new dialect
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// ParseFieldStructForDialect parse field struct for dialect
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var reflectType = field.Struct.Type
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))
// Get scanner's real value
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)
// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}
// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
} }

137
dialect_common.go Normal file
View File

@ -0,0 +1,137 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
type commonDialect struct {
db *sql.DB
}
func init() {
RegisterDialect("common", &commonDialect{})
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db *sql.DB) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit > 0 || offset > 0 {
if limit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", limit)
}
if offset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", offset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

113
dialect_mysql.go Normal file
View File

@ -0,0 +1,113 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type mysql struct {
commonDialect
}
func init() {
RegisterDialect("mysql", &mysql{})
}
func (mysql) GetName() string {
return "mysql"
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
// Get Data Type for MySQL Dialect
func (mysql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
sqlType = "bigint unsigned"
}
case reflect.Float32, reflect.Float64:
sqlType = "double"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "longtext"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
if _, ok := field.TagSettings["NOT NULL"]; ok {
sqlType = "timestamp"
} else {
sqlType = "timestamp NULL"
}
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "longblob"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}

128
dialect_postgres.go Normal file
View File

@ -0,0 +1,128 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type postgres struct {
commonDialect
}
func init() {
RegisterDialect("postgres", &postgres{})
}
func (postgres) GetName() string {
return "postgres"
}
func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (postgres) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "bigserial"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "numeric"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "timestamp with time zone"
}
case reflect.Map:
if dataValue.Type().Name() == "Hstore" {
sqlType = "hstore"
}
default:
if isByteArrayOrSlice(dataValue) {
sqlType = "bytea"
} else if isUUID(dataValue) {
sqlType = "uuid"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) currentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (postgres) SupportLastInsertID() bool {
return false
}
func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}

106
dialect_sqlite3.go Normal file
View File

@ -0,0 +1,106 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type sqlite3 struct {
commonDialect
}
func init() {
RegisterDialect("sqlite", &sqlite3{})
RegisterDialect("sqlite3", &sqlite3{})
}
func (sqlite3) GetName() string {
return "sqlite3"
}
// Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) currentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

150
dialects/mssql/mssql.go Normal file
View File

@ -0,0 +1,150 @@
package mssql
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
}
}
func init() {
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
gorm.RegisterDialect("mssql", &mssql{})
}
type mssql struct {
db *sql.DB
}
func (mssql) GetName() string {
return "mssql"
}
func (s *mssql) SetDB(db *sql.DB) {
s.db = db
}
func (mssql) BindVar(i int) string {
return "$$" // ?
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (mssql) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "float"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("nvarchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime2"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit > 0 || offset > 0 {
if offset < 0 {
offset = 0
}
sql += fmt.Sprintf(" OFFSET %d ROWS", offset)
if limit >= 0 {
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
}
}
return
}
func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

3
dialects/mysql/mysql.go Normal file
View File

@ -0,0 +1,3 @@
package mysql
import _ "github.com/go-sql-driver/mysql"

View File

@ -0,0 +1,52 @@
package postgres
import (
"database/sql"
"database/sql/driver"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
)
type Hstore map[string]*string
func (h Hstore) Value() (driver.Value, error) {
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
if len(h) == 0 {
return nil, nil
}
for key, value := range h {
var s sql.NullString
if value != nil {
s.String = *value
s.Valid = true
}
hstore.Map[key] = s
}
return hstore.Value()
}
func (h *Hstore) Scan(value interface{}) error {
hstore := hstore.Hstore{}
if err := hstore.Scan(value); err != nil {
return err
}
if len(hstore.Map) == 0 {
return nil
}
*h = Hstore{}
for k := range hstore.Map {
if hstore.Map[k].Valid {
s := hstore.Map[k].String
(*h)[k] = &s
} else {
(*h)[k] = nil
}
}
return nil
}

View File

@ -0,0 +1,3 @@
package sqlite
import _ "github.com/mattn/go-sqlite3"

View File

@ -1,68 +0,0 @@
# Gorm Development
## Architecture
The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this:
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain.
Lets use below code to explain how it works:
db.Where("name = ?", "jinzhu").Find(&users)
// equivalent code
newdb := db.Where("name =?", "jinzhu")
newdb.Find(&user)
`newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method.
`Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks.
There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks.
## Callbacks
### Register a new callback
func updateCreated(scope *Scope) {
if scope.HasColumn("Created") {
scope.SetColumn("Created", NowFunc())
}
}
db.Callback().Create().Register("update_created_at", updateCreated)
// register a callback for Create process
### Delete an existing callback
db.Callback().Create().Remove("gorm:create")
// delete callback `gorm:create` from Create callbacks
### Replace an existing callback
db.Callback().Create().Replace("gorm:create", newCreateFunction)
// replace callback `gorm:create` with new function `newCreateFunction` for Create process
### Register callback orders
db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate)
### Callback API
Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks
[Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API

View File

@ -6,25 +6,31 @@ import (
) )
var ( var (
RecordNotFound = errors.New("record not found") // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
InvalidSql = errors.New("invalid sql") ErrRecordNotFound = errors.New("record not found")
NoNewAttrs = errors.New("no new attributes") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
NoValidTransaction = errors.New("no valid transaction") ErrInvalidSQL = errors.New("invalid SQL")
CantStartTransaction = errors.New("can't start transaction") // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
ErrCantStartTransaction = errors.New("can't start transaction")
) )
type errorsInterface interface { type errorsInterface interface {
GetErrors() []error GetErrors() []error
} }
// Errors contains all happened errors
type Errors struct { type Errors struct {
errors []error errors []error
} }
// GetErrors get all happened errors
func (errs Errors) GetErrors() []error { func (errs Errors) GetErrors() []error {
return errs.errors return errs.errors
} }
// Add add an error
func (errs *Errors) Add(err error) { func (errs *Errors) Add(err error) {
if errors, ok := err.(errorsInterface); ok { if errors, ok := err.(errorsInterface); ok {
for _, err := range errors.GetErrors() { for _, err := range errors.GetErrors() {
@ -40,6 +46,7 @@ func (errs *Errors) Add(err error) {
} }
} }
// Error format happened errors
func (errs Errors) Error() string { func (errs Errors) Error() string {
var errors = []string{} var errors = []string{}
for _, e := range errs.errors { for _, e := range errs.errors {

View File

@ -7,12 +7,14 @@ import (
"reflect" "reflect"
) )
// Field model field definition
type Field struct { type Field struct {
*StructField *StructField
IsBlank bool IsBlank bool
Field reflect.Value Field reflect.Value
} }
// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) { func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() { if !field.Field.IsValid() {
return errors.New("field value not valid") return errors.New("field value not valid")
@ -56,35 +58,34 @@ func (field *Field) Set(value interface{}) (err error) {
} }
// Fields get value's fields // Fields get value's fields
func (scope *Scope) Fields() map[string]*Field { func (scope *Scope) Fields() []*Field {
if scope.fields == nil { var (
fields := map[string]*Field{} fields []*Field
modelStruct := scope.GetModelStruct() indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
indirectValue := scope.IndirectValue() for _, structField := range scope.GetModelStruct().StructFields {
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range modelStruct.StructFields {
if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
if isStruct { if isStruct {
fields[structField.DBName] = getField(indirectValue, structField) fieldValue := indirectScopeValue
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
}
}
scope.fields = fields
return fields
}
return scope.fields
}
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names { for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name) fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
} }
field.Field = indirectValue fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
field.IsBlank = isBlank(indirectValue) } else {
return field fields = append(fields, &Field{StructField: structField, IsBlank: true})
}
}
return fields
}
func (scope *Scope) fieldsMap() map[string]*Field {
var results = map[string]*Field{}
for _, field := range scope.Fields() {
if field.IsNormal {
results[field.DBName] = field
}
}
return results
} }

View File

@ -32,12 +32,16 @@ type CalculateFieldCategory struct {
func TestCalculateField(t *testing.T) { func TestCalculateField(t *testing.T) {
var field CalculateField var field CalculateField
fields := DB.NewScope(&field).Fields() var scope = DB.NewScope(&field)
if fields["children"].Relationship == nil || fields["category"].Relationship == nil { if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
t.Errorf("Should calculate fields correctly for the first time") t.Errorf("Should calculate fields correctly for the first time")
} }
if field, ok := fields["embedded_name"]; !ok { if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
t.Errorf("Should calculate fields correctly for the first time")
}
if field, ok := scope.FieldByName("embedded_name"); !ok {
t.Errorf("should find embedded field") t.Errorf("should find embedded field")
} else if _, ok := field.TagSettings["NOT NULL"]; !ok { } else if _, ok := field.TagSettings["NOT NULL"]; !ok {
t.Errorf("should find embedded field's tag settings") t.Errorf("should find embedded field's tag settings")

View File

@ -1,83 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type foundation struct {
commonDialect
}
func (foundation) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (foundation) SupportLastInsertId() bool {
return false
}
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "clob"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
}
func (s foundation) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (s foundation) HasTable(scope *Scope, tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName)
return count > 0
}
func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName)
return count > 0
}
func (s foundation) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName)))
}
func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName)
return count > 0
}
func (s foundation) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA")
return
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 65 KiB

View File

@ -7,40 +7,54 @@ import (
"strings" "strings"
) )
// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface { type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table(db *DB) string Table(db *DB) string
// Add create relationship in join table for source and destination
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
// Delete delete relationship in join table for sources
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey
} }
// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct { type JoinTableForeignKey struct {
DBName string DBName string
AssociationDBName string AssociationDBName string
} }
// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct { type JoinTableSource struct {
ModelType reflect.Type ModelType reflect.Type
ForeignKeys []JoinTableForeignKey ForeignKeys []JoinTableForeignKey
} }
// JoinTableHandler default join table handler
type JoinTableHandler struct { type JoinTableHandler struct {
TableName string `sql:"-"` TableName string `sql:"-"`
Source JoinTableSource `sql:"-"` Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"` Destination JoinTableSource `sql:"-"`
} }
// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys return s.Source.ForeignKeys
} }
// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys return s.Destination.ForeignKeys
} }
// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName s.TableName = tableName
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
} }
} }
// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string { func (s JoinTableHandler) Table(db *DB) string {
return s.TableName return s.TableName
} }
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{} values := map[string]interface{}{}
for _, source := range sources { for _, source := range sources {
@ -74,20 +89,25 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
if s.Source.ModelType == modelType { if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys { for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
} }
} else if s.Destination.ModelType == modelType { } else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys { for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
} }
} }
} }
return values return values
} }
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { // Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("") scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2) searchMap := s.getSearchMap(db, source, destination)
var assignColumns, binVars, conditions []string var assignColumns, binVars, conditions []string
var values []interface{} var values []interface{}
@ -116,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
return db.Exec(sql, values...).Error return db.Exec(sql, values...).Error
} }
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var ( var (
scope = db.NewScope(nil) scope = db.NewScope(nil)
@ -123,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
values []interface{} values []interface{}
) )
for key, value := range s.GetSearchMap(db, sources...) { for key, value := range s.getSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value) values = append(values, value)
} }
@ -131,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
} }
// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var ( var (
scope = db.NewScope(source) scope = db.NewScope(source)
@ -151,10 +173,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
for _, foreignKey := range s.Source.ForeignKeys { for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName) foreignDBNames = append(foreignDBNames, foreignKey.DBName)
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
} }
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
var condString string var condString string
if len(foreignFieldValues) > 0 { if len(foreignFieldValues) > 0 {
@ -165,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames) keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys)) values = append(values, toQueryValues(keys))
} else { } else {
condString = fmt.Sprintf("1 <> 1") condString = fmt.Sprintf("1 <> 1")
@ -173,8 +197,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...) Where(condString, toQueryValues(foreignFieldValues)...)
} else { }
db.Error = errors.New("wrong source type for join table handler") db.Error = errors.New("wrong source type for join table handler")
return db return db
}
} }

View File

@ -18,7 +18,7 @@ type PersonAddress struct {
gorm.JoinTableHandler gorm.JoinTableHandler
PersonID int PersonID int
AddressID int AddressID int
DeletedAt time.Time DeletedAt *time.Time
CreatedAt time.Time CreatedAt time.Time
} }

View File

@ -8,25 +8,28 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"time" "time"
"unicode"
)
var (
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
) )
type logger interface { type logger interface {
Print(v ...interface{}) Print(v ...interface{})
} }
type LogWriter interface { type logWriter interface {
Println(v ...interface{}) Println(v ...interface{})
} }
// Logger default logger
type Logger struct { type Logger struct {
LogWriter logWriter
} }
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} // Print format & print log
// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
func (logger Logger) Print(values ...interface{}) { func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 { if len(values) > 1 {
level := values[0] level := values[0]
@ -38,29 +41,44 @@ func (logger Logger) Print(values ...interface{}) {
// duration // duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql // sql
var formatedValues []interface{} var sql string
var formattedValues []string
for _, value := range values[4].([]interface{}) { for _, value := range values[4].([]interface{}) {
indirectValue := reflect.Indirect(reflect.ValueOf(value)) indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() { if indirectValue.IsValid() {
value = indirectValue.Interface() value = indirectValue.Interface()
if t, ok := value.(time.Time); ok { if t, ok := value.(time.Time); ok {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
} else if b, ok := value.([]byte); ok { } else if b, ok := value.([]byte); ok {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b))) if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
} else {
formattedValues = append(formattedValues, "'<binary>'")
}
} else if r, ok := value.(driver.Valuer); ok { } else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil { if value, err := r.Value(); err == nil && value != nil {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} else { } else {
formatedValues = append(formatedValues, "NULL") formattedValues = append(formattedValues, "NULL")
} }
} else { } else {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} }
} else { } else {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} }
} }
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
var formattedValuesLength = len(formattedValues)
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
sql += value
if index < formattedValuesLength {
sql += formattedValues[index]
}
}
messages = append(messages, sql)
} else { } else {
messages = append(messages, "\033[31;1m") messages = append(messages, "\033[31;1m")
messages = append(messages, values[2:]...) messages = append(messages, values[2:]...)
@ -69,3 +87,12 @@ func (logger Logger) Print(values ...interface{}) {
logger.Println(messages...) logger.Println(messages...)
} }
} }
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}

279
main.go
View File

@ -6,24 +6,14 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
) )
// NowFunc returns current time, this function is exported in order to be able // DB contains information for current db connection
// to give the flexibility to the developer to customize it according to their
// needs
//
// e.g: return time.Now().UTC()
//
var NowFunc = func() time.Time {
return time.Now()
}
type DB struct { type DB struct {
Value interface{} Value interface{}
Error error Error error
RowsAffected int64 RowsAffected int64
callback *callback callbacks *Callback
db sqlCommon db sqlCommon
parent *DB parent *DB
search *search search *search
@ -36,7 +26,18 @@ type DB struct {
joinTableHandlers map[string]JoinTableHandler joinTableHandlers map[string]JoinTableHandler
} }
func Open(dialect string, args ...interface{}) (DB, error) { // Open initialize a new db connection, need to import driver first, e.g:
//
// import _ "github.com/go-sql-driver/mysql"
// func main() {
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// }
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
// import _ "github.com/jinzhu/gorm/dialects/mysql"
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (*DB, error) {
var db DB var db DB
var err error var err error
@ -44,7 +45,7 @@ func Open(dialect string, args ...interface{}) (DB, error) {
err = errors.New("invalid database source") err = errors.New("invalid database source")
} else { } else {
var source string var source string
var dbSql sqlCommon var dbSQL sqlCommon
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -55,22 +56,19 @@ func Open(dialect string, args ...interface{}) (DB, error) {
driver = value driver = value
source = args[1].(string) source = args[1].(string)
} }
if driver == "foundation" { dbSQL, err = sql.Open(driver, source)
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
}
dbSql, err = sql.Open(driver, source)
case sqlCommon: case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSql = value dbSQL = value
} }
db = DB{ db = DB{
dialect: NewDialect(dialect), dialect: newDialect(dialect, dbSQL.(*sql.DB)),
logger: defaultLogger, logger: defaultLogger,
callback: DefaultCallback, callbacks: DefaultCallback,
source: source, source: source,
values: map[string]interface{}{}, values: map[string]interface{}{},
db: dbSql, db: dbSQL,
} }
db.parent = &db db.parent = &db
@ -79,17 +77,20 @@ func Open(dialect string, args ...interface{}) (DB, error) {
} }
} }
return db, err return &db, err
} }
// Close close current db connection
func (s *DB) Close() error { func (s *DB) Close() error {
return s.parent.db.(*sql.DB).Close() return s.parent.db.(*sql.DB).Close()
} }
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB { func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) return s.db.(*sql.DB)
} }
// New clone a new db connection without search conditions
func (s *DB) New() *DB { func (s *DB) New() *DB {
clone := s.clone() clone := s.clone()
clone.search = nil clone.search = nil
@ -97,29 +98,32 @@ func (s *DB) New() *DB {
return clone return clone
} }
// NewScope create scope for callbacks, including DB's search information // NewScope create a scope for current operation
func (db *DB) NewScope(value interface{}) *Scope { func (s *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone() dbClone := s.clone()
dbClone.Value = value dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
} }
// CommonDB Return the underlying sql.DB or sql.Tx instance. // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
// Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon { func (s *DB) CommonDB() sqlCommon {
return s.db return s.db
} }
func (s *DB) Callback() *callback { // Callback return `Callbacks` container, you could add/change/delete callbacks with it
s.parent.callback = s.parent.callback.clone() // db.Callback().Create().Register("update_created_at", updateCreated)
return s.parent.callback // Refer https://jinzhu.github.io/gorm/development.html#callbacks
func (s *DB) Callback() *Callback {
s.parent.callbacks = s.parent.callbacks.clone()
return s.parent.callbacks
} }
func (s *DB) SetLogger(l logger) { // SetLogger replace default logger
s.logger = l func (s *DB) SetLogger(log logger) {
s.logger = log
} }
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
func (s *DB) LogMode(enable bool) *DB { func (s *DB) LogMode(enable bool) *DB {
if enable { if enable {
s.logMode = 2 s.logMode = 2
@ -129,55 +133,82 @@ func (s *DB) LogMode(enable bool) *DB {
return s return s
} }
// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) { func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap() modelStructsMap = newModelStructsMap()
s.parent.singularTable = enable s.parent.singularTable = enable
} }
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB { func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db return s.clone().search.Where(query, args...).db
} }
// Or filter records that match before conditions or this one, similar to `Where`
func (s *DB) Or(query interface{}, args ...interface{}) *DB { func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.Or(query, args...).db return s.clone().search.Or(query, args...).db
} }
// Not filter records that don't match current conditions, similar to `Where`
func (s *DB) Not(query interface{}, args ...interface{}) *DB { func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db return s.clone().search.Not(query, args...).db
} }
func (s *DB) Limit(value interface{}) *DB { // Limit specify the number of records to be retrieved
return s.clone().search.Limit(value).db func (s *DB) Limit(limit int) *DB {
return s.clone().search.Limit(limit).db
} }
func (s *DB) Offset(value interface{}) *DB { // Offset specify the number of records to skip before starting to return the records
return s.clone().search.Offset(value).db func (s *DB) Offset(offset int) *DB {
return s.clone().search.Offset(offset).db
} }
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
func (s *DB) Order(value string, reorder ...bool) *DB { func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.Order(value, reorder...).db return s.clone().search.Order(value, reorder...).db
} }
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
// When creating/updating, specify fields that you want to save to database
func (s *DB) Select(query interface{}, args ...interface{}) *DB { func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.Select(query, args...).db return s.clone().search.Select(query, args...).db
} }
// Omit specify fields that you want to ignore when saving to database for creating, updating
func (s *DB) Omit(columns ...string) *DB { func (s *DB) Omit(columns ...string) *DB {
return s.clone().search.Omit(columns...).db return s.clone().search.Omit(columns...).db
} }
// Group specify the group method on the find
func (s *DB) Group(query string) *DB { func (s *DB) Group(query string) *DB {
return s.clone().search.Group(query).db return s.clone().search.Group(query).db
} }
// Having specify HAVING conditions for GROUP BY
func (s *DB) Having(query string, values ...interface{}) *DB { func (s *DB) Having(query string, values ...interface{}) *DB {
return s.clone().search.Having(query, values...).db return s.clone().search.Having(query, values...).db
} }
func (s *DB) Joins(query string) *DB { // Joins specify Joins conditions
return s.clone().search.Joins(query).db // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (s *DB) Joins(query string, args ...interface{}) *DB {
return s.clone().search.Joins(query, args...).db
} }
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/curd.html#scopes
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs { for _, f := range funcs {
s = f(s) s = f(s)
@ -185,60 +216,91 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
return s return s
} }
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete
func (s *DB) Unscoped() *DB { func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db return s.clone().search.unscoped().db
} }
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
func (s *DB) Attrs(attrs ...interface{}) *DB { func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.Attrs(attrs...).db return s.clone().search.Attrs(attrs...).db
} }
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) Assign(attrs ...interface{}) *DB {
return s.clone().search.Assign(attrs...).db return s.clone().search.Assign(attrs...).db
} }
// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC"). return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
} }
// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB { func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC"). return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
} }
// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB { func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
} }
// Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) *DB {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
} }
// Row return `*sql.Row` with given conditions
func (s *DB) Row() *sql.Row { func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row() return s.NewScope(s.Value).row()
} }
// Rows return `*sql.Rows` with given conditions
func (s *DB) Rows() (*sql.Rows, error) { func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows() return s.NewScope(s.Value).rows()
} }
// ScanRows scan `*sql.Rows` to give struct
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
var (
clone = s.clone()
scope = clone.NewScope(result)
columns, err = rows.Columns()
)
if clone.AddError(err) == nil {
scope.scan(rows, columns, scope.fieldsMap())
}
return clone.Error
}
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (s *DB) Pluck(column string, value interface{}) *DB { func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db return s.NewScope(s.Value).pluck(column, value).db
} }
// Count get how many records for a model
func (s *DB) Count(value interface{}) *DB { func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db return s.NewScope(s.Value).count(value).db
} }
// Related get related associations
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
} }
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/curd.html#firstorinit
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
if result := c.First(out, where...); result.Error != nil { if result := c.First(out, where...); result.Error != nil {
@ -247,82 +309,100 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
} }
c.NewScope(out).inlineCondition(where...).initialize() c.NewScope(out).inlineCondition(where...).initialize()
} else { } else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs))
} }
return c return c
} }
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/curd.html#firstorcreate
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
if result := c.First(out, where...); result.Error != nil { if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error)
} else if len(c.search.assignAttrs) > 0 { } else if len(c.search.assignAttrs) > 0 {
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error)
} }
return c return c
} }
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
func (s *DB) Update(attrs ...interface{}) *DB { func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true) return s.Updates(toSearchableMap(attrs...), true)
} }
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().NewScope(s.Value). return s.clone().NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values). InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callbacks.updates).db
} }
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
return s.UpdateColumns(toSearchableMap(attrs...)) return s.UpdateColumns(toSearchableMap(attrs...))
} }
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
func (s *DB) UpdateColumns(values interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) *DB {
return s.clone().NewScope(s.Value). return s.clone().NewScope(s.Value).
Set("gorm:update_column", true). Set("gorm:update_column", true).
Set("gorm:save_associations", false). Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values). InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callbacks.updates).db
} }
// Save update value in database, if the value doesn't have primary key, will insert it
func (s *DB) Save(value interface{}) *DB { func (s *DB) Save(value interface{}) *DB {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() { if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db return scope.callCallbacks(s.parent.callbacks.creates).db
} }
return scope.callCallbacks(s.parent.callback.updates).db return scope.callCallbacks(s.parent.callbacks.updates).db
} }
// Create insert the value into database
func (s *DB) Create(value interface{}) *DB { func (s *DB) Create(value interface{}) *DB {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
return scope.callCallbacks(s.parent.callback.creates).db return scope.callCallbacks(s.parent.callbacks.creates).db
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (s *DB) Delete(value interface{}, where ...interface{}) *DB { func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
} }
// Raw use raw sql as conditions, won't run it unless invoked by other methods
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
func (s *DB) Raw(sql string, values ...interface{}) *DB { func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.Raw(true).Where(sql, values...).db return s.clone().search.Raw(true).Where(sql, values...).db
} }
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB { func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.clone().NewScope(nil) scope := s.clone().NewScope(nil)
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSql) scope.Raw(generatedSQL)
return scope.Exec().db return scope.Exec().db
} }
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (s *DB) Model(value interface{}) *DB { func (s *DB) Model(value interface{}) *DB {
c := s.clone() c := s.clone()
c.Value = value c.Value = value
return c return c
} }
// Table specify the table you would like to run db operations
func (s *DB) Table(name string) *DB { func (s *DB) Table(name string) *DB {
clone := s.clone() clone := s.clone()
clone.search.Table(name) clone.search.Table(name)
@ -330,10 +410,12 @@ func (s *DB) Table(name string) *DB {
return clone return clone
} }
// Debug start debug mode
func (s *DB) Debug() *DB { func (s *DB) Debug() *DB {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
// Begin begin a transaction
func (s *DB) Begin() *DB { func (s *DB) Begin() *DB {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
@ -341,46 +423,56 @@ func (s *DB) Begin() *DB {
c.db = interface{}(tx).(sqlCommon) c.db = interface{}(tx).(sqlCommon)
c.AddError(err) c.AddError(err)
} else { } else {
c.AddError(CantStartTransaction) c.AddError(ErrCantStartTransaction)
} }
return c return c
} }
// Commit commit a transaction
func (s *DB) Commit() *DB { func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Commit()) s.AddError(db.Commit())
} else { } else {
s.AddError(NoValidTransaction) s.AddError(ErrInvalidTransaction)
} }
return s return s
} }
// Rollback rollback a transaction
func (s *DB) Rollback() *DB { func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Rollback()) s.AddError(db.Rollback())
} else { } else {
s.AddError(NoValidTransaction) s.AddError(ErrInvalidTransaction)
} }
return s return s
} }
// NewRecord check if value's primary key is blank
func (s *DB) NewRecord(value interface{}) bool { func (s *DB) NewRecord(value interface{}) bool {
return s.clone().NewScope(value).PrimaryKeyZero() return s.clone().NewScope(value).PrimaryKeyZero()
} }
// RecordNotFound check if returning ErrRecordNotFound error
func (s *DB) RecordNotFound() bool { func (s *DB) RecordNotFound() bool {
return s.Error == RecordNotFound for _, err := range s.GetErrors() {
if err == ErrRecordNotFound {
return true
}
}
return false
} }
// Migrations // CreateTable create table for models
func (s *DB) CreateTable(values ...interface{}) *DB { func (s *DB) CreateTable(models ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, model := range models {
db = db.NewScope(value).createTable().db db = db.NewScope(model).createTable().db
} }
return db return db
} }
// DropTable drop table for models
func (s *DB) DropTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
@ -393,18 +485,18 @@ func (s *DB) DropTable(values ...interface{}) *DB {
return db return db
} }
// DropTableIfExists drop table if it is exist
func (s *DB) DropTableIfExists(values ...interface{}) *DB { func (s *DB) DropTableIfExists(values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
if tableName, ok := value.(string); ok { if s.HasTable(value) {
db = db.Table(tableName) db.AddError(s.DropTable(value).Error)
} }
db = db.NewScope(value).dropTableIfExists().db
} }
return db return db
} }
// HasTable check has table or not
func (s *DB) HasTable(value interface{}) bool { func (s *DB) HasTable(value interface{}) bool {
var ( var (
scope = s.clone().NewScope(value) scope = s.clone().NewScope(value)
@ -417,69 +509,64 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName() tableName = scope.TableName()
} }
has := scope.Dialect().HasTable(scope, tableName) has := scope.Dialect().HasTable(tableName)
s.AddError(scope.db.Error) s.AddError(scope.db.Error)
return has return has
} }
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB { func (s *DB) AutoMigrate(values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
db = db.NewScope(value).NeedPtr().autoMigrate().db db = db.NewScope(value).autoMigrate().db
} }
return db return db
} }
// ModifyColumn modify column to type
func (s *DB) ModifyColumn(column string, typ string) *DB { func (s *DB) ModifyColumn(column string, typ string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.modifyColumn(column, typ) scope.modifyColumn(column, typ)
return scope.db return scope.db
} }
// DropColumn drop a column
func (s *DB) DropColumn(column string) *DB { func (s *DB) DropColumn(column string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.dropColumn(column) scope.dropColumn(column)
return scope.db return scope.db
} }
func (s *DB) AddIndex(indexName string, column ...string) *DB { // AddIndex add index for columns with given name
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value) scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, column...) scope.addIndex(false, indexName, columns...)
return scope.db return scope.db
} }
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { // AddUniqueIndex add unique index for columns with given name
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.addIndex(true, indexName, column...) scope.addIndex(true, indexName, columns...)
return scope.db return scope.db
} }
// RemoveIndex remove index with name
func (s *DB) RemoveIndex(indexName string) *DB { func (s *DB) RemoveIndex(indexName string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.removeIndex(indexName) scope.removeIndex(indexName)
return scope.db return scope.db
} }
func (s *DB) CurrentDatabase() string { // AddForeignKey Add foreign key to the given scope, e.g:
var ( // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
scope = s.clone().NewScope(s.Value)
name = s.dialect.CurrentDatabase(scope)
)
return name
}
/*
Add foreign key to the given scope
Example:
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
*/
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate) scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db return scope.db
} }
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
func (s *DB) Association(column string) *Association { func (s *DB) Association(column string) *Association {
var err error var err error
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
@ -491,7 +578,7 @@ func (s *DB) Association(column string) *Association {
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else { } else {
return &Association{Scope: scope, Column: column, Field: field} return &Association{scope: scope, column: column, field: field}
} }
} else { } else {
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
@ -501,26 +588,30 @@ func (s *DB) Association(column string) *Association {
return &Association{Error: err} return &Association{Error: err}
} }
// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (s *DB) Preload(column string, conditions ...interface{}) *DB { func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.Preload(column, conditions...).db return s.clone().search.Preload(column, conditions...).db
} }
// Set set value by name // Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
func (s *DB) Set(name string, value interface{}) *DB { func (s *DB) Set(name string, value interface{}) *DB {
return s.clone().InstantSet(name, value) return s.clone().InstantSet(name, value)
} }
// InstantSet instant set setting, will affect current db
func (s *DB) InstantSet(name string, value interface{}) *DB { func (s *DB) InstantSet(name string, value interface{}) *DB {
s.values[name] = value s.values[name] = value
return s return s
} }
// Get get value by name // Get get setting by name
func (s *DB) Get(name string) (value interface{}, ok bool) { func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values[name] value, ok = s.values[name]
return return
} }
// SetJoinTableHandler set a model's join table handler for a relation
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
scope := s.NewScope(source) scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields { for _, field := range scope.GetModelStruct().StructFields {
@ -530,7 +621,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination) handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { if table := handler.Table(s); scope.Dialect().HasTable(table) {
s.Table(table).AutoMigrate(handler) s.Table(table).AutoMigrate(handler)
} }
} }
@ -538,9 +629,10 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
} }
} }
// AddError add error to the db
func (s *DB) AddError(err error) error { func (s *DB) AddError(err error) error {
if err != nil { if err != nil {
if err != RecordNotFound { if err != ErrRecordNotFound {
if s.logMode == 0 { if s.logMode == 0 {
go s.print(fileWithLineNum(), err) go s.print(fileWithLineNum(), err)
} else { } else {
@ -559,6 +651,7 @@ func (s *DB) AddError(err error) error {
return err return err
} }
// GetErrors get happened errors from the db
func (s *DB) GetErrors() (errors []error) { func (s *DB) GetErrors() (errors []error) {
if errs, ok := s.Error.(errorsInterface); ok { if errs, ok := s.Error.(errorsInterface); ok {
return errs.GetErrors() return errs.GetErrors()

View File

@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
} }
if s.search == nil { if s.search == nil {
db.search = &search{} db.search = &search{limit: -1, offset: -1}
} else { } else {
db.search = s.search.clone() db.search = s.search.clone()
} }

View File

@ -4,23 +4,23 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"strconv"
_ "github.com/denisenkom/go-mssqldb"
testdb "github.com/erikstmartin/go-testdb"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/now"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"os" "os"
"reflect"
"strconv"
"testing" "testing"
"time" "time"
"github.com/erikstmartin/go-testdb"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
"github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/jinzhu/now"
) )
var ( var (
DB gorm.DB DB *gorm.DB
t1, t2, t3, t4, t5 time.Time t1, t2, t3, t4, t5 time.Time
) )
@ -42,7 +42,7 @@ func init() {
runMigration() runMigration()
} }
func OpenTestConnection() (db gorm.DB, err error) { func OpenTestConnection() (db *gorm.DB, err error) {
switch os.Getenv("GORM_DIALECT") { switch os.Getenv("GORM_DIALECT") {
case "mysql": case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
@ -115,7 +115,7 @@ func TestSetTable(t *testing.T) {
DB.Create(getPreparedUser("pluck_user3", "pluck_user")) DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error()) t.Error("No errors should happen if set table for pluck", err)
} }
var users []User var users []User
@ -376,7 +376,7 @@ func TestRows(t *testing.T) {
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
if err != nil { if err != nil {
t.Errorf("Not error should happen, but got") t.Errorf("Not error should happen, got %v", err)
} }
count := 0 count := 0
@ -386,8 +386,39 @@ func TestRows(t *testing.T) {
rows.Scan(&name, &age) rows.Scan(&name, &age)
count++ count++
} }
if count != 2 { if count != 2 {
t.Errorf("Should found two records with name 3") t.Errorf("Should found two records")
}
}
func TestScanRows(t *testing.T) {
user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
if err != nil {
t.Errorf("Not error should happen, got %v", err)
}
type Result struct {
Name string
Age int
}
var results []Result
for rows.Next() {
var result Result
if err := DB.ScanRows(rows, &result); err != nil {
t.Errorf("should get no error, but got %v", err)
}
results = append(results, result)
}
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results")
} }
} }
@ -448,7 +479,7 @@ func TestRaw(t *testing.T) {
} }
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
t.Error("Raw sql to update records") t.Error("Raw sql to update records")
} }
} }
@ -470,14 +501,33 @@ func TestGroup(t *testing.T) {
func TestJoins(t *testing.T) { func TestJoins(t *testing.T) {
var user = User{ var user = User{
Name: "joins", Name: "joins",
CreditCard: CreditCard{Number: "411111111111"},
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
} }
DB.Save(&user) DB.Save(&user)
var result User var users1 []User
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
if result.Name != "joins" || result.Id != user.Id { if len(users1) != 2 {
t.Errorf("Should find all two emails with Join") t.Errorf("should find two users using left join")
}
var users2 []User
DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
if len(users2) != 1 {
t.Errorf("should find one users using left join with conditions")
}
var users3 []User
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
if len(users3) != 1 {
t.Errorf("should find one users using multiple left join conditions")
}
var users4 []User
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
if len(users4) != 0 {
t.Errorf("should find no user when searching with unexisting credit card")
} }
} }
@ -557,7 +607,7 @@ func TestTimeWithZone(t *testing.T) {
DB.First(&findUser, "name = ?", name) DB.First(&findUser, "name = ?", name)
foundBirthday = findUser.Birthday.UTC().Format(format) foundBirthday = findUser.Birthday.UTC().Format(format)
if foundBirthday != expectedBirthday { if foundBirthday != expectedBirthday {
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
} }
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
@ -573,7 +623,7 @@ func TestTimeWithZone(t *testing.T) {
func TestHstore(t *testing.T) { func TestHstore(t *testing.T) {
type Details struct { type Details struct {
Id int64 Id int64
Bulk gorm.Hstore Bulk postgres.Hstore
} }
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
@ -659,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
} }
var user User var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
t.Errorf("Should have found existing record") t.Errorf("Should have found existing record")
} }
} }

View File

@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
} }
scope := DB.NewScope(&Email{}) scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email") t.Errorf("Email should have index idx_email_email")
} }
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted") t.Errorf("Email's index idx_email_email should be deleted")
} }
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
scope := DB.NewScope(&BigEmail{}) scope := DB.NewScope(&BigEmail{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }

View File

@ -2,6 +2,10 @@ package gorm
import "time" import "time"
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
// type User struct {
// gorm.Model
// }
type Model struct { type Model struct {
ID uint `gorm:"primary_key"` ID uint `gorm:"primary_key"`
CreatedAt time.Time CreatedAt time.Time

View File

@ -3,10 +3,8 @@ package gorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"go/ast" "go/ast"
"reflect" "reflect"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -14,6 +12,7 @@ import (
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
) )
// DefaultTableNameHandler default table name handler
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName return defaultTableName
} }
@ -41,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
var modelStructsMap = newModelStructsMap() var modelStructsMap = newModelStructsMap()
// ModelStruct model definition
type ModelStruct struct { type ModelStruct struct {
PrimaryFields []*StructField PrimaryFields []*StructField
StructFields []*StructField StructFields []*StructField
@ -48,10 +48,12 @@ type ModelStruct struct {
defaultTableName string defaultTableName string
} }
// TableName get model's table name
func (s *ModelStruct) TableName(db *DB) string { func (s *ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName) return DefaultTableNameHandler(db, s.defaultTableName)
} }
// StructField model field's struct definition
type StructField struct { type StructField struct {
DBName string DBName string
Name string Name string
@ -107,7 +109,7 @@ func getForeignField(column string, fields []*StructField) *StructField {
return nil return nil
} }
// GetModelStruct generate model struct & relationships based on struct and tag definition // GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
// Scope value can't be nil // Scope value can't be nil
@ -296,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(associationForeignKeys) == 0 { if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys { for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) { if strings.HasPrefix(foreignKey, associationType) {
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
} }
} }
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
@ -389,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(associationForeignKeys) == 0 { if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys { for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) { if strings.HasPrefix(foreignKey, associationType) {
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
} }
} }
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
@ -445,7 +453,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(associationForeignKeys) == 0 { if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys { for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, field.Name) { if strings.HasPrefix(foreignKey, field.Name) {
associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name)) associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
} }
} }
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
@ -508,63 +519,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
return &modelStruct return &modelStruct
} }
// GetStructFields get model's field structs
func (scope *Scope) GetStructFields() (fields []*StructField) { func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields return scope.GetModelStruct().StructFields
} }
func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))
if value, ok := field.TagSettings["TYPE"]; ok {
sqlType = value
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}
if sqlType == "" {
var size = 255
if value, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
autoIncrease = true
}
if v == "FALSE" {
autoIncrease = false
}
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
} else {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
}
func parseTagSetting(tags reflect.StructTag) map[string]string { func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{} setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {

View File

@ -1,80 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type mssql struct {
commonDialect
}
func (mssql) HasTop() bool {
return true
}
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int IDENTITY(1,1)"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint IDENTITY(1,1)"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "float"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("nvarchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime2"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = s.CurrentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
return count > 0
}
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = s.CurrentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
return count > 0
}
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
return count > 0
}
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
return
}

View File

@ -21,7 +21,7 @@ type Tag struct {
ID uint `gorm:"primary_key"` ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"` Locale string `gorm:"primary_key"`
Value string Value string
Blogs []*Blog `gorm:"many2many:"blogs_tags` Blogs []*Blog `gorm:"many2many:blogs_tags"`
} }
func compareTags(tags []Tag, contents []string) bool { func compareTags(tags []Tag, contents []string) bool {

View File

@ -1,70 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type mysql struct {
commonDialect
}
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "longtext"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
return
}

View File

@ -39,46 +39,46 @@ func TestPointerFields(t *testing.T) {
var nilPointerStruct = PointerStruct{} var nilPointerStruct = PointerStruct{}
if err := DB.Create(&nilPointerStruct).Error; err != nil { if err := DB.Create(&nilPointerStruct).Error; err != nil {
t.Errorf("Failed to save nil pointer struct", err) t.Error("Failed to save nil pointer struct", err)
} }
var pointerStruct2 PointerStruct var pointerStruct2 PointerStruct
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
t.Errorf("Failed to query saved nil pointer struct", err) t.Error("Failed to query saved nil pointer struct", err)
} }
var normalStruct2 NormalStruct var normalStruct2 NormalStruct
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
t.Errorf("Failed to query saved nil pointer struct", err) t.Error("Failed to query saved nil pointer struct", err)
} }
var partialNilPointerStruct1 = PointerStruct{Num: &num} var partialNilPointerStruct1 = PointerStruct{Num: &num}
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
t.Errorf("Failed to save partial nil pointer struct", err) t.Error("Failed to save partial nil pointer struct", err)
} }
var pointerStruct3 PointerStruct var pointerStruct3 PointerStruct
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
t.Errorf("Failed to query saved partial nil pointer struct", err) t.Error("Failed to query saved partial nil pointer struct", err)
} }
var normalStruct3 NormalStruct var normalStruct3 NormalStruct
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
t.Errorf("Failed to query saved partial pointer struct", err) t.Error("Failed to query saved partial pointer struct", err)
} }
var partialNilPointerStruct2 = PointerStruct{Name: &name} var partialNilPointerStruct2 = PointerStruct{Name: &name}
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
t.Errorf("Failed to save partial nil pointer struct", err) t.Error("Failed to save partial nil pointer struct", err)
} }
var pointerStruct4 PointerStruct var pointerStruct4 PointerStruct
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
t.Errorf("Failed to query saved partial nil pointer struct", err) t.Error("Failed to query saved partial nil pointer struct", err)
} }
var normalStruct4 NormalStruct var normalStruct4 NormalStruct
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
t.Errorf("Failed to query saved partial pointer struct", err) t.Error("Failed to query saved partial pointer struct", err)
} }
} }

View File

@ -1,154 +0,0 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strings"
"time"
"github.com/lib/pq/hstore"
)
type postgres struct {
commonDialect
}
func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (postgres) SupportLastInsertId() bool {
return false
}
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "numeric"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "timestamp with time zone"
}
case reflect.Map:
if value.Type() == hstoreType {
return "hstore"
}
default:
if isByteArrayOrSlice(value) {
return "bytea"
} else if isUUID(value) {
return "uuid"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
var byteType = reflect.TypeOf(uint8(0))
func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (s postgres) HasTable(scope *Scope, tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
return count > 0
}
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
return count > 0
}
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
return count > 0
}
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
return
}
var hstoreType = reflect.TypeOf(Hstore{})
type Hstore map[string]*string
func (h Hstore) Value() (driver.Value, error) {
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
if len(h) == 0 {
return nil, nil
}
for key, value := range h {
var s sql.NullString
if value != nil {
s.String = *value
s.Valid = true
}
hstore.Map[key] = s
}
return hstore.Value()
}
func (h *Hstore) Scan(value interface{}) error {
hstore := hstore.Hstore{}
if err := hstore.Scan(value); err != nil {
return err
}
if len(hstore.Map) == 0 {
return nil
}
*h = Hstore{}
for k := range hstore.Map {
if hstore.Map[k].Valid {
s := hstore.Map[k].String
(*h)[k] = &s
} else {
(*h)[k] = nil
}
}
return nil
}

View File

@ -1,388 +0,0 @@
package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
)
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
for _, column := range columns {
if pointedValue.FieldByName(column).IsValid() {
result := pointedValue.FieldByName(column).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}
func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
func Preload(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() {
return
}
preloadMap := map[string]bool{}
fields := scope.Fields()
for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".")
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = []interface{}{}
for i, key := range keys {
var found bool
if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if i == len(keys)-1 {
conditions = originalConditions
}
for _, field := range currentFields {
if field.Name != key || field.Relationship == nil {
continue
}
found = true
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, conditions)
case "has_many":
currentScope.handleHasManyPreload(field, conditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, conditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, conditions)
default:
currentScope.Err(errors.New("not supported relation"))
}
break
}
if !found {
value := reflect.ValueOf(currentScope.Value)
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
value = value.Index(0).Elem()
}
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
return
}
preloadMap[strings.Join(keys[:i+1], ".")] = true
nextLoop:
if i < len(keys)-1 {
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields()
}
}
}
}
func makeSlice(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.ForeignFieldNames)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
if err := scope.SetColumn(field, result); err != nil {
scope.Err(err)
return
}
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice {
preloadMap := make(map[string][]reflect.Value)
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldNames)
preloadMap[toString(value)] = append(preloadMap[toString(value)], result)
}
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
objectRealValue := getRealValue(object, relation.AssociationForeignFieldNames)
objectStringValue := toString(objectRealValue)
if results, ok := preloadMap[objectStringValue]; ok {
if object.Kind() == reflect.Ptr {
object = object.Elem()
}
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, results...))
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.AssociationForeignFieldNames)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if object.Kind() == reflect.Ptr {
object = reflect.Indirect(objects.Index(j).Elem())
}
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
joinTableHandler := relation.JoinTableHandler
destType := field.StructField.Struct.Type.Elem()
var isPtr bool
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
var sourceKeys []string
var linkHash = make(map[string][]reflect.Value)
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
if len(conditions) > 0 {
preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...)
}
rows, err := preloadJoinDB.Rows()
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
elem := reflect.New(destType).Elem()
var values = make([]interface{}, len(columns))
fields := scope.New(elem.Addr().Interface()).Fields()
var foundFields = map[string]bool{}
for index, column := range columns {
if field, ok := fields[column]; ok && !foundFields[column] {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
foundFields[column] = true
} else {
var i interface{}
values[index] = &i
}
}
scope.Err(rows.Scan(values...))
var sourceKey []interface{}
var scannedFields = map[string]bool{}
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok && !scannedFields[column] {
if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
scannedFields[column] = true
} else if strInSlice(column, sourceKeys) {
sourceKey = append(sourceKey, *(value.(*interface{})))
}
}
if len(sourceKey) != 0 {
if isPtr {
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr())
} else {
linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem)
}
}
}
var foreignFieldNames []string
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if scope.IndirectValue().Kind() == reflect.Slice {
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if object.Kind() == reflect.Ptr {
object = object.Elem()
}
source := getRealValue(object, foreignFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))
}
}
} else {
if object := scope.IndirectValue(); object.IsValid() {
source := getRealValue(object, foreignFieldNames)
field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link))
}
}
}
}
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
var result []interface{}
for _, column := range columns {
value := reflect.Indirect(values.Index(i))
if value.Kind() == reflect.Ptr {
value = reflect.Indirect(values.Index(i).Elem())
}
result = append(result, value.FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, values.FieldByName(column).Interface())
}
return [][]interface{}{result}
}
return
}
func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
modelType := values.Type().Elem()
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
elem := column.Index(i)
if elem.CanAddr() {
columns = reflect.Append(columns, elem.Addr())
}
}
} else {
if column.CanAddr() {
columns = reflect.Append(columns, column.Addr())
}
}
}
return scope.New(columns.Interface())
case reflect.Struct:
field := values.FieldByName(column)
if !field.CanAddr() {
return nil
}
return scope.New(field.Addr().Interface())
}
return nil
}

View File

@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err) t.Error(err)
} }
} }
@ -818,90 +818,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
} }
} }
func TestManyToManyPreloadForPointer(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string
Level1s []*Level1 `gorm:"many2many:levels;"`
}
)
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
t.Error(err)
}
want := Level2{Value: "Bob", Level1s: []*Level1{
{Value: "ru"},
{Value: "en"},
}}
if err := DB.Save(&want).Error; err != nil {
t.Error(err)
}
want2 := Level2{Value: "Tom", Level1s: []*Level1{
{Value: "zh"},
{Value: "de"},
}}
if err := DB.Save(&want2).Error; err != nil {
t.Error(err)
}
var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got2, want2) {
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
}
var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
}
var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
t.Error(err)
}
var got5 Level2
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
var ruLevel1 Level1
var zhLevel1 Level1
DB.First(&ruLevel1, "value = ?", "ru")
DB.First(&zhLevel1, "value = ?", "zh")
got.Level1s = []*Level1{&ruLevel1}
got2.Level1s = []*Level1{&zhLevel1}
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
}
}
func TestManyToManyPreloadForNestedPointer(t *testing.T) { func TestManyToManyPreloadForNestedPointer(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
@ -1065,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err) t.Error(err)
} }
} }
@ -1122,12 +1038,87 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err) t.Error(err)
} }
} }
func TestNestedManyToManyPreload3(t *testing.T) { func TestNestedManyToManyPreload3(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string
Level1s []*Level1 `gorm:"many2many:level1_level2;"`
}
Level3 struct {
ID uint
Value string
Level2ID sql.NullInt64
Level2 *Level2
}
)
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists("level1_level2")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
t.Error(err)
}
level1Zh := &Level1{Value: "zh"}
level1Ru := &Level1{Value: "ru"}
level1En := &Level1{Value: "en"}
level21 := &Level2{
Value: "Level2-1",
Level1s: []*Level1{level1Zh, level1Ru},
}
level22 := &Level2{
Value: "Level2-2",
Level1s: []*Level1{level1Zh, level1En},
}
wants := []*Level3{
{
Value: "Level3-1",
Level2: level21,
},
{
Value: "Level3-2",
Level2: level22,
},
{
Value: "Level3-3",
Level2: level21,
},
}
for _, want := range wants {
if err := DB.Save(&want).Error; err != nil {
t.Error(err)
}
}
var gots []*Level3
if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB {
return db.Order("level1.id ASC")
}).Find(&gots).Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gots, wants) {
t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants))
}
}
func TestNestedManyToManyPreload4(t *testing.T) {
type ( type (
Level4 struct { Level4 struct {
ID uint ID uint
@ -1185,6 +1176,90 @@ func TestNestedManyToManyPreload3(t *testing.T) {
} }
} }
func TestManyToManyPreloadForPointer(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string
Level1s []*Level1 `gorm:"many2many:levels;"`
}
)
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
t.Error(err)
}
want := Level2{Value: "Bob", Level1s: []*Level1{
{Value: "ru"},
{Value: "en"},
}}
if err := DB.Save(&want).Error; err != nil {
t.Error(err)
}
want2 := Level2{Value: "Tom", Level1s: []*Level1{
{Value: "zh"},
{Value: "de"},
}}
if err := DB.Save(&want2).Error; err != nil {
t.Error(err)
}
var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got2, want2) {
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
}
var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got3, []Level2{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2}))
}
var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
t.Error(err)
}
var got5 Level2
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
var ruLevel1 Level1
var zhLevel1 Level1
DB.First(&ruLevel1, "value = ?", "ru")
DB.First(&zhLevel1, "value = ?", "zh")
got.Level1s = []*Level1{&ruLevel1}
got2.Level1s = []*Level1{&zhLevel1}
if !reflect.DeepEqual(got4, []Level2{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
}
}
func TestNilPointerSlice(t *testing.T) { func TestNilPointerSlice(t *testing.T) {
type ( type (
Level3 struct { Level3 struct {
@ -1234,7 +1309,7 @@ func TestNilPointerSlice(t *testing.T) {
} }
if len(got) != 2 { if len(got) != 2 {
t.Error("got %v items, expected 2", len(got)) t.Errorf("got %v items, expected 2", len(got))
} }
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {

View File

@ -629,14 +629,3 @@ func TestSelectWithArrayInput(t *testing.T) {
t.Errorf("Should have selected both age and name") t.Errorf("Should have selected both age and name")
} }
} }
func TestCurrentDatabase(t *testing.T) {
databaseName := DB.CurrentDatabase()
if err := DB.Error; err != nil {
t.Errorf("Problem getting current db name: %s", err)
}
if databaseName == "" {
t.Errorf("Current db name returned empty; this should never happen!")
}
t.Logf("Got current db name: %v", databaseName)
}

310
scope.go
View File

@ -1,48 +1,32 @@
package gorm package gorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
"time"
"reflect" "reflect"
) )
// Scope contain current operation's information when you perform any operation on the database
type Scope struct { type Scope struct {
Search *search Search *search
Value interface{} Value interface{}
Sql string SQL string
SqlVars []interface{} SQLVars []interface{}
db *DB db *DB
indirectValue *reflect.Value instanceID string
instanceId string
primaryKeyField *Field primaryKeyField *Field
skipLeft bool skipLeft bool
fields map[string]*Field fields map[string]*Field
selectAttrs *[]string selectAttrs *[]string
} }
// IndirectValue return scope's reflect value's indirect value
func (scope *Scope) IndirectValue() reflect.Value { func (scope *Scope) IndirectValue() reflect.Value {
if scope.indirectValue == nil { return indirect(reflect.ValueOf(scope.Value))
value := reflect.Indirect(reflect.ValueOf(scope.Value))
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
scope.indirectValue = &value
}
return *scope.indirectValue
}
func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
scope.Err(err)
fmt.Printf(err.Error())
}
return scope
} }
// New create a new Scope without search information // New create a new Scope without search information
@ -61,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
return nil return nil
} }
// DB return scope's DB connection
func (scope *Scope) DB() *DB { func (scope *Scope) DB() *DB {
return scope.db return scope.db
} }
// SqlDB return *sql.DB // SQLDB return *sql.DB
func (scope *Scope) SqlDB() sqlCommon { func (scope *Scope) SQLDB() sqlCommon {
return scope.db.db return scope.db.db
} }
@ -75,7 +60,7 @@ func (scope *Scope) SkipLeft() {
scope.skipLeft = true scope.skipLeft = true
} }
// Quote used to quote database column name according to database dialect // Quote used to quote string to escape them for database
func (scope *Scope) Quote(str string) string { func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 { if strings.Index(str, ".") != -1 {
newStrs := []string{} newStrs := []string{}
@ -83,12 +68,12 @@ func (scope *Scope) Quote(str string) string {
newStrs = append(newStrs, scope.Dialect().Quote(str)) newStrs = append(newStrs, scope.Dialect().Quote(str))
} }
return strings.Join(newStrs, ".") return strings.Join(newStrs, ".")
} else {
return scope.Dialect().Quote(str)
} }
return scope.Dialect().Quote(str)
} }
func (scope *Scope) QuoteIfPossible(str string) string { func (scope *Scope) quoteIfPossible(str string) string {
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
return scope.Quote(str) return scope.Quote(str)
} }
@ -100,7 +85,7 @@ func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect return scope.db.parent.dialect
} }
// Err write error // Err add error to Scope
func (scope *Scope) Err(err error) error { func (scope *Scope) Err(err error) error {
if err != nil { if err != nil {
scope.db.AddError(err) scope.db.AddError(err)
@ -118,27 +103,30 @@ func (scope *Scope) HasError() bool {
return scope.db.Error != nil return scope.db.Error != nil
} }
func (scope *Scope) PrimaryFields() []*Field { // PrimaryFields return scope's primary fields
var fields = []*Field{} func (scope *Scope) PrimaryFields() (fields []*Field) {
for _, field := range scope.GetModelStruct().PrimaryFields { for _, field := range scope.Fields() {
fields = append(fields, scope.Fields()[field.DBName]) if field.IsPrimaryKey {
fields = append(fields, field)
}
} }
return fields return fields
} }
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
func (scope *Scope) PrimaryField() *Field { func (scope *Scope) PrimaryField() *Field {
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
if len(primaryFields) > 1 { if len(primaryFields) > 1 {
if field, ok := scope.Fields()["id"]; ok { if field, ok := scope.FieldByName("id"); ok {
return field return field
} }
} }
return scope.Fields()[primaryFields[0].DBName] return scope.PrimaryFields()[0]
} }
return nil return nil
} }
// PrimaryKey get the primary key's column name // PrimaryKey get main primary field's db name
func (scope *Scope) PrimaryKey() string { func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryField(); field != nil { if field := scope.PrimaryField(); field != nil {
return field.DBName return field.DBName
@ -146,7 +134,7 @@ func (scope *Scope) PrimaryKey() string {
return "" return ""
} }
// PrimaryKeyZero check the primary key is blank or not // PrimaryKeyZero check main primary field's value is blank or not
func (scope *Scope) PrimaryKeyZero() bool { func (scope *Scope) PrimaryKeyZero() bool {
field := scope.PrimaryField() field := scope.PrimaryField()
return field == nil || field.IsBlank return field == nil || field.IsBlank
@ -170,80 +158,85 @@ func (scope *Scope) HasColumn(column string) bool {
return false return false
} }
// SetColumn to set the column's value // SetColumn to set the column's value, column could be field or field's name/dbname
func (scope *Scope) SetColumn(column interface{}, value interface{}) error { func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
var updateAttrs = map[string]interface{}{}
if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
updateAttrs = attrs.(map[string]interface{})
defer scope.InstanceSet("gorm:update_attrs", updateAttrs)
}
if field, ok := column.(*Field); ok { if field, ok := column.(*Field); ok {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} else if name, ok := column.(string); ok { } else if name, ok := column.(string); ok {
var (
if field, ok := scope.Fields()[name]; ok { dbName = ToDBName(name)
mostMatchedField *Field
)
for _, field := range scope.Fields() {
if field.DBName == value {
updateAttrs[field.DBName] = value
return field.Set(value) return field.Set(value)
} }
if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
dbName := ToDBName(name) mostMatchedField = field
if field, ok := scope.Fields()[dbName]; ok { }
return field.Set(value)
} }
if field, ok := scope.FieldByName(name); ok { if mostMatchedField != nil {
return field.Set(value) updateAttrs[mostMatchedField.DBName] = value
return mostMatchedField.Set(value)
} }
} }
return errors.New("could not convert column to field") return errors.New("could not convert column to field")
} }
func (scope *Scope) CallMethod(name string, checkError bool) { func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
if scope.Value == nil || (checkError && scope.HasError()) { if reflectValue.CanAddr() {
reflectValue = reflectValue.Addr()
}
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
switch method := methodValue.Interface().(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*DB):
newDB := scope.NewDB()
method(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
}
}
}
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
if scope.Value == nil {
return return
} }
call := func(value interface{}) { if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { for i := 0; i < indirectScopeValue.Len(); i++ {
switch f := fm.Interface().(type) { scope.callMethod(methodName, indirectScopeValue.Index(i))
case func():
f()
case func(s *Scope):
f(scope)
case func(s *DB):
newDB := scope.NewDB()
f(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(f())
case func(s *Scope) error:
scope.Err(f(scope))
case func(s *DB) error:
newDB := scope.NewDB()
scope.Err(f(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", name))
}
}
}
if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ {
value := values.Index(i).Addr().Interface()
if values.Index(i).Kind() == reflect.Ptr {
value = values.Index(i).Interface()
}
call(value)
} }
} else { } else {
if scope.IndirectValue().CanAddr() { scope.callMethod(methodName, indirectScopeValue)
call(scope.IndirectValue().Addr().Interface())
} else {
call(scope.IndirectValue().Interface())
}
} }
} }
func (scope *Scope) CallMethodWithErrorCheck(name string) { // AddToVars add value as sql's vars, used to prevent SQL injection
scope.CallMethod(name, true)
}
// AddToVars add value as sql's vars, gorm will escape them
func (scope *Scope) AddToVars(value interface{}) string { func (scope *Scope) AddToVars(value interface{}) string {
if expr, ok := value.(*expr); ok { if expr, ok := value.(*expr); ok {
exp := expr.expr exp := expr.expr
@ -251,10 +244,10 @@ func (scope *Scope) AddToVars(value interface{}) string {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
} }
return exp return exp
} else {
scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BinVar(len(scope.SqlVars))
} }
scope.SQLVars = append(scope.SQLVars, value)
return scope.Dialect().BindVar(len(scope.SQLVars))
} }
type tabler interface { type tabler interface {
@ -265,7 +258,7 @@ type dbTabler interface {
TableName(*DB) string TableName(*DB) string
} }
// TableName get table name // TableName return table name
func (scope *Scope) TableName() string { func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.tableName) > 0 { if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Search.tableName return scope.Search.tableName
@ -282,44 +275,54 @@ func (scope *Scope) TableName() string {
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
} }
// QuotedTableName return quoted table name
func (scope *Scope) QuotedTableName() (name string) { func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.tableName) > 0 { if scope.Search != nil && len(scope.Search.tableName) > 0 {
if strings.Index(scope.Search.tableName, " ") != -1 { if strings.Index(scope.Search.tableName, " ") != -1 {
return scope.Search.tableName return scope.Search.tableName
} }
return scope.Quote(scope.Search.tableName) return scope.Quote(scope.Search.tableName)
} else {
return scope.Quote(scope.TableName())
} }
return scope.Quote(scope.TableName())
} }
// CombinedConditionSql get combined condition sql // CombinedConditionSql return combined condition sql
func (scope *Scope) CombinedConditionSql() string { func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() + return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
} }
// FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var (
dbName = ToDBName(name)
mostMatchedField *Field
)
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if field.Name == name || field.DBName == name { if field.Name == name || field.DBName == name {
return field, true return field, true
} }
if field.DBName == dbName {
mostMatchedField = field
} }
return nil, false }
return mostMatchedField, mostMatchedField != nil
} }
// Raw set sql // Raw set raw sql
func (scope *Scope) Raw(sql string) *Scope { func (scope *Scope) Raw(sql string) *Scope {
scope.Sql = strings.Replace(sql, "$$", "?", -1) scope.SQL = strings.Replace(sql, "$$", "?", -1)
return scope return scope
} }
// Exec invoke sql // Exec perform generated SQL
func (scope *Scope) Exec() *Scope { func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc()) defer scope.trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count scope.db.RowsAffected = count
} }
@ -334,37 +337,32 @@ func (scope *Scope) Set(name string, value interface{}) *Scope {
return scope return scope
} }
// Get get value by name // Get get setting by name
func (scope *Scope) Get(name string) (interface{}, bool) { func (scope *Scope) Get(name string) (interface{}, bool) {
return scope.db.Get(name) return scope.db.Get(name)
} }
// InstanceId get InstanceId for scope // InstanceID get InstanceID for scope
func (scope *Scope) InstanceId() string { func (scope *Scope) InstanceID() string {
if scope.instanceId == "" { if scope.instanceID == "" {
scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db) scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db)
} }
return scope.instanceId return scope.instanceID
} }
// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
return scope.Set(name+scope.InstanceId(), value) return scope.Set(name+scope.InstanceID(), value)
} }
// InstanceGet get instance setting from current operation
func (scope *Scope) InstanceGet(name string) (interface{}, bool) { func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
return scope.Get(name + scope.InstanceId()) return scope.Get(name + scope.InstanceID())
}
// Trace print sql log
func (scope *Scope) Trace(t time.Time) {
if len(scope.Sql) > 0 {
scope.db.slog(scope.Sql, t, scope.SqlVars...)
}
} }
// Begin start a transaction // Begin start a transaction
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if db, ok := scope.SqlDB().(sqlDb); ok { if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil { if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon) scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true) scope.InstanceSet("gorm:started_transaction", true)
@ -373,7 +371,7 @@ func (scope *Scope) Begin() *Scope {
return scope return scope
} }
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
func (scope *Scope) CommitOrRollback() *Scope { func (scope *Scope) CommitOrRollback() *Scope {
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
if db, ok := scope.db.db.(sqlTx); ok { if db, ok := scope.db.db.(sqlTx); ok {
@ -388,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
return scope return scope
} }
// SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string { func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil { if scope.selectAttrs == nil {
attrs := []string{} attrs := []string{}
@ -407,57 +406,38 @@ func (scope *Scope) SelectAttrs() []string {
return *scope.selectAttrs return *scope.selectAttrs
} }
// OmitAttrs return omited attributes
func (scope *Scope) OmitAttrs() []string { func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits return scope.Search.omits
} }
func (scope *Scope) changeableDBColumn(column string) bool { func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
selectAttrs := scope.SelectAttrs() var values = make([]interface{}, len(columns))
omitAttrs := scope.OmitAttrs() var ignored interface{}
if len(selectAttrs) > 0 { for index, column := range columns {
for _, attr := range selectAttrs { if field, ok := fieldsMap[column]; ok {
if column == ToDBName(attr) { if field.Field.Kind() == reflect.Ptr {
return true values[index] = field.Field.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
} }
} else {
values[index] = &ignored
} }
return false
} }
for _, attr := range omitAttrs { scope.Err(rows.Scan(values...))
if column == ToDBName(attr) {
return false for index, column := range columns {
if field, ok := fieldsMap[column]; ok {
if field.Field.Kind() != reflect.Ptr {
if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
} }
} }
return true
}
func (scope *Scope) changeableField(field *Field) bool {
selectAttrs := scope.SelectAttrs()
omitAttrs := scope.OmitAttrs()
if len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range omitAttrs {
if field.Name == attr || field.DBName == attr {
return false
}
}
return !field.IsIgnored
}
func (scope *Scope) shouldSaveAssociations() bool {
saveAssociations, ok := scope.Get("gorm:save_associations")
if ok && !saveAssociations.(bool) {
return false
}
return true && !scope.HasError()
} }

View File

@ -8,6 +8,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
) )
func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) primaryCondition(value interface{}) string {
@ -75,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
} }
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSql string var notEqualSQL string
var primaryKey = scope.PrimaryKey() var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) { switch value := clause["query"].(type) {
@ -86,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value) str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value) notEqualSQL = fmt.Sprintf("NOT (%v)", value)
} else { } else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
@ -138,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
if scanner, ok := interface{}(arg).(driver.Valuer); ok { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value() arg, _ = scanner.Value()
} }
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
} }
} }
return return
@ -172,17 +173,20 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
return return
} }
func (scope *Scope) whereSql() (sql string) { func (scope *Scope) whereSQL() (sql string) {
var primaryConditions, andConditions, orConditions []string var (
quotedTableName = scope.QuotedTableName()
primaryConditions, andConditions, orConditions []string
)
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
primaryConditions = append(primaryConditions, sql) primaryConditions = append(primaryConditions, sql)
} }
if !scope.PrimaryKeyZero() { if !scope.PrimaryKeyZero() {
for _, field := range scope.PrimaryFields() { for _, field := range scope.PrimaryFields() {
sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
primaryConditions = append(primaryConditions, sql) primaryConditions = append(primaryConditions, sql)
} }
} }
@ -205,30 +209,30 @@ func (scope *Scope) whereSql() (sql string) {
} }
} }
orSql := strings.Join(orConditions, " OR ") orSQL := strings.Join(orConditions, " OR ")
combinedSql := strings.Join(andConditions, " AND ") combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSql) > 0 { if len(combinedSQL) > 0 {
if len(orSql) > 0 { if len(orSQL) > 0 {
combinedSql = combinedSql + " OR " + orSql combinedSQL = combinedSQL + " OR " + orSQL
} }
} else { } else {
combinedSql = orSql combinedSQL = orSQL
} }
if len(primaryConditions) > 0 { if len(primaryConditions) > 0 {
sql = "WHERE " + strings.Join(primaryConditions, " AND ") sql = "WHERE " + strings.Join(primaryConditions, " AND ")
if len(combinedSql) > 0 { if len(combinedSQL) > 0 {
sql = sql + " AND (" + combinedSql + ")" sql = sql + " AND (" + combinedSQL + ")"
} }
} else if len(combinedSql) > 0 { } else if len(combinedSQL) > 0 {
sql = "WHERE " + combinedSql sql = "WHERE " + combinedSQL
} }
return return
} }
func (scope *Scope) selectSql() string { func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 { if len(scope.Search.selects) == 0 {
if scope.Search.joins != "" { if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName()) return fmt.Sprintf("%v.*", scope.QuotedTableName())
} }
return "*" return "*"
@ -236,87 +240,60 @@ func (scope *Scope) selectSql() string {
return scope.buildSelectQuery(scope.Search.selects) return scope.buildSelectQuery(scope.Search.selects)
} }
func (scope *Scope) orderSql() string { func (scope *Scope) orderSQL() string {
if len(scope.Search.orders) == 0 || scope.Search.countingQuery { if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
return "" return ""
} }
return " ORDER BY " + strings.Join(scope.Search.orders, ",") return " ORDER BY " + strings.Join(scope.Search.orders, ",")
} }
func (scope *Scope) limitSql() string { func (scope *Scope) limitAndOffsetSQL() string {
if !scope.Dialect().HasTop() { return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
if len(scope.Search.limit) == 0 {
return ""
}
return " LIMIT " + scope.Search.limit
}
return ""
} }
func (scope *Scope) topSql() string { func (scope *Scope) groupSQL() string {
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " TOP(" + scope.Search.limit + ")"
}
return ""
}
func (scope *Scope) offsetSql() string {
if len(scope.Search.offset) == 0 {
return ""
}
if scope.Dialect().HasTop() {
sql := " OFFSET " + scope.Search.offset + " ROW "
if len(scope.Search.limit) > 0 {
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
}
return sql
}
return " OFFSET " + scope.Search.offset
}
func (scope *Scope) groupSql() string {
if len(scope.Search.group) == 0 { if len(scope.Search.group) == 0 {
return "" return ""
} }
return " GROUP BY " + scope.Search.group return " GROUP BY " + scope.Search.group
} }
func (scope *Scope) havingSql() string { func (scope *Scope) havingSQL() string {
if scope.Search.havingConditions == nil { if len(scope.Search.havingConditions) == 0 {
return "" return ""
} }
var andConditions []string var andConditions []string
for _, clause := range scope.Search.havingConditions { for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" { if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql) andConditions = append(andConditions, sql)
} }
} }
combinedSql := strings.Join(andConditions, " AND ") combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSql) == 0 { if len(combinedSQL) == 0 {
return "" return ""
} }
return " HAVING " + combinedSql return " HAVING " + combinedSQL
} }
func (scope *Scope) joinsSql() string { func (scope *Scope) joinsSQL() string {
return scope.Search.joins + " " var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
}
}
return strings.Join(joinConditions, " ") + " "
} }
func (scope *Scope) prepareQuerySql() { func (scope *Scope) prepareQuerySQL() {
if scope.Search.raw { if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else { } else {
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
} }
return return
} }
@ -338,61 +315,53 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
return scope return scope
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
if !scope.IndirectValue().CanAddr() { if scope.IndirectValue().Kind() != reflect.Struct {
return values, true return values, true
} }
var hasExpr bool results = map[string]interface{}{}
for key, value := range values { for key, value := range values {
if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok { if _, ok := value.(*expr); ok {
hasExpr = true
} else if !equalAsString(field.Field.Interface(), value) {
hasUpdate = true hasUpdate = true
results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) {
field.Set(value)
if field.IsNormal {
hasUpdate = true
results[field.DBName] = field.Field.Interface()
}
}
} else {
field.Set(value) field.Set(value)
} }
} }
} }
}
if hasExpr {
var updateMap = map[string]interface{}{}
for key, field := range scope.Fields() {
if field.IsNormal {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = field.Field.Interface()
}
}
}
return updateMap, true
}
return return
} }
func (scope *Scope) row() *sql.Row { func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc()) defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySql() scope.prepareQuerySQL()
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
} }
func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc()) defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySql() scope.prepareQuerySQL()
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
} }
func (scope *Scope) initialize() *Scope { func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions { for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
} }
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
return scope return scope
} }
@ -433,23 +402,45 @@ func (scope *Scope) typeName() string {
return typ.Name() return typ.Name()
} }
// trace print sql log
func (scope *Scope) trace(t time.Time) {
if len(scope.SQL) > 0 {
scope.db.slog(scope.SQL, t, scope.SQLVars...)
}
}
func (scope *Scope) changeableField(field *Field) bool {
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range scope.OmitAttrs() {
if field.Name == attr || field.DBName == attr {
return false
}
}
return true
}
func (scope *Scope) shouldSaveAssociations() bool {
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
return false
}
return true && !scope.HasError()
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value) toScope := scope.db.NewScope(value)
fromFields := scope.Fields()
toFields := toScope.Fields()
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
var fromField, toField *Field fromField, _ := scope.FieldByName(foreignKey)
if field, ok := scope.FieldByName(foreignKey); ok { toField, _ := toScope.FieldByName(foreignKey)
fromField = field
} else {
fromField = fromFields[ToDBName(foreignKey)]
}
if field, ok := toScope.FieldByName(foreignKey); ok {
toField = field
} else {
toField = toFields[ToDBName(foreignKey)]
}
if fromField != nil { if fromField != nil {
if relationship := fromField.Relationship; relationship != nil { if relationship := fromField.Relationship; relationship != nil {
@ -508,30 +499,26 @@ func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db) joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) { if !scope.Dialect().HasTable(joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames { for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok { if field, ok := scope.FieldByName(fieldName); ok {
value := reflect.Indirect(reflect.New(field.Struct.Type)) foreignKeyStruct := field.clone()
primaryKeySqlType := field.TagSettings["TYPE"] foreignKeyStruct.IsPrimaryKey = false
if primaryKeySqlType == "" { foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
} }
} }
for idx, fieldName := range relationship.AssociationForeignFieldNames { for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok { if field, ok := toScope.FieldByName(fieldName); ok {
value := reflect.Indirect(reflect.New(field.Struct.Type)) foreignKeyStruct := field.clone()
primaryKeySqlType := field.TagSettings["TYPE"] foreignKeyStruct.IsPrimaryKey = false
if primaryKeySqlType == "" { foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
} }
} }
@ -545,10 +532,10 @@ func (scope *Scope) createJoinTable(field *StructField) {
func (scope *Scope) createTable() *Scope { func (scope *Scope) createTable() *Scope {
var tags []string var tags []string
var primaryKeys []string var primaryKeys []string
var primaryKeyInColumnType bool = false var primaryKeyInColumnType = false
for _, field := range scope.GetStructFields() { for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal { if field.IsNormal {
sqlTag := scope.generateSqlTag(field) sqlTag := scope.Dialect().DataTypeOf(field)
// Check if the primary key constraint was specified as // Check if the primary key constraint was specified as
// part of the column type. If so, we can only support // part of the column type. If so, we can only support
@ -582,13 +569,6 @@ func (scope *Scope) dropTable() *Scope {
return scope return scope
} }
func (scope *Scope) dropTableIfExists() *Scope {
if scope.Dialect().HasTable(scope, scope.TableName()) {
scope.dropTable()
}
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) { func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
} }
@ -598,13 +578,13 @@ func (scope *Scope) dropColumn(column string) {
} }
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) { if scope.Dialect().HasIndex(scope.TableName(), indexName) {
return return
} }
var columns []string var columns []string
for _, name := range column { for _, name := range column {
columns = append(columns, scope.QuoteIfPossible(name)) columns = append(columns, scope.quoteIfPossible(name))
} }
sqlCreate := "CREATE INDEX" sqlCreate := "CREATE INDEX"
@ -612,31 +592,35 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
sqlCreate = "CREATE UNIQUE INDEX" sqlCreate = "CREATE UNIQUE INDEX"
} }
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
} }
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
} }
func (scope *Scope) removeIndex(indexName string) { func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope, indexName) scope.Dialect().RemoveIndex(scope.TableName(), indexName)
} }
func (scope *Scope) autoMigrate() *Scope { func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName() tableName := scope.TableName()
quotedTableName := scope.QuotedTableName() quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(scope, tableName) { if !scope.Dialect().HasTable(tableName) {
scope.createTable() scope.createTable()
} else { } else {
for _, field := range scope.GetStructFields() { for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal { if field.IsNormal {
sqlTag := scope.generateSqlTag(field) sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
} }
} }

67
scope_utils.go Normal file
View File

@ -0,0 +1,67 @@
package gorm
import "reflect"
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
for _, value := range values {
indirectValue := reflect.ValueOf(value)
for indirectValue.Kind() == reflect.Ptr {
indirectValue = indirectValue.Elem()
}
switch indirectValue.Kind() {
case reflect.Slice:
for i := 0; i < indirectValue.Len(); i++ {
var result []interface{}
var object = indirect(indirectValue.Index(i))
for _, column := range columns {
result = append(result, object.FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, indirectValue.FieldByName(column).Interface())
}
results = append(results, result)
}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
fieldType := fieldStruct.Type
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ {
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
}
} else if result.CanAddr() {
results = reflect.Append(results, result.Addr())
}
}
return scope.New(results.Interface())
}
case reflect.Struct:
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
return scope.New(field.Addr().Interface())
}
}
return nil
}

View File

@ -8,15 +8,15 @@ type search struct {
orConditions []map[string]interface{} orConditions []map[string]interface{}
notConditions []map[string]interface{} notConditions []map[string]interface{}
havingConditions []map[string]interface{} havingConditions []map[string]interface{}
joinConditions []map[string]interface{}
initAttrs []interface{} initAttrs []interface{}
assignAttrs []interface{} assignAttrs []interface{}
selects map[string]interface{} selects map[string]interface{}
omits []string omits []string
orders []string orders []string
joins string
preload []searchPreload preload []searchPreload
offset string offset int
limit string limit int
group string group string
tableName string tableName string
raw bool raw bool
@ -82,18 +82,18 @@ func (s *search) Omit(columns ...string) *search {
return s return s
} }
func (s *search) Limit(value interface{}) *search { func (s *search) Limit(limit int) *search {
s.limit = s.getInterfaceAsSql(value) s.limit = limit
return s return s
} }
func (s *search) Offset(value interface{}) *search { func (s *search) Offset(offset int) *search {
s.offset = s.getInterfaceAsSql(value) s.offset = offset
return s return s
} }
func (s *search) Group(query string) *search { func (s *search) Group(query string) *search {
s.group = s.getInterfaceAsSql(query) s.group = s.getInterfaceAsSQL(query)
return s return s
} }
@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search {
return s return s
} }
func (s *search) Joins(query string) *search { func (s *search) Joins(query string, values ...interface{}) *search {
s.joins = query s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
return s return s
} }
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
return s return s
} }
func (s *search) getInterfaceAsSql(value interface{}) (str string) { func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
switch value.(type) { switch value.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
str = fmt.Sprintf("%v", value) str = fmt.Sprintf("%v", value)
default: default:
s.db.AddError(InvalidSql) s.db.AddError(ErrInvalidSQL)
} }
if str == "-1" { if str == "-1" {

View File

@ -1,84 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct {
commonDialect
}
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "integer primary key autoincrement"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "integer primary key autoincrement"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
return count > 0
}
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
return count > 0
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
return count > 0
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@ -42,9 +42,9 @@ type CreditCard struct {
ID int8 ID int8
Number string Number string
UserId sql.NullInt64 UserId sql.NullInt64
CreatedAt time.Time CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time DeletedAt *time.Time
} }
type Email struct { type Email struct {
@ -62,7 +62,7 @@ type Address struct {
Post string Post string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time DeletedAt *time.Time
} }
type Language struct { type Language struct {

View File

@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) {
} }
DB.First(&product4, product4.Id) DB.First(&product4, product4.Id)
updatedAt4 := product4.UpdatedAt
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
var product5 Product var product5 Product
DB.First(&product5, product4.Id) DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 { if product5.Price != product4.Price+100-50 {
t.Errorf("Update with expression") t.Errorf("Update with expression")
} }
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
t.Errorf("Update with expression should update UpdatedAt") t.Errorf("Update with expression should update UpdatedAt")
} }
} }
@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) {
t.Errorf("product2's code should be updated") t.Errorf("product2's code should be updated")
} }
updatedAt4 := product4.UpdatedAt
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
var product5 Product var product5 Product
DB.First(&product5, product4.Id) DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100 { if product5.Price != product4.Price+100 {
t.Errorf("Updates with expression") t.Errorf("Updates with expression")
} }
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { // product4's UpdatedAt will be reset when updating
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
t.Errorf("Updates with expression should update UpdatedAt") t.Errorf("Updates with expression should update UpdatedAt")
} }
} }
@ -419,3 +422,32 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
} }
} }
func TestUpdatesWithBlankValues(t *testing.T) {
product := Product{Code: "product1", Price: 10}
DB.Save(&product)
DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
var product1 Product
DB.First(&product1, product.Id)
if product1.Code != "product1" || product1.Price != 100 {
t.Errorf("product's code should not be updated")
}
}
func TestUpdateDecodeVirtualAttributes(t *testing.T) {
var user = User{
Name: "jinzhu",
IgnoreMe: 88,
}
DB.Save(&user)
DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
if user.IgnoreMe != 100 {
t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
}
}

233
utils.go
View File

@ -2,10 +2,26 @@ package gorm
import ( import (
"bytes" "bytes"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"runtime"
"strings" "strings"
"sync" "sync"
"time"
) )
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
// needs, e.g:
// gorm.NowFunc = func() time.Time {
// return time.Now().UTC()
// }
var NowFunc = func() time.Time {
return time.Now()
}
// Copied from golint // Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer var commonInitialismsReplacer *strings.Replacer
@ -41,30 +57,239 @@ func newSafeMap() *safeMap {
var smap = newSafeMap() var smap = newSafeMap()
type strCase bool
const (
lower strCase = false
upper strCase = true
)
// ToDBName convert string to db name
func ToDBName(name string) string { func ToDBName(name string) string {
if v := smap.Get(name); v != "" { if v := smap.Get(name); v != "" {
return v return v
} }
value := commonInitialismsReplacer.Replace(name) if name == "" {
buf := bytes.NewBufferString("") return ""
for i, v := range value { }
if i > 0 && v >= 'A' && v <= 'Z' {
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase strCase
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z'
if i > 0 {
if currCase == upper {
if lastCase == upper && nextCase == upper {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_') buf.WriteRune('_')
} }
buf.WriteRune(v) buf.WriteRune(v)
} }
} else {
buf.WriteRune(v)
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String()) s := strings.ToLower(buf.String())
smap.Set(name, s) smap.Set(name, s)
return s return s
} }
// SQL expression
type expr struct { type expr struct {
expr string expr string
args []interface{} args []interface{}
} }
// Expr generate raw SQL expression, for example:
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
func Expr(expression string, args ...interface{}) *expr { func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args} return &expr{expr: expression, args: args}
} }
func indirect(reflectValue reflect.Value) reflect.Value {
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
return reflectValue
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for _ = range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
}
return strings.Join(newColumns, ",")
}
func toQueryValues(values [][]interface{}) (results []interface{}) {
for _, value := range values {
for _, v := range value {
results = append(results, v)
}
}
return
}
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}
func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
return value
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
for _, field := range (&Scope{Value: values}).Fields() {
if !field.IsBlank {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}
func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
return fmt.Sprintf("%v", reflectValue.Interface())
}
return ""
}
func makeSlice(elemType reflect.Type) interface{} {
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem()
}
sliceType := reflect.SliceOf(elemType)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func strInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// getValueFromFields return given fields's value
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
for _, fieldName := range fieldNames {
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
result := fieldValue.Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}
func addExtraSpaceIfExist(str string) string {
if str != "" {
return " " + str
}
return ""
}

View File

@ -1,98 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"regexp"
"runtime"
"strings"
)
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}
func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
for k, v := range value {
attrs[ToDBName(k)] = v
}
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
scope := Scope{Value: values}
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}
func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
return fmt.Sprintf("%v", reflectValue.Interface())
}
return ""
}
func strInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}

30
utils_test.go Normal file
View File

@ -0,0 +1,30 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestToDBNameGenerateFriendlyName(t *testing.T) {
var maps = map[string]string{
"": "",
"ThisIsATest": "this_is_a_test",
"PFAndESI": "pf_and_esi",
"AbcAndJkl": "abc_and_jkl",
"EmployeeID": "employee_id",
"SKU_ID": "sku_id",
"HTTPAndSMTP": "http_and_smtp",
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
"UUID": "uuid",
"HTTPURL": "http_url",
"HTTP_URL": "http_url",
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
}
for key, value := range maps {
if gorm.ToDBName(key) != value {
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
}
}
}