mirror of https://github.com/go-gorm/gorm.git
Merge branch 'master' into master
This commit is contained in:
commit
85774eb9da
|
@ -1,10 +1,4 @@
|
||||||
Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`.
|
Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
|
||||||
|
|
||||||
DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm,
|
|
||||||
|
|
||||||
Please answer these questions before submitting your issue. Thanks!
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### What version of Go are you using (`go version`)?
|
### What version of Go are you using (`go version`)?
|
||||||
|
|
||||||
|
@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks!
|
||||||
### Which database and its version are you using?
|
### Which database and its version are you using?
|
||||||
|
|
||||||
|
|
||||||
### What did you do?
|
### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
|
||||||
|
|
||||||
Please provide a complete runnable program to reproduce your issue.
|
Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
@ -32,10 +26,9 @@ var db *gorm.DB
|
||||||
func init() {
|
func init() {
|
||||||
var err error
|
var err error
|
||||||
db, err = gorm.Open("sqlite3", "test.db")
|
db, err = gorm.Open("sqlite3", "test.db")
|
||||||
// Please use below username, password as your database's account for the script.
|
// db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
|
||||||
// db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
// db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
|
||||||
// db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True")
|
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
|
||||||
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -43,8 +36,6 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// your code here
|
|
||||||
|
|
||||||
if /* failure condition */ {
|
if /* failure condition */ {
|
||||||
fmt.Println("failed")
|
fmt.Println("failed")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request.
|
||||||
- [] Do only one thing
|
- [] Do only one thing
|
||||||
- [] No API-breaking changes
|
- [] No API-breaking changes
|
||||||
- [] New code/logic commented & tested
|
- [] New code/logic commented & tested
|
||||||
- [] Write good commit message, try to squash your commits into a single one
|
|
||||||
- [] Run `./build.sh` in `gh-pages` branch for document changes
|
|
||||||
|
|
||||||
For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it.
|
|
||||||
|
|
||||||
Thank you.
|
|
||||||
|
|
||||||
|
For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
|
||||||
|
|
||||||
### What did this pull request do?
|
### What did this pull request do?
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
The fantastic ORM library for Golang, aims to be developer friendly.
|
The fantastic ORM library for Golang, aims to be developer friendly.
|
||||||
|
|
||||||
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||||
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
|
[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
|
||||||
|
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
|
||||||
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
|
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
@ -31,7 +32,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||||
|
|
||||||
## Supporting the project
|
## Supporting the project
|
||||||
|
|
||||||
[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu)
|
[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu)
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||||
}
|
}
|
||||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||||
|
@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
|
||||||
} else {
|
} else {
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
for _, foreignKey := range relationship.ForeignDBNames {
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
|
|
@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) {
|
||||||
DB.Save(&category)
|
DB.Save(&category)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSkipSaveAssociation(t *testing.T) {
|
func TestAutoSaveBelongsToAssociation(t *testing.T) {
|
||||||
type Company struct {
|
type Company struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
|
@ -895,13 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
CompanyID uint
|
CompanyID uint
|
||||||
Company Company `gorm:"save_associations:false"`
|
Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Where("name = ?", "auto_save_association").Delete(&Company{})
|
||||||
DB.AutoMigrate(&Company{}, &User{})
|
DB.AutoMigrate(&Company{}, &User{})
|
||||||
|
|
||||||
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
|
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}})
|
||||||
|
|
||||||
if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
|
if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() {
|
||||||
t.Errorf("Company skip_save_association should not been saved")
|
t.Errorf("Company auto_save_association should not have been saved when autosave is false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if foreign key is set, this should be saved even if association isn't
|
||||||
|
company := Company{Name: "auto_save_association"}
|
||||||
|
DB.Save(&company)
|
||||||
|
|
||||||
|
company.Name = "auto_save_association_new_name"
|
||||||
|
user := User{Name: "jinzhu", Company: company}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should not have been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("User's foreign key should have been saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}}
|
||||||
|
DB.Set("gorm:association_autocreate", true).Save(&user2)
|
||||||
|
if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company auto_save_association_2 should been created when autocreate is true")
|
||||||
|
}
|
||||||
|
|
||||||
|
user2.Company.Name = "auto_save_association_2_newname"
|
||||||
|
DB.Set("gorm:association_autoupdate", true).Save(&user2)
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should been updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoSaveHasOneAssociation(t *testing.T) {
|
||||||
|
type Company struct {
|
||||||
|
gorm.Model
|
||||||
|
UserID uint
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{})
|
||||||
|
DB.AutoMigrate(&Company{}, &User{})
|
||||||
|
|
||||||
|
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}})
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false")
|
||||||
|
}
|
||||||
|
|
||||||
|
company := Company{Name: "auto_save_has_one_association"}
|
||||||
|
DB.Save(&company)
|
||||||
|
|
||||||
|
company.Name = "auto_save_has_one_association_new_name"
|
||||||
|
user := User{Name: "jinzhu", Company: company}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should not have been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should not have been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Company.UserID == 0 {
|
||||||
|
t.Errorf("UserID should be assigned")
|
||||||
|
}
|
||||||
|
|
||||||
|
company.Name = "auto_save_has_one_association_2_new_name"
|
||||||
|
DB.Set("gorm:association_autoupdate", true).Save(&user)
|
||||||
|
|
||||||
|
if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}}
|
||||||
|
DB.Set("gorm:association_autocreate", true).Save(&user2)
|
||||||
|
if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoSaveMany2ManyAssociation(t *testing.T) {
|
||||||
|
type Company struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.AutoMigrate(&Company{}, &User{})
|
||||||
|
|
||||||
|
DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}})
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false")
|
||||||
|
}
|
||||||
|
|
||||||
|
company := Company{Name: "auto_save_m2m_association"}
|
||||||
|
DB.Save(&company)
|
||||||
|
|
||||||
|
company.Name = "auto_save_m2m_association_new_name"
|
||||||
|
user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should not have been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should not been created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Companies").Count() != 1 {
|
||||||
|
t.Errorf("Relationship should been saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user)
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company should been created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Companies").Count() != 2 {
|
||||||
|
t.Errorf("Relationship should been updated")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
12
callback.go
12
callback.go
|
@ -1,8 +1,6 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import "log"
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultCallback default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{}
|
var DefaultCallback = &Callback{}
|
||||||
|
@ -95,7 +93,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
fmt.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
||||||
cp.before = "gorm:row_query"
|
cp.before = "gorm:row_query"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,7 +107,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
@ -122,7 +120,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
// scope.SetColumn("Updated", now)
|
// scope.SetColumn("Updated", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
|
@ -161,7 +159,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||||
}
|
}
|
||||||
allNames = append(allNames, cp.name)
|
allNames = append(allNames, cp.name)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,10 @@ func init() {
|
||||||
|
|
||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
func queryCallback(scope *Scope) {
|
func queryCallback(scope *Scope) {
|
||||||
|
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -10,6 +10,9 @@ import (
|
||||||
|
|
||||||
// preloadCallback used to preload associations
|
// preloadCallback used to preload associations
|
||||||
func preloadCallback(scope *Scope) {
|
func preloadCallback(scope *Scope) {
|
||||||
|
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
||||||
autoPreload(scope)
|
autoPreload(scope)
|
||||||
|
@ -324,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
|
||||||
|
|
||||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||||
|
|
||||||
|
scope.New(elem.Addr().Interface()).
|
||||||
|
InstanceSet("gorm:skip_query_callback", true).
|
||||||
|
callCallbacks(scope.db.parent.callbacks.queries)
|
||||||
|
|
||||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||||
// generate hashed forkey keys in join table
|
// generate hashed forkey keys in join table
|
||||||
for idx, joinTableField := range joinTableFields {
|
for idx, joinTableField := range joinTableFields {
|
||||||
|
|
101
callback_save.go
101
callback_save.go
|
@ -1,6 +1,9 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "reflect"
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
func beginTransactionCallback(scope *Scope) {
|
func beginTransactionCallback(scope *Scope) {
|
||||||
scope.Begin()
|
scope.Begin()
|
||||||
|
@ -10,25 +13,74 @@ func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||||
scope.CommitOrRollback()
|
scope.CommitOrRollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
|
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
|
||||||
|
checkTruth := func(value interface{}) bool {
|
||||||
|
if v, ok := value.(bool); ok && !v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := value.(string); ok {
|
||||||
|
v = strings.ToLower(v)
|
||||||
|
if v == "false" || v != "skip" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") {
|
if r = field.Relationship; r != nil {
|
||||||
if relationship := field.Relationship; relationship != nil {
|
autoUpdate, autoCreate, saveReference = true, true, true
|
||||||
return true, relationship
|
|
||||||
|
if value, ok := scope.Get("gorm:save_associations"); ok {
|
||||||
|
autoUpdate = checkTruth(value)
|
||||||
|
autoCreate = autoUpdate
|
||||||
|
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
|
||||||
|
autoUpdate = checkTruth(value)
|
||||||
|
autoCreate = autoUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
|
||||||
|
autoUpdate = checkTruth(value)
|
||||||
|
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
|
||||||
|
autoUpdate = checkTruth(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := scope.Get("gorm:association_autocreate"); ok {
|
||||||
|
autoCreate = checkTruth(value)
|
||||||
|
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
|
||||||
|
autoCreate = checkTruth(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := scope.Get("gorm:association_save_reference"); ok {
|
||||||
|
saveReference = checkTruth(value)
|
||||||
|
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
|
||||||
|
saveReference = checkTruth(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, nil
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||||
if !scope.shouldSaveAssociations() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" {
|
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||||
|
|
||||||
|
if relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
fieldValue := field.Field.Addr().Interface()
|
fieldValue := field.Field.Addr().Interface()
|
||||||
|
newScope := scope.New(fieldValue)
|
||||||
|
|
||||||
|
if newScope.PrimaryKeyZero() {
|
||||||
|
if autoCreate {
|
||||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||||
|
}
|
||||||
|
} else if autoUpdate {
|
||||||
|
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if saveReference {
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
// set value's foreign key
|
// set value's foreign key
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
@ -41,14 +93,13 @@ func saveBeforeAssociationsCallback(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func saveAfterAssociationsCallback(scope *Scope) {
|
func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
if !scope.shouldSaveAssociations() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if ok, relationship := saveFieldAsAssociation(scope, field); ok &&
|
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
|
||||||
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
|
||||||
|
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||||
value := field.Field
|
value := field.Field
|
||||||
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
|
@ -58,6 +109,7 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
elem := value.Index(i).Addr().Interface()
|
elem := value.Index(i).Addr().Interface()
|
||||||
newScope := newDB.NewScope(elem)
|
newScope := newDB.NewScope(elem)
|
||||||
|
|
||||||
|
if saveReference {
|
||||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
|
@ -70,16 +122,27 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
if relationship.PolymorphicType != "" {
|
if relationship.PolymorphicType != "" {
|
||||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newScope.PrimaryKeyZero() {
|
||||||
|
if autoCreate {
|
||||||
scope.Err(newDB.Save(elem).Error)
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
}
|
||||||
|
} else if autoUpdate {
|
||||||
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
|
||||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
elem := value.Addr().Interface()
|
elem := value.Addr().Interface()
|
||||||
newScope := scope.New(elem)
|
newScope := scope.New(elem)
|
||||||
|
|
||||||
|
if saveReference {
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
|
@ -92,8 +155,16 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
if relationship.PolymorphicType != "" {
|
if relationship.PolymorphicType != "" {
|
||||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newScope.PrimaryKeyZero() {
|
||||||
|
if autoCreate {
|
||||||
|
scope.Err(scope.NewDB().Save(elem).Error)
|
||||||
|
}
|
||||||
|
} else if autoUpdate {
|
||||||
scope.Err(scope.NewDB().Save(elem).Error)
|
scope.Err(scope.NewDB().Save(elem).Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,7 +60,16 @@ func updateCallback(scope *Scope) {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
for column, value := range updateAttrs.(map[string]interface{}) {
|
// Sort the column names so that the generated SQL is the same every time.
|
||||||
|
updateMap := updateAttrs.(map[string]interface{})
|
||||||
|
var columns []string
|
||||||
|
for c := range updateMap {
|
||||||
|
columns = append(columns, c)
|
||||||
|
}
|
||||||
|
sort.Strings(columns)
|
||||||
|
|
||||||
|
for _, column := range columns {
|
||||||
|
value := updateMap[column]
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -279,3 +279,68 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
||||||
t.Errorf("should preload discount from coupon")
|
t.Errorf("should preload discount from coupon")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SelfReferencingUser struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||||
|
DB.AutoMigrate(&SelfReferencingUser{})
|
||||||
|
|
||||||
|
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
||||||
|
if err := DB.Create(&friend1).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
|
||||||
|
if err := DB.Create(&friend2).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := SelfReferencingUser{
|
||||||
|
Name: "self_m2m",
|
||||||
|
Friends: []*SelfReferencingUser{&friend1, &friend2},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Friends").Count() != 2 {
|
||||||
|
t.Errorf("Should find created friends correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
var newUser = SelfReferencingUser{}
|
||||||
|
|
||||||
|
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newUser.Friends) != 2 {
|
||||||
|
t.Errorf("Should preload created frineds for self reference m2m")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
|
||||||
|
if DB.Model(&user).Association("Friends").Count() != 3 {
|
||||||
|
t.Errorf("Should find created friends correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
|
||||||
|
if DB.Model(&user).Association("Friends").Count() != 1 {
|
||||||
|
t.Errorf("Should find created friends correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
friend := SelfReferencingUser{}
|
||||||
|
DB.Model(&newUser).Association("Friends").Find(&friend)
|
||||||
|
if friend.Name != "friend4_m2m" {
|
||||||
|
t.Errorf("Should find created friends correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&newUser).Association("Friends").Delete(friend)
|
||||||
|
if DB.Model(&user).Association("Friends").Count() != 0 {
|
||||||
|
t.Errorf("All friends should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
14
dialect.go
14
dialect.go
|
@ -33,6 +33,8 @@ type Dialect interface {
|
||||||
HasTable(tableName string) bool
|
HasTable(tableName string) bool
|
||||||
// HasColumn check has column or not
|
// HasColumn check has column or not
|
||||||
HasColumn(tableName string, columnName string) bool
|
HasColumn(tableName string, columnName string) bool
|
||||||
|
// ModifyColumn modify column's type
|
||||||
|
ModifyColumn(tableName string, columnName string, typ string) error
|
||||||
|
|
||||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||||
|
@ -41,8 +43,8 @@ type Dialect interface {
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
|
|
||||||
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
BuildForeignKeyName(tableName, field, dest string) string
|
BuildKeyName(kind, tableName string, fields ...string) string
|
||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
|
@ -114,3 +116,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
||||||
|
|
||||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
|
||||||
|
if strings.Contains(tableName, ".") {
|
||||||
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
|
return splitStrings[0], splitStrings[1]
|
||||||
|
}
|
||||||
|
return dialect.CurrentDatabase(), tableName
|
||||||
|
}
|
||||||
|
|
|
@ -38,6 +38,13 @@ func (commonDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
|
||||||
|
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
return strings.ToLower(value) != "false"
|
||||||
|
}
|
||||||
|
return field.IsPrimaryKey
|
||||||
|
}
|
||||||
|
|
||||||
func (s *commonDialect) DataTypeOf(field *StructField) string {
|
func (s *commonDialect) DataTypeOf(field *StructField) string {
|
||||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||||
|
|
||||||
|
@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
sqlType = "BOOLEAN"
|
sqlType = "BOOLEAN"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
sqlType = "INTEGER AUTO_INCREMENT"
|
sqlType = "INTEGER AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "INTEGER"
|
sqlType = "INTEGER"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
sqlType = "BIGINT AUTO_INCREMENT"
|
sqlType = "BIGINT AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "BIGINT"
|
sqlType = "BIGINT"
|
||||||
|
@ -92,7 +99,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
||||||
|
|
||||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,16 +115,23 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo
|
||||||
|
|
||||||
func (s commonDialect) HasTable(tableName string) bool {
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s commonDialect) CurrentDatabase() (name string) {
|
func (s commonDialect) CurrentDatabase() (name string) {
|
||||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
|
@ -144,9 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
|
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
||||||
|
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
|
||||||
return keyName
|
return keyName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "tinyint AUTO_INCREMENT"
|
sqlType = "tinyint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint"
|
sqlType = "tinyint"
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int16, reflect.Int32:
|
case reflect.Int, reflect.Int16, reflect.Int32:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "int AUTO_INCREMENT"
|
sqlType = "int AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Uint8:
|
case reflect.Uint8:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "tinyint unsigned"
|
sqlType = "tinyint unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "int unsigned AUTO_INCREMENT"
|
sqlType = "int unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int unsigned"
|
sqlType = "int unsigned"
|
||||||
}
|
}
|
||||||
case reflect.Int64:
|
case reflect.Int64:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "bigint AUTO_INCREMENT"
|
sqlType = "bigint AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "bigint"
|
sqlType = "bigint"
|
||||||
}
|
}
|
||||||
case reflect.Uint64:
|
case reflect.Uint64:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||||
} else {
|
} else {
|
||||||
|
@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
precision := ""
|
||||||
|
if p, ok := field.TagSettings["PRECISION"]; ok {
|
||||||
|
precision = fmt.Sprintf("(%s)", p)
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
||||||
sqlType = "timestamp"
|
sqlType = fmt.Sprintf("timestamp%v", precision)
|
||||||
} else {
|
} else {
|
||||||
sqlType = "timestamp NULL"
|
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -127,6 +132,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||||
if limit != nil {
|
if limit != nil {
|
||||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||||
|
@ -144,7 +154,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||||
|
|
||||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,8 +168,8 @@ func (mysql) SelectFromDummyTable() string {
|
||||||
return "FROM DUAL"
|
return "FROM DUAL"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
|
||||||
if utf8.RuneCountInString(keyName) <= 64 {
|
if utf8.RuneCountInString(keyName) <= 64 {
|
||||||
return keyName
|
return keyName
|
||||||
}
|
}
|
||||||
|
@ -166,8 +177,8 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
||||||
h.Write([]byte(keyName))
|
h.Write([]byte(keyName))
|
||||||
bs := h.Sum(nil)
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
// sha1 is 40 digits, keep first 24 characters of destination
|
// sha1 is 40 characters, keep first 24 characters of destination
|
||||||
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
|
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
|
||||||
if len(destRunes) > 24 {
|
if len(destRunes) > 24 {
|
||||||
destRunes = destRunes[:24]
|
destRunes = destRunes[:24]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -13,6 +14,7 @@ type postgres struct {
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterDialect("postgres", &postgres{})
|
RegisterDialect("postgres", &postgres{})
|
||||||
|
RegisterDialect("cloudsqlpostgres", &postgres{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) GetName() string {
|
func (postgres) GetName() string {
|
||||||
|
@ -31,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
sqlType = "boolean"
|
sqlType = "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "serial"
|
sqlType = "serial"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "bigserial"
|
sqlType = "bigserial"
|
||||||
} else {
|
} else {
|
||||||
|
@ -67,9 +69,15 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
||||||
default:
|
default:
|
||||||
if IsByteArrayOrSlice(dataValue) {
|
if IsByteArrayOrSlice(dataValue) {
|
||||||
sqlType = "bytea"
|
sqlType = "bytea"
|
||||||
} else if isUUID(dataValue) {
|
|
||||||
|
if isUUID(dataValue) {
|
||||||
sqlType = "uuid"
|
sqlType = "uuid"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isJSON(dataValue) {
|
||||||
|
sqlType = "jsonb"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +93,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
|
||||||
|
|
||||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,13 +105,13 @@ func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
|
||||||
func (s postgres) HasTable(tableName string) bool {
|
func (s postgres) HasTable(tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,3 +136,8 @@ func isUUID(value reflect.Value) bool {
|
||||||
lower := strings.ToLower(typename)
|
lower := strings.ToLower(typename)
|
||||||
return "uuid" == lower || "guid" == lower
|
return "uuid" == lower || "guid" == lower
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isJSON(value reflect.Value) bool {
|
||||||
|
_, ok := value.Interface().(json.RawMessage)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
|
@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
sqlType = "bool"
|
sqlType = "bool"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "integer"
|
sqlType = "integer"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "integer primary key autoincrement"
|
sqlType = "integer primary key autoincrement"
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) Quote(key string) string {
|
func (mssql) Quote(key string) string {
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
return fmt.Sprintf(`[%s]`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
||||||
|
@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
sqlType = "bit"
|
sqlType = "bit"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "int IDENTITY(1,1)"
|
sqlType = "int IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
sqlType = "int"
|
sqlType = "int"
|
||||||
}
|
}
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
if s.fieldCanAutoIncrement(field) {
|
||||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
sqlType = "bigint IDENTITY(1,1)"
|
sqlType = "bigint IDENTITY(1,1)"
|
||||||
} else {
|
} else {
|
||||||
|
@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
||||||
|
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
return value != "FALSE"
|
||||||
|
}
|
||||||
|
return field.IsPrimaryKey
|
||||||
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
|
@ -128,16 +135,23 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
|
||||||
func (s mssql) HasTable(tableName string) bool {
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase() (name string) {
|
func (s mssql) CurrentDatabase() (name string) {
|
||||||
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
return
|
return
|
||||||
|
@ -168,3 +182,11 @@ func (mssql) SelectFromDummyTable() string {
|
||||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
||||||
|
if strings.Contains(tableName, ".") {
|
||||||
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
|
return splitStrings[0], splitStrings[1]
|
||||||
|
}
|
||||||
|
return dialect.CurrentDatabase(), tableName
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,9 @@ import (
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/lib/pq/hstore"
|
"github.com/lib/pq/hstore"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Hstore map[string]*string
|
type Hstore map[string]*string
|
||||||
|
@ -52,3 +55,26 @@ func (h *Hstore) Scan(value interface{}) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Jsonb Postgresql's JSONB data type
|
||||||
|
type Jsonb struct {
|
||||||
|
json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value get value of Jsonb
|
||||||
|
func (j Jsonb) Value() (driver.Value, error) {
|
||||||
|
if len(j.RawMessage) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return j.MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value into Jsonb
|
||||||
|
func (j *Jsonb) Scan(value interface{}) error {
|
||||||
|
bytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(bytes, j)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
version: '3'
|
||||||
|
|
||||||
|
services:
|
||||||
|
mysql:
|
||||||
|
image: 'mysql:latest'
|
||||||
|
ports:
|
||||||
|
- 9910:3306
|
||||||
|
environment:
|
||||||
|
- MYSQL_DATABASE=gorm
|
||||||
|
- MYSQL_USER=gorm
|
||||||
|
- MYSQL_PASSWORD=gorm
|
||||||
|
- MYSQL_RANDOM_ROOT_PASSWORD="yes"
|
||||||
|
postgres:
|
||||||
|
image: 'postgres:latest'
|
||||||
|
ports:
|
||||||
|
- 9920:5432
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=gorm
|
||||||
|
- POSTGRES_DB=gorm
|
||||||
|
- POSTGRES_PASSWORD=gorm
|
||||||
|
mssql:
|
||||||
|
image: 'mcmoe/mssqldocker:latest'
|
||||||
|
ports:
|
||||||
|
- 9930:1433
|
||||||
|
environment:
|
||||||
|
- ACCEPT_EULA=Y
|
||||||
|
- SA_PASSWORD=LoremIpsum86
|
||||||
|
- MSSQL_DB=gorm
|
||||||
|
- MSSQL_USER=gorm
|
||||||
|
- MSSQL_PASSWORD=LoremIpsum86
|
|
@ -29,6 +29,10 @@ func (errs Errors) GetErrors() []error {
|
||||||
// Add adds an error
|
// Add adds an error
|
||||||
func (errs Errors) Add(newErrors ...error) Errors {
|
func (errs Errors) Add(newErrors ...error) Errors {
|
||||||
for _, err := range newErrors {
|
for _, err := range newErrors {
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if errors, ok := err.(Errors); ok {
|
if errors, ok := err.(Errors); ok {
|
||||||
errs = errs.Add(errors...)
|
errs = errs.Add(errors...)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -82,38 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string {
|
||||||
return s.TableName
|
return s.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
|
||||||
values := map[string]interface{}{}
|
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
scope := db.NewScope(source)
|
scope := db.NewScope(source)
|
||||||
modelType := scope.GetModelStruct().ModelType
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
|
||||||
if s.Source.ModelType == modelType {
|
for _, joinTableSource := range joinTableSources {
|
||||||
for _, foreignKey := range s.Source.ForeignKeys {
|
if joinTableSource.ModelType == modelType {
|
||||||
|
for _, foreignKey := range joinTableSource.ForeignKeys {
|
||||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
values[foreignKey.DBName] = field.Field.Interface()
|
conditionMap[foreignKey.DBName] = field.Field.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if s.Destination.ModelType == modelType {
|
break
|
||||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
||||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
|
||||||
values[foreignKey.DBName] = field.Field.Interface()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add create relationship in join table for source and destination
|
// Add create relationship in join table for source and destination
|
||||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||||
scope := db.NewScope("")
|
var (
|
||||||
searchMap := s.getSearchMap(db, source, destination)
|
scope = db.NewScope("")
|
||||||
|
conditionMap = map[string]interface{}{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Update condition map for source
|
||||||
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
|
||||||
|
|
||||||
|
// Update condition map for destination
|
||||||
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
|
||||||
|
|
||||||
var assignColumns, binVars, conditions []string
|
var assignColumns, binVars, conditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
for key, value := range searchMap {
|
for key, value := range conditionMap {
|
||||||
assignColumns = append(assignColumns, scope.Quote(key))
|
assignColumns = append(assignColumns, scope.Quote(key))
|
||||||
binVars = append(binVars, `?`)
|
binVars = append(binVars, `?`)
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
@ -144,9 +146,12 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
scope = db.NewScope(nil)
|
scope = db.NewScope(nil)
|
||||||
conditions []string
|
conditions []string
|
||||||
values []interface{}
|
values []interface{}
|
||||||
|
conditionMap = map[string]interface{}{}
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value := range s.getSearchMap(db, sources...) {
|
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
|
||||||
|
|
||||||
|
for key, value := range conditionMap {
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
46
main.go
46
main.go
|
@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||||
|
|
||||||
// First find first record that match given conditions, order by primary key
|
// First find first record that match given conditions, order by primary key
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
|
|
||||||
// Last find last record that match given conditions, order by primary key
|
// Last find last record that match given conditions, order by primary key
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
|
|
||||||
// Find find records that match given conditions
|
// Find find records that match given conditions
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scan value to a struct
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Row return `*sql.Row` with given conditions
|
// Row return `*sql.Row` with given conditions
|
||||||
|
@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) {
|
||||||
// ScanRows scan `*sql.Rows` to give struct
|
// ScanRows scan `*sql.Rows` to give struct
|
||||||
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||||
var (
|
var (
|
||||||
clone = s.clone()
|
scope = s.NewScope(result)
|
||||||
scope = clone.NewScope(result)
|
clone = scope.db
|
||||||
columns, err = rows.Columns()
|
columns, err = rows.Columns()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB {
|
||||||
|
|
||||||
// Related get related associations
|
// Related get related associations
|
||||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
return s.NewScope(s.Value).related(value, foreignKeys...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||||
|
@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
|
|
||||||
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.NewScope(s.Value).
|
||||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callbacks.updates).db
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
|
@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
|
|
||||||
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.NewScope(s.Value).
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
Set("gorm:save_associations", false).
|
Set("gorm:save_associations", false).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
|
@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
|
|
||||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.NewScope(value)
|
||||||
if !scope.PrimaryKeyZero() {
|
if !scope.PrimaryKeyZero() {
|
||||||
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
||||||
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
||||||
|
@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB {
|
||||||
|
|
||||||
// Create insert the value into database
|
// Create insert the value into database
|
||||||
func (s *DB) Create(value interface{}) *DB {
|
func (s *DB) Create(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.NewScope(value)
|
||||||
return scope.callCallbacks(s.parent.callbacks.creates).db
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||||
|
@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||||
|
|
||||||
// Exec execute raw sql
|
// Exec execute raw sql
|
||||||
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||||
scope := s.clone().NewScope(nil)
|
scope := s.NewScope(nil)
|
||||||
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
||||||
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||||
scope.Raw(generatedSQL)
|
scope.Raw(generatedSQL)
|
||||||
|
@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB {
|
||||||
|
|
||||||
// NewRecord check if value's primary key is blank
|
// NewRecord check if value's primary key is blank
|
||||||
func (s *DB) NewRecord(value interface{}) bool {
|
func (s *DB) NewRecord(value interface{}) bool {
|
||||||
return s.clone().NewScope(value).PrimaryKeyZero()
|
return s.NewScope(value).PrimaryKeyZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordNotFound check if returning ErrRecordNotFound error
|
// RecordNotFound check if returning ErrRecordNotFound error
|
||||||
|
@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
func (s *DB) HasTable(value interface{}) bool {
|
func (s *DB) HasTable(value interface{}) bool {
|
||||||
var (
|
var (
|
||||||
scope = s.clone().NewScope(value)
|
scope = s.NewScope(value)
|
||||||
tableName string
|
tableName string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||||
|
|
||||||
// ModifyColumn modify column to type
|
// ModifyColumn modify column to type
|
||||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.modifyColumn(column, typ)
|
scope.modifyColumn(column, typ)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropColumn drop a column
|
// DropColumn drop a column
|
||||||
func (s *DB) DropColumn(column string) *DB {
|
func (s *DB) DropColumn(column string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.dropColumn(column)
|
scope.dropColumn(column)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||||
|
|
||||||
// RemoveIndex remove index with name
|
// RemoveIndex remove index with name
|
||||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.removeIndex(indexName)
|
scope.removeIndex(indexName)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
@ -606,11 +606,19 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
// AddForeignKey Add foreign key to the given scope, e.g:
|
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||||
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.NewScope(s.Value)
|
||||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveForeignKey Remove foreign key from the given scope, e.g:
|
||||||
|
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
|
||||||
|
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
scope.removeForeignKey(field, dest)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
||||||
func (s *DB) Association(column string) *Association {
|
func (s *DB) Association(column string) *Association {
|
||||||
var err error
|
var err error
|
||||||
|
|
30
main_test.go
30
main_test.go
|
@ -36,27 +36,20 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenTestConnection() (db *gorm.DB, err error) {
|
func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
|
dbDSN := os.Getenv("GORM_DSN")
|
||||||
switch os.Getenv("GORM_DIALECT") {
|
switch os.Getenv("GORM_DIALECT") {
|
||||||
case "mysql":
|
case "mysql":
|
||||||
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
|
|
||||||
// CREATE DATABASE gorm;
|
|
||||||
// GRANT ALL ON gorm.* TO 'gorm'@'localhost';
|
|
||||||
fmt.Println("testing mysql...")
|
fmt.Println("testing mysql...")
|
||||||
dbhost := os.Getenv("GORM_DBADDRESS")
|
if dbDSN == "" {
|
||||||
if dbhost != "" {
|
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
|
||||||
dbhost = fmt.Sprintf("tcp(%v)", dbhost)
|
|
||||||
}
|
}
|
||||||
db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost))
|
db, err = gorm.Open("mysql", dbDSN)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
fmt.Println("testing postgres...")
|
fmt.Println("testing postgres...")
|
||||||
dbhost := os.Getenv("GORM_DBHOST")
|
if dbDSN == "" {
|
||||||
if dbhost != "" {
|
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||||
dbhost = fmt.Sprintf("host=%v ", dbhost)
|
|
||||||
}
|
}
|
||||||
db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost))
|
db, err = gorm.Open("postgres", dbDSN)
|
||||||
case "foundation":
|
|
||||||
fmt.Println("testing foundation...")
|
|
||||||
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
|
|
||||||
case "mssql":
|
case "mssql":
|
||||||
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
|
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
|
||||||
// CREATE DATABASE gorm;
|
// CREATE DATABASE gorm;
|
||||||
|
@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
// CREATE USER gorm FROM LOGIN gorm;
|
// CREATE USER gorm FROM LOGIN gorm;
|
||||||
// sp_changedbowner 'gorm';
|
// sp_changedbowner 'gorm';
|
||||||
fmt.Println("testing mssql...")
|
fmt.Println("testing mssql...")
|
||||||
db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
|
if dbDSN == "" {
|
||||||
|
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||||
|
}
|
||||||
|
db, err = gorm.Open("mssql", dbDSN)
|
||||||
default:
|
default:
|
||||||
fmt.Println("testing sqlite3...")
|
fmt.Println("testing sqlite3...")
|
||||||
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
||||||
|
@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
|
|
||||||
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
|
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
|
||||||
// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
|
// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
|
||||||
if os.Getenv("DEBUG") == "true" {
|
if debug := os.Getenv("DEBUG"); debug == "true" {
|
||||||
db.LogMode(true)
|
db.LogMode(true)
|
||||||
|
} else if debug == "false" {
|
||||||
|
db.LogMode(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.DB().SetMaxIdleConns(10)
|
db.DB().SetMaxIdleConns(10)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -432,3 +433,24 @@ func TestMultipleIndexes(t *testing.T) {
|
||||||
t.Error("MultipleIndexes unique index failed")
|
t.Error("MultipleIndexes unique index failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModifyColumnType(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
|
||||||
|
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModifyColumnType struct {
|
||||||
|
gorm.Model
|
||||||
|
Name1 string `gorm:"length:100"`
|
||||||
|
Name2 string `gorm:"length:200"`
|
||||||
|
}
|
||||||
|
DB.DropTable(&ModifyColumnType{})
|
||||||
|
DB.CreateTable(&ModifyColumnType{})
|
||||||
|
|
||||||
|
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
|
||||||
|
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
|
||||||
|
|
||||||
|
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
foreignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||||
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
|
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
associationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
||||||
|
@ -264,6 +266,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||||
relationship.Kind = "many_to_many"
|
relationship.Kind = "many_to_many"
|
||||||
|
|
||||||
|
{ // Foreign Keys for Source
|
||||||
|
joinTableDBNames := []string{}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
joinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// if no foreign keys defined with tag
|
// if no foreign keys defined with tag
|
||||||
if len(foreignKeys) == 0 {
|
if len(foreignKeys) == 0 {
|
||||||
for _, field := range modelStruct.PrimaryFields {
|
for _, field := range modelStruct.PrimaryFields {
|
||||||
|
@ -271,15 +280,29 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, foreignKey := range foreignKeys {
|
for idx, foreignKey := range foreignKeys {
|
||||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
// source foreign keys (db names)
|
// source foreign keys (db names)
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
||||||
// join table foreign keys for source
|
|
||||||
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
// setup join table foreign keys for source
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
if len(joinTableDBNames) > idx {
|
||||||
|
// if defined join table's foreign key
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||||
|
} else {
|
||||||
|
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // Foreign Keys for Association (Destination)
|
||||||
|
associationJoinTableDBNames := []string{}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
associationJoinTableDBNames = strings.Split(foreignKey, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// if no association foreign keys defined with tag
|
// if no association foreign keys defined with tag
|
||||||
if len(associationForeignKeys) == 0 {
|
if len(associationForeignKeys) == 0 {
|
||||||
|
@ -288,15 +311,22 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range associationForeignKeys {
|
for idx, name := range associationForeignKeys {
|
||||||
if field, ok := toScope.FieldByName(name); ok {
|
if field, ok := toScope.FieldByName(name); ok {
|
||||||
// association foreign keys (db names)
|
// association foreign keys (db names)
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||||
|
|
||||||
|
// setup join table foreign keys for association
|
||||||
|
if len(associationJoinTableDBNames) > idx {
|
||||||
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||||
|
} else {
|
||||||
// join table foreign keys for association
|
// join table foreign keys for association
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
joinTableHandler := JoinTableHandler{}
|
joinTableHandler := JoinTableHandler{}
|
||||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||||
|
@ -399,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
)
|
)
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
tagForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" {
|
||||||
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
|
} else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
|
|
|
@ -1627,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreloadManyToManyCallbacks(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Level2 struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
Level1 struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.DropTableIfExists("level1_level2s")
|
||||||
|
DB.DropTableIfExists(new(Level1))
|
||||||
|
DB.DropTableIfExists(new(Level2))
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lvl := Level1{
|
||||||
|
Name: "l1",
|
||||||
|
Level2s: []Level2{
|
||||||
|
Level2{Name: "l2-1"}, Level2{Name: "l2-2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DB.Save(&lvl)
|
||||||
|
|
||||||
|
called := 0
|
||||||
|
|
||||||
|
DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) {
|
||||||
|
called = called + 1
|
||||||
|
})
|
||||||
|
|
||||||
|
DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID)
|
||||||
|
|
||||||
|
if called != 3 {
|
||||||
|
t.Errorf("Wanted callback to be called 3 times but got %d", called)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toJSONString(v interface{}) []byte {
|
func toJSONString(v interface{}) []byte {
|
||||||
r, _ := json.MarshalIndent(v, "", " ")
|
r, _ := json.MarshalIndent(v, "", " ")
|
||||||
return r
|
return r
|
||||||
|
|
|
@ -389,7 +389,7 @@ func TestOffset(t *testing.T) {
|
||||||
DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
|
DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
|
||||||
}
|
}
|
||||||
var users1, users2, users3, users4 []User
|
var users1, users2, users3, users4 []User
|
||||||
DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||||
|
|
||||||
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||||
t.Errorf("Offset should work")
|
t.Errorf("Offset should work")
|
||||||
|
|
39
scope.go
39
scope.go
|
@ -998,18 +998,6 @@ func (scope *Scope) changeableField(field *Field) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) shouldSaveAssociations() bool {
|
|
||||||
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok {
|
|
||||||
if v, ok := saveAssociations.(bool); ok && !v {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if v, ok := saveAssociations.(string); ok && (v != "skip") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true && !scope.HasError()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
toScope := scope.db.NewScope(value)
|
toScope := scope.db.NewScope(value)
|
||||||
tx := scope.db.Set("gorm:association:source", scope.Value)
|
tx := scope.db.Set("gorm:association:source", scope.Value)
|
||||||
|
@ -1144,7 +1132,7 @@ func (scope *Scope) dropTable() *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropColumn(column string) {
|
func (scope *Scope) dropColumn(column string) {
|
||||||
|
@ -1170,7 +1158,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||||
keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
|
// Compatible with old generated key
|
||||||
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
|
||||||
|
|
||||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
return
|
return
|
||||||
|
@ -1179,6 +1168,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) removeForeignKey(field string, dest string) {
|
||||||
|
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest)
|
||||||
|
|
||||||
|
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
|
||||||
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(indexName string) {
|
func (scope *Scope) removeIndex(indexName string) {
|
||||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||||
}
|
}
|
||||||
|
@ -1214,7 +1213,7 @@ func (scope *Scope) autoIndex() *Scope {
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "INDEX" || name == "" {
|
if name == "INDEX" || name == "" {
|
||||||
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
indexes[name] = append(indexes[name], field.DBName)
|
indexes[name] = append(indexes[name], field.DBName)
|
||||||
}
|
}
|
||||||
|
@ -1225,7 +1224,7 @@ func (scope *Scope) autoIndex() *Scope {
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "UNIQUE_INDEX" || name == "" {
|
if name == "UNIQUE_INDEX" || name == "" {
|
||||||
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
||||||
}
|
}
|
||||||
|
@ -1233,11 +1232,15 @@ func (scope *Scope) autoIndex() *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, columns := range indexes {
|
for name, columns := range indexes {
|
||||||
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
|
||||||
|
scope.db.AddError(db.Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, columns := range uniqueIndexes {
|
for name, columns := range uniqueIndexes {
|
||||||
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
|
||||||
|
scope.db.AddError(db.Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scope
|
return scope
|
||||||
|
|
|
@ -2,7 +2,6 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type search struct {
|
type search struct {
|
||||||
|
@ -73,13 +72,7 @@ func (s *search) Order(value interface{}, reorder ...bool) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
var distinctSQLRegexp = regexp.MustCompile(`(?i)distinct[^a-z]+[a-z]+`)
|
|
||||||
|
|
||||||
func (s *search) Select(query interface{}, args ...interface{}) *search {
|
func (s *search) Select(query interface{}, args ...interface{}) *search {
|
||||||
if distinctSQLRegexp.MatchString(fmt.Sprint(query)) {
|
|
||||||
s.ignoreOrderQuery = true
|
|
||||||
}
|
|
||||||
|
|
||||||
s.selects = map[string]interface{}{"query": query, "args": args}
|
s.selects = map[string]interface{}{"query": query, "args": args}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
dialects=("postgres" "mysql" "mssql" "sqlite")
|
dialects=("postgres" "mysql" "mssql" "sqlite")
|
||||||
|
|
||||||
for dialect in "${dialects[@]}" ; do
|
for dialect in "${dialects[@]}" ; do
|
||||||
GORM_DIALECT=${dialect} go test
|
DEBUG=false GORM_DIALECT=${dialect} go test
|
||||||
done
|
done
|
||||||
|
|
2
utils.go
2
utils.go
|
@ -23,7 +23,7 @@ var NowFunc = func() time.Time {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copied from golint
|
// Copied from golint
|
||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
var commonInitialismsReplacer *strings.Replacer
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
|
||||||
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`)
|
||||||
|
|
115
wercker.yml
115
wercker.yml
|
@ -2,17 +2,79 @@
|
||||||
box: golang
|
box: golang
|
||||||
|
|
||||||
services:
|
services:
|
||||||
- id: mariadb:10.0
|
- name: mariadb
|
||||||
|
id: mariadb:latest
|
||||||
env:
|
env:
|
||||||
MYSQL_DATABASE: gorm
|
MYSQL_DATABASE: gorm
|
||||||
MYSQL_USER: gorm
|
MYSQL_USER: gorm
|
||||||
MYSQL_PASSWORD: gorm
|
MYSQL_PASSWORD: gorm
|
||||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
- id: postgres
|
- name: mysql
|
||||||
|
id: mysql:8
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- name: mysql57
|
||||||
|
id: mysql:5.7
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- name: mysql56
|
||||||
|
id: mysql:5.6
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- name: mysql55
|
||||||
|
id: mysql:5.5
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- name: postgres
|
||||||
|
id: postgres:latest
|
||||||
env:
|
env:
|
||||||
POSTGRES_USER: gorm
|
POSTGRES_USER: gorm
|
||||||
POSTGRES_PASSWORD: gorm
|
POSTGRES_PASSWORD: gorm
|
||||||
POSTGRES_DB: gorm
|
POSTGRES_DB: gorm
|
||||||
|
- name: postgres96
|
||||||
|
id: postgres:9.6
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: gorm
|
||||||
|
POSTGRES_PASSWORD: gorm
|
||||||
|
POSTGRES_DB: gorm
|
||||||
|
- name: postgres95
|
||||||
|
id: postgres:9.5
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: gorm
|
||||||
|
POSTGRES_PASSWORD: gorm
|
||||||
|
POSTGRES_DB: gorm
|
||||||
|
- name: postgres94
|
||||||
|
id: postgres:9.4
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: gorm
|
||||||
|
POSTGRES_PASSWORD: gorm
|
||||||
|
POSTGRES_DB: gorm
|
||||||
|
- name: postgres93
|
||||||
|
id: postgres:9.3
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: gorm
|
||||||
|
POSTGRES_PASSWORD: gorm
|
||||||
|
POSTGRES_DB: gorm
|
||||||
|
- name: mssql
|
||||||
|
id: mcmoe/mssqldocker:latest
|
||||||
|
env:
|
||||||
|
ACCEPT_EULA: Y
|
||||||
|
SA_PASSWORD: LoremIpsum86
|
||||||
|
MSSQL_DB: gorm
|
||||||
|
MSSQL_USER: gorm
|
||||||
|
MSSQL_PASSWORD: LoremIpsum86
|
||||||
|
|
||||||
# The steps that will be executed in the build pipeline
|
# The steps that will be executed in the build pipeline
|
||||||
build:
|
build:
|
||||||
|
@ -42,12 +104,57 @@ build:
|
||||||
code: |
|
code: |
|
||||||
go test ./...
|
go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mariadb
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test mysql
|
name: test mysql
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./...
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mysql5.7
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mysql5.6
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mysql5.5
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./...
|
||||||
|
|
||||||
- script:
|
- script:
|
||||||
name: test postgres
|
name: test postgres
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...
|
GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test postgres96
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test postgres95
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test postgres94
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test postgres93
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mssql
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...
|
||||||
|
|
Loading…
Reference in New Issue