Merge branch 'master' into master

This commit is contained in:
Jinzhu 2018-02-10 22:25:28 +08:00 committed by GitHub
commit 85774eb9da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 864 additions and 237 deletions

View File

@ -1,10 +1,4 @@
Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`.
DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm,
Please answer these questions before submitting your issue. Thanks!
Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
### What version of Go are you using (`go version`)?
@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks!
### Which database and its version are you using?
### What did you do?
### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
Please provide a complete runnable program to reproduce your issue.
Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
```go
package main
@ -32,10 +26,9 @@ var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// Please use below username, password as your database's account for the script.
// db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
// db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
if err != nil {
panic(err)
}
@ -43,8 +36,6 @@ func init() {
}
func main() {
// your code here
if /* failure condition */ {
fmt.Println("failed")
} else {

View File

@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request.
- [] Do only one thing
- [] No API-breaking changes
- [] New code/logic commented & tested
- [] Write good commit message, try to squash your commits into a single one
- [] Run `./build.sh` in `gh-pages` branch for document changes
For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it.
Thank you.
For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
### What did this pull request do?

View File

@ -3,7 +3,8 @@
The fantastic ORM library for Golang, aims to be developer friendly.
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
## Overview
@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Supporting the project
[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu)
[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu)
## Author

View File

@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {

View File

@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) {
DB.Save(&category)
}
func TestSkipSaveAssociation(t *testing.T) {
func TestAutoSaveBelongsToAssociation(t *testing.T) {
type Company struct {
gorm.Model
Name string
@ -895,13 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) {
gorm.Model
Name string
CompanyID uint
Company Company `gorm:"save_associations:false"`
Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
}
DB.Where("name = ?", "auto_save_association").Delete(&Company{})
DB.AutoMigrate(&Company{}, &User{})
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}})
if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
t.Errorf("Company skip_save_association should not been saved")
if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() {
t.Errorf("Company auto_save_association should not have been saved when autosave is false")
}
// if foreign key is set, this should be saved even if association isn't
company := Company{Name: "auto_save_association"}
DB.Save(&company)
company.Name = "auto_save_association_new_name"
user := User{Name: "jinzhu", Company: company}
DB.Save(&user)
if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() {
t.Errorf("Company should not have been updated")
}
if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() {
t.Errorf("User's foreign key should have been saved")
}
user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}}
DB.Set("gorm:association_autocreate", true).Save(&user2)
if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() {
t.Errorf("Company auto_save_association_2 should been created when autocreate is true")
}
user2.Company.Name = "auto_save_association_2_newname"
DB.Set("gorm:association_autoupdate", true).Save(&user2)
if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() {
t.Errorf("Company should been updated")
}
}
func TestAutoSaveHasOneAssociation(t *testing.T) {
type Company struct {
gorm.Model
UserID uint
Name string
}
type User struct {
gorm.Model
Name string
Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
}
DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{})
DB.AutoMigrate(&Company{}, &User{})
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}})
if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() {
t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false")
}
company := Company{Name: "auto_save_has_one_association"}
DB.Save(&company)
company.Name = "auto_save_has_one_association_new_name"
user := User{Name: "jinzhu", Company: company}
DB.Save(&user)
if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() {
t.Errorf("Company should not have been updated")
}
if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() {
t.Errorf("Company should not have been updated")
}
if user.Company.UserID == 0 {
t.Errorf("UserID should be assigned")
}
company.Name = "auto_save_has_one_association_2_new_name"
DB.Set("gorm:association_autoupdate", true).Save(&user)
if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() {
t.Errorf("Company should been updated")
}
user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}}
DB.Set("gorm:association_autocreate", true).Save(&user2)
if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() {
t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true")
}
}
func TestAutoSaveMany2ManyAssociation(t *testing.T) {
type Company struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
Name string
Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"`
}
DB.AutoMigrate(&Company{}, &User{})
DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}})
if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() {
t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false")
}
company := Company{Name: "auto_save_m2m_association"}
DB.Save(&company)
company.Name = "auto_save_m2m_association_new_name"
user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}}
DB.Save(&user)
if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
t.Errorf("Company should not have been updated")
}
if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
t.Errorf("Company should not been created")
}
if DB.Model(&user).Association("Companies").Count() != 1 {
t.Errorf("Relationship should been saved")
}
DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user)
if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
t.Errorf("Company should been updated")
}
if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
t.Errorf("Company should been created")
}
if DB.Model(&user).Association("Companies").Count() != 2 {
t.Errorf("Relationship should been updated")
}
}

View File

@ -1,8 +1,6 @@
package gorm
import (
"fmt"
)
import "log"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
@ -95,7 +93,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
fmt.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
cp.before = "gorm:row_query"
}
}
@ -109,7 +107,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
@ -122,7 +120,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true
@ -161,7 +159,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
allNames = append(allNames, cp.name)
}

View File

@ -15,6 +15,10 @@ func init() {
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
defer scope.trace(NowFunc())
var (

View File

@ -10,6 +10,9 @@ import (
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope)
@ -324,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
scope.scan(rows, columns, append(fields, joinTableFields...))
scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {

View File

@ -1,6 +1,9 @@
package gorm
import "reflect"
import (
"reflect"
"strings"
)
func beginTransactionCallback(scope *Scope) {
scope.Begin()
@ -10,31 +13,81 @@ func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}
func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
checkTruth := func(value interface{}) bool {
if v, ok := value.(bool); ok && !v {
return false
}
if v, ok := value.(string); ok {
v = strings.ToLower(v)
if v == "false" || v != "skip" {
return false
}
}
return true
}
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") {
if relationship := field.Relationship; relationship != nil {
return true, relationship
if r = field.Relationship; r != nil {
autoUpdate, autoCreate, saveReference = true, true, true
if value, ok := scope.Get("gorm:save_associations"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
}
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
autoUpdate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
autoUpdate = checkTruth(value)
}
if value, ok := scope.Get("gorm:association_autocreate"); ok {
autoCreate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
autoCreate = checkTruth(value)
}
if value, ok := scope.Get("gorm:association_save_reference"); ok {
saveReference = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
saveReference = checkTruth(value)
}
}
}
return false, nil
return
}
func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
if relationship != nil && relationship.Kind == "belongs_to" {
fieldValue := field.Field.Addr().Interface()
scope.Err(scope.NewDB().Save(fieldValue).Error)
if len(relationship.ForeignFieldNames) != 0 {
// set value's foreign key
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
newScope := scope.New(fieldValue)
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(scope.NewDB().Save(fieldValue).Error)
}
} else if autoUpdate {
scope.Err(scope.NewDB().Save(fieldValue).Error)
}
if saveReference {
if len(relationship.ForeignFieldNames) != 0 {
// set value's foreign key
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
}
}
}
}
@ -43,12 +96,10 @@ func saveBeforeAssociationsCallback(scope *Scope) {
}
func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if ok, relationship := saveFieldAsAssociation(scope, field); ok &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
switch value.Kind() {
@ -58,7 +109,41 @@ func saveAfterAssociationsCallback(scope *Scope) {
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
if saveReference {
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
}
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(newDB.Save(elem).Error)
}
} else if autoUpdate {
scope.Err(newDB.Save(elem).Error)
}
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
}
}
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if saveReference {
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
@ -70,29 +155,15 @@ func saveAfterAssociationsCallback(scope *Scope) {
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
}
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(scope.NewDB().Save(elem).Error)
}
} else if autoUpdate {
scope.Err(scope.NewDB().Save(elem).Error)
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
scope.Err(scope.NewDB().Save(elem).Error)
}
}
}

View File

@ -3,6 +3,7 @@ package gorm
import (
"errors"
"fmt"
"sort"
"strings"
)
@ -59,7 +60,16 @@ func updateCallback(scope *Scope) {
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for column, value := range updateAttrs.(map[string]interface{}) {
// Sort the column names so that the generated SQL is the same every time.
updateMap := updateAttrs.(map[string]interface{})
var columns []string
for c := range updateMap {
columns = append(columns, c)
}
sort.Strings(columns)
for _, column := range columns {
value := updateMap[column]
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
}
} else {

View File

@ -279,3 +279,68 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
t.Errorf("should preload discount from coupon")
}
}
type SelfReferencingUser struct {
gorm.Model
Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
}
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{})
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
user := SelfReferencingUser{
Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2},
}
if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly")
}
var newUser = SelfReferencingUser{}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m")
}
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
if DB.Model(&user).Association("Friends").Count() != 3 {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
if DB.Model(&user).Association("Friends").Count() != 1 {
t.Errorf("Should find created friends correctly")
}
friend := SelfReferencingUser{}
DB.Model(&newUser).Association("Friends").Find(&friend)
if friend.Name != "friend4_m2m" {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Delete(friend)
if DB.Model(&user).Association("Friends").Count() != 0 {
t.Errorf("All friends should be deleted")
}
}

View File

@ -33,6 +33,8 @@ type Dialect interface {
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
@ -41,8 +43,8 @@ type Dialect interface {
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildKeyName(kind, tableName string, fields ...string) string
// CurrentDatabase return current database name
CurrentDatabase() string
@ -114,3 +116,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -38,6 +38,13 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
@ -92,7 +99,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
@ -107,16 +115,23 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
@ -144,9 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
return ""
}
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
return keyName
}

View File

@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int8:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint8:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string {
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
precision := ""
if p, ok := field.TagSettings["PRECISION"]; ok {
precision = fmt.Sprintf("(%s)", p)
}
if _, ok := field.TagSettings["NOT NULL"]; ok {
sqlType = "timestamp"
sqlType = fmt.Sprintf("timestamp%v", precision)
} else {
sqlType = "timestamp NULL"
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
}
}
default:
@ -127,6 +132,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error {
return err
}
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
@ -144,7 +154,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}
@ -157,8 +168,8 @@ func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
if utf8.RuneCountInString(keyName) <= 64 {
return keyName
}
@ -166,8 +177,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
// sha1 is 40 characters, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
if len(destRunes) > 24 {
destRunes = destRunes[:24]
}

View File

@ -1,6 +1,7 @@
package gorm
import (
"encoding/json"
"fmt"
"reflect"
"strings"
@ -13,6 +14,7 @@ type postgres struct {
func init() {
RegisterDialect("postgres", &postgres{})
RegisterDialect("cloudsqlpostgres", &postgres{})
}
func (postgres) GetName() string {
@ -31,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigserial"
} else {
@ -67,8 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "bytea"
} else if isUUID(dataValue) {
sqlType = "uuid"
if isUUID(dataValue) {
sqlType = "uuid"
}
if isJSON(dataValue) {
sqlType = "jsonb"
}
}
}
}
@ -85,7 +93,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
@ -97,13 +105,13 @@ func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
@ -128,3 +136,8 @@ func isUUID(value reflect.Value) bool {
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
func isJSON(value reflect.Value) bool {
_, ok := value.Interface().(json.RawMessage)
return ok
}

View File

@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "integer primary key autoincrement"
} else {

View File

@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string {
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
return fmt.Sprintf(`[%s]`, key)
}
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)"
} else {
@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return value != "FALSE"
}
return field.IsPrimaryKey
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
@ -128,16 +135,23 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
@ -168,3 +182,11 @@ func (mssql) SelectFromDummyTable() string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@ -6,6 +6,9 @@ import (
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json"
"errors"
"fmt"
)
type Hstore map[string]*string
@ -52,3 +55,26 @@ func (h *Hstore) Scan(value interface{}) error {
return nil
}
// Jsonb Postgresql's JSONB data type
type Jsonb struct {
json.RawMessage
}
// Value get value of Jsonb
func (j Jsonb) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into Jsonb
func (j *Jsonb) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, j)
}

30
docker-compose.yml Normal file
View File

@ -0,0 +1,30 @@
version: '3'
services:
mysql:
image: 'mysql:latest'
ports:
- 9910:3306
environment:
- MYSQL_DATABASE=gorm
- MYSQL_USER=gorm
- MYSQL_PASSWORD=gorm
- MYSQL_RANDOM_ROOT_PASSWORD="yes"
postgres:
image: 'postgres:latest'
ports:
- 9920:5432
environment:
- POSTGRES_USER=gorm
- POSTGRES_DB=gorm
- POSTGRES_PASSWORD=gorm
mssql:
image: 'mcmoe/mssqldocker:latest'
ports:
- 9930:1433
environment:
- ACCEPT_EULA=Y
- SA_PASSWORD=LoremIpsum86
- MSSQL_DB=gorm
- MSSQL_USER=gorm
- MSSQL_PASSWORD=LoremIpsum86

View File

@ -29,6 +29,10 @@ func (errs Errors) GetErrors() []error {
// Add adds an error
func (errs Errors) Add(newErrors ...error) Errors {
for _, err := range newErrors {
if err == nil {
continue
}
if errors, ok := err.(Errors); ok {
errs = errs.Add(errors...)
} else {

View File

@ -82,38 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
}
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
values[foreignKey.DBName] = field.Field.Interface()
for _, joinTableSource := range joinTableSources {
if joinTableSource.ModelType == modelType {
for _, foreignKey := range joinTableSource.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
conditionMap[foreignKey.DBName] = field.Field.Interface()
}
}
break
}
}
}
return values
}
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("")
searchMap := s.getSearchMap(db, source, destination)
var (
scope = db.NewScope("")
conditionMap = map[string]interface{}{}
)
// Update condition map for source
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
// Update condition map for destination
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
for key, value := range conditionMap {
assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
@ -141,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
scope = db.NewScope(nil)
conditions []string
values []interface{}
conditionMap = map[string]interface{}{}
)
for key, value := range s.getSearchMap(db, sources...) {
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
for key, value := range conditionMap {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}

46
main.go
View File

@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB {
// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
// Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
}
// Row return `*sql.Row` with given conditions
@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) {
// ScanRows scan `*sql.Rows` to give struct
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
var (
clone = s.clone()
scope = clone.NewScope(result)
scope = s.NewScope(result)
clone = scope.db
columns, err = rows.Columns()
)
@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB {
// Related get related associations
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
return s.NewScope(s.Value).related(value, foreignKeys...).db
}
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB {
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().NewScope(s.Value).
return s.NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db
@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumns(values interface{}) *DB {
return s.clone().NewScope(s.Value).
return s.NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values).
@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB {
// Save update value in database, if the value doesn't have primary key, will insert it
func (s *DB) Save(value interface{}) *DB {
scope := s.clone().NewScope(value)
scope := s.NewScope(value)
if !scope.PrimaryKeyZero() {
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
if newDB.Error == nil && newDB.RowsAffected == 0 {
@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB {
// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
scope := s.clone().NewScope(value)
scope := s.NewScope(value)
return scope.callCallbacks(s.parent.callbacks.creates).db
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
}
// Raw use raw sql as conditions, won't run it unless invoked by other methods
@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.clone().NewScope(nil)
scope := s.NewScope(nil)
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL)
@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB {
// NewRecord check if value's primary key is blank
func (s *DB) NewRecord(value interface{}) bool {
return s.clone().NewScope(value).PrimaryKeyZero()
return s.NewScope(value).PrimaryKeyZero()
}
// RecordNotFound check if returning ErrRecordNotFound error
@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB {
// HasTable check has table or not
func (s *DB) HasTable(value interface{}) bool {
var (
scope = s.clone().NewScope(value)
scope = s.NewScope(value)
tableName string
)
@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB {
// ModifyColumn modify column to type
func (s *DB) ModifyColumn(column string, typ string) *DB {
scope := s.clone().NewScope(s.Value)
scope := s.NewScope(s.Value)
scope.modifyColumn(column, typ)
return scope.db
}
// DropColumn drop a column
func (s *DB) DropColumn(column string) *DB {
scope := s.clone().NewScope(s.Value)
scope := s.NewScope(s.Value)
scope.dropColumn(column)
return scope.db
}
@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
// RemoveIndex remove index with name
func (s *DB) RemoveIndex(indexName string) *DB {
scope := s.clone().NewScope(s.Value)
scope := s.NewScope(s.Value)
scope.removeIndex(indexName)
return scope.db
}
@ -606,11 +606,19 @@ func (s *DB) RemoveIndex(indexName string) *DB {
// AddForeignKey Add foreign key to the given scope, e.g:
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.clone().NewScope(s.Value)
scope := s.NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
}
// RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest)
return scope.db
}
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
func (s *DB) Association(column string) *Association {
var err error

View File

@ -36,27 +36,20 @@ func init() {
}
func OpenTestConnection() (db *gorm.DB, err error) {
dbDSN := os.Getenv("GORM_DSN")
switch os.Getenv("GORM_DIALECT") {
case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
// CREATE DATABASE gorm;
// GRANT ALL ON gorm.* TO 'gorm'@'localhost';
fmt.Println("testing mysql...")
dbhost := os.Getenv("GORM_DBADDRESS")
if dbhost != "" {
dbhost = fmt.Sprintf("tcp(%v)", dbhost)
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
}
db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost))
db, err = gorm.Open("mysql", dbDSN)
case "postgres":
fmt.Println("testing postgres...")
dbhost := os.Getenv("GORM_DBHOST")
if dbhost != "" {
dbhost = fmt.Sprintf("host=%v ", dbhost)
if dbDSN == "" {
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
}
db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost))
case "foundation":
fmt.Println("testing foundation...")
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
db, err = gorm.Open("postgres", dbDSN)
case "mssql":
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// CREATE DATABASE gorm;
@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
// CREATE USER gorm FROM LOGIN gorm;
// sp_changedbowner 'gorm';
fmt.Println("testing mssql...")
db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
}
db, err = gorm.Open("mssql", dbDSN)
default:
fmt.Println("testing sqlite3...")
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
if os.Getenv("DEBUG") == "true" {
if debug := os.Getenv("DEBUG"); debug == "true" {
db.LogMode(true)
} else if debug == "false" {
db.LogMode(false)
}
db.DB().SetMaxIdleConns(10)

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"os"
"reflect"
"testing"
"time"
@ -432,3 +433,24 @@ func TestMultipleIndexes(t *testing.T) {
t.Error("MultipleIndexes unique index failed")
}
}
func TestModifyColumnType(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
}
type ModifyColumnType struct {
gorm.Model
Name1 string `gorm:"length:100"`
Name2 string `gorm:"length:200"`
}
DB.DropTable(&ModifyColumnType{})
DB.CreateTable(&ModifyColumnType{})
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
}
}

View File

@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
foreignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
}
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
@ -264,37 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, field.DBName)
{ // Foreign Keys for Source
joinTableDBNames := []string{}
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
joinTableDBNames = strings.Split(foreignKey, ",")
}
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, field.DBName)
}
}
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// setup join table foreign keys for source
if len(joinTableDBNames) > idx {
// if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else {
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
}
}
}
}
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// join table foreign keys for source
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
}
}
{ // Foreign Keys for Association (Destination)
associationJoinTableDBNames := []string{}
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
associationJoinTableDBNames = strings.Split(foreignKey, ",")
}
}
for _, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}
for idx, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// setup join table foreign keys for association
if len(associationJoinTableDBNames) > idx {
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
} else {
// join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
}
}
}
}
@ -399,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
)
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
tagForeignKeys = strings.Split(foreignKey, ",")
}
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {

View File

@ -1627,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) {
}
}
func TestPreloadManyToManyCallbacks(t *testing.T) {
type (
Level2 struct {
ID uint
Name string
}
Level1 struct {
ID uint
Name string
Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
}
)
DB.DropTableIfExists("level1_level2s")
DB.DropTableIfExists(new(Level1))
DB.DropTableIfExists(new(Level2))
if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil {
t.Error(err)
}
lvl := Level1{
Name: "l1",
Level2s: []Level2{
Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
},
}
DB.Save(&lvl)
called := 0
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
called = called + 1
})
DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
if called != 3 {
t.Errorf("Wanted callback to be called 3 times but got %d", called)
}
}
func toJSONString(v interface{}) []byte {
r, _ := json.MarshalIndent(v, "", " ")
return r

View File

@ -389,7 +389,7 @@ func TestOffset(t *testing.T) {
DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
}
var users1, users2, users3, users4 []User
DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
t.Errorf("Offset should work")

View File

@ -998,18 +998,6 @@ func (scope *Scope) changeableField(field *Field) bool {
return true
}
func (scope *Scope) shouldSaveAssociations() bool {
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok {
if v, ok := saveAssociations.(bool); ok && !v {
return false
}
if v, ok := saveAssociations.(string); ok && (v != "skip") {
return false
}
}
return true && !scope.HasError()
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
tx := scope.db.Set("gorm:association:source", scope.Value)
@ -1144,7 +1132,7 @@ func (scope *Scope) dropTable() *Scope {
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
}
func (scope *Scope) dropColumn(column string) {
@ -1170,7 +1158,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
// Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
@ -1179,6 +1168,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
}
@ -1214,7 +1213,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names {
if name == "INDEX" || name == "" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
}
indexes[name] = append(indexes[name], field.DBName)
}
@ -1225,7 +1224,7 @@ func (scope *Scope) autoIndex() *Scope {
for _, name := range names {
if name == "UNIQUE_INDEX" || name == "" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
@ -1233,11 +1232,15 @@ func (scope *Scope) autoIndex() *Scope {
}
for name, columns := range indexes {
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
scope.db.AddError(db.Error)
}
}
for name, columns := range uniqueIndexes {
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
scope.db.AddError(db.Error)
}
}
return scope

View File

@ -2,7 +2,6 @@ package gorm
import (
"fmt"
"regexp"
)
type search struct {
@ -73,13 +72,7 @@ func (s *search) Order(value interface{}, reorder ...bool) *search {
return s
}
var distinctSQLRegexp = regexp.MustCompile(`(?i)distinct[^a-z]+[a-z]+`)
func (s *search) Select(query interface{}, args ...interface{}) *search {
if distinctSQLRegexp.MatchString(fmt.Sprint(query)) {
s.ignoreOrderQuery = true
}
s.selects = map[string]interface{}{"query": query, "args": args}
return s
}

View File

@ -1,5 +1,5 @@
dialects=("postgres" "mysql" "mssql" "sqlite")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test
DEBUG=false GORM_DIALECT=${dialect} go test
done

View File

@ -23,7 +23,7 @@ var NowFunc = func() time.Time {
}
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)

View File

@ -2,17 +2,79 @@
box: golang
services:
- id: mariadb:10.0
- name: mariadb
id: mariadb:latest
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- id: postgres
- name: mysql
id: mysql:8
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql57
id: mysql:5.7
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql56
id: mysql:5.6
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: mysql55
id: mysql:5.5
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
- name: postgres
id: postgres:latest
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres96
id: postgres:9.6
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres95
id: postgres:9.5
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres94
id: postgres:9.4
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: postgres93
id: postgres:9.3
env:
POSTGRES_USER: gorm
POSTGRES_PASSWORD: gorm
POSTGRES_DB: gorm
- name: mssql
id: mcmoe/mssqldocker:latest
env:
ACCEPT_EULA: Y
SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
# The steps that will be executed in the build pipeline
build:
@ -42,12 +104,57 @@ build:
code: |
go test ./...
- script:
name: test mariadb
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql
code: |
GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./...
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.7
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.6
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test mysql5.5
code: |
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
- script:
name: test postgres
code: |
GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres96
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres95
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres94
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test postgres93
code: |
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
- script:
name: test mssql
code: |
GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...