forked from mirror/gorm
Rename test files
This commit is contained in:
parent
9d57c6b961
commit
d08894cd69
|
@ -18,10 +18,9 @@ func queryCallback(scope *Scope) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice bool
|
isSlice, isPtr bool
|
||||||
isPtr bool
|
resultType reflect.Type
|
||||||
results = scope.IndirectValue()
|
results = scope.IndirectValue()
|
||||||
resultType reflect.Type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||||
|
|
|
@ -17,13 +17,14 @@ func saveBeforeAssociationsCallback(scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
value := field.Field
|
fieldValue := field.Field.Addr().Interface()
|
||||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||||
if len(relationship.ForeignFieldNames) != 0 {
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
// set value's foreign key
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
|
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||||
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
|
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -65,7 +66,7 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
scope.Err(newDB.Save(elem).Error)
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
|
||||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||||
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
|
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -92,7 +92,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||||
|
|
||||||
// We must fetch the value again, to have the default fields updated
|
// We must fetch the value again, to have the default fields updated
|
||||||
// (We can't do this in the update statements, since sql default can be expressions
|
// (We can't do this in the update statements, since sql default can be expressions
|
||||||
// And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
|
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
|
||||||
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
|
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
|
||||||
|
|
||||||
if an.Name != "galeone" {
|
if an.Name != "galeone" {
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDdlErrors(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if err = DB.Close(); err != nil {
|
|
||||||
t.Errorf("Closing DDL test db connection err=%s", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
// Reopen DB connection.
|
|
||||||
if DB, err = OpenTestConnection(); err != nil {
|
|
||||||
t.Fatalf("Failed re-opening db connection: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := DB.Find(&User{}).Error; err == nil {
|
|
||||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
|
||||||
}
|
|
||||||
}
|
|
42
main.go
42
main.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DB contains information for current db connection
|
// DB contains information for current db connection
|
||||||
|
@ -221,7 +222,7 @@ func (s *DB) Unscoped() *DB {
|
||||||
return s.clone().search.unscoped().db
|
return s.clone().search.unscoped().db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||||
return s.clone().search.Attrs(attrs...).db
|
return s.clone().search.Attrs(attrs...).db
|
||||||
}
|
}
|
||||||
|
@ -299,7 +300,7 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
|
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||||
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
|
@ -660,3 +661,40 @@ func (s *DB) GetErrors() (errors []error) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Private Methods For *gorm.DB
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
func (s *DB) clone() *DB {
|
||||||
|
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||||
|
|
||||||
|
for key, value := range s.values {
|
||||||
|
db.values[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.search == nil {
|
||||||
|
db.search = &search{limit: -1, offset: -1}
|
||||||
|
} else {
|
||||||
|
db.search = s.search.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
db.search.db = &db
|
||||||
|
return &db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) print(v ...interface{}) {
|
||||||
|
s.logger.(logger).Print(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) log(v ...interface{}) {
|
||||||
|
if s != nil && s.logMode == 2 {
|
||||||
|
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||||
|
if s.logMode == 2 {
|
||||||
|
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
func (s *DB) clone() *DB {
|
|
||||||
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
|
||||||
|
|
||||||
for key, value := range s.values {
|
|
||||||
db.values[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.search == nil {
|
|
||||||
db.search = &search{limit: -1, offset: -1}
|
|
||||||
} else {
|
|
||||||
db.search = s.search.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
db.search.db = &db
|
|
||||||
return &db
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) print(v ...interface{}) {
|
|
||||||
s.logger.(logger).Print(v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) log(v ...interface{}) {
|
|
||||||
if s != nil && s.logMode == 2 {
|
|
||||||
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
|
||||||
if s.logMode == 2 {
|
|
||||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
|
|
||||||
}
|
|
||||||
}
|
|
18
main_test.go
18
main_test.go
|
@ -714,6 +714,24 @@ func TestOpenExistingDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDdlErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err = DB.Close(); err != nil {
|
||||||
|
t.Errorf("Closing DDL test db connection err=%s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Reopen DB connection.
|
||||||
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
|
t.Fatalf("Failed re-opening db connection: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := DB.Find(&User{}).Error; err == nil {
|
||||||
|
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
|
|
@ -1,11 +1,226 @@
|
||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Id int64
|
||||||
|
Age int64
|
||||||
|
UserNum Num
|
||||||
|
Name string `sql:"size:255"`
|
||||||
|
Birthday time.Time // Time
|
||||||
|
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||||
|
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||||
|
Emails []Email // Embedded structs
|
||||||
|
BillingAddress Address // Embedded struct
|
||||||
|
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
||||||
|
ShippingAddress Address // Embedded struct
|
||||||
|
ShippingAddressId int64 // Embedded struct's foreign key
|
||||||
|
CreditCard CreditCard
|
||||||
|
Latitude float64
|
||||||
|
Languages []Language `gorm:"many2many:user_languages;"`
|
||||||
|
CompanyID *int
|
||||||
|
Company Company
|
||||||
|
Role
|
||||||
|
PasswordHash []byte
|
||||||
|
IgnoreMe int64 `sql:"-"`
|
||||||
|
IgnoreStringSlice []string `sql:"-"`
|
||||||
|
Ignored struct{ Name string } `sql:"-"`
|
||||||
|
IgnoredPointer *User `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreditCard struct {
|
||||||
|
ID int8
|
||||||
|
Number string
|
||||||
|
UserId sql.NullInt64
|
||||||
|
CreatedAt time.Time `sql:"not null"`
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Email struct {
|
||||||
|
Id int16
|
||||||
|
UserId int
|
||||||
|
Email string `sql:"type:varchar(100);"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ID int
|
||||||
|
Address1 string
|
||||||
|
Address2 string
|
||||||
|
Post string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Language struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Users []User `gorm:"many2many:user_languages;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Id int64
|
||||||
|
Code string
|
||||||
|
Price int64
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
AfterFindCallTimes int64
|
||||||
|
BeforeCreateCallTimes int64
|
||||||
|
AfterCreateCallTimes int64
|
||||||
|
BeforeUpdateCallTimes int64
|
||||||
|
AfterUpdateCallTimes int64
|
||||||
|
BeforeSaveCallTimes int64
|
||||||
|
AfterSaveCallTimes int64
|
||||||
|
BeforeDeleteCallTimes int64
|
||||||
|
AfterDeleteCallTimes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Company struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
Owner *User `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role *Role) Scan(value interface{}) error {
|
||||||
|
if b, ok := value.([]uint8); ok {
|
||||||
|
role.Name = string(b)
|
||||||
|
} else {
|
||||||
|
role.Name = value.(string)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) Value() (driver.Value, error) {
|
||||||
|
return role.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) IsAdmin() bool {
|
||||||
|
return role.Name == "admin"
|
||||||
|
}
|
||||||
|
|
||||||
|
type Num int64
|
||||||
|
|
||||||
|
func (i *Num) Scan(src interface{}) error {
|
||||||
|
switch s := src.(type) {
|
||||||
|
case []byte:
|
||||||
|
case int64:
|
||||||
|
*i = Num(s)
|
||||||
|
default:
|
||||||
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Animal struct {
|
||||||
|
Counter uint64 `gorm:"primary_key:yes"`
|
||||||
|
Name string `sql:"DEFAULT:'galeone'"`
|
||||||
|
From string //test reserved sql keyword as field name
|
||||||
|
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
||||||
|
unexported string // unexported value
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinTable struct {
|
||||||
|
From uint64
|
||||||
|
To uint64
|
||||||
|
Time time.Time `sql:"default: null"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
Id int64
|
||||||
|
CategoryId sql.NullInt64
|
||||||
|
MainCategoryId int64
|
||||||
|
Title string
|
||||||
|
Body string
|
||||||
|
Comments []*Comment
|
||||||
|
Category Category
|
||||||
|
MainCategory Category
|
||||||
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Comment struct {
|
||||||
|
gorm.Model
|
||||||
|
PostId int64
|
||||||
|
Content string
|
||||||
|
Post Post
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scanner
|
||||||
|
type NullValue struct {
|
||||||
|
Id int64
|
||||||
|
Name sql.NullString `sql:"not null"`
|
||||||
|
Gender *sql.NullString `sql:"not null"`
|
||||||
|
Age sql.NullInt64
|
||||||
|
Male sql.NullBool
|
||||||
|
Height sql.NullFloat64
|
||||||
|
AddedAt NullTime
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullTime struct {
|
||||||
|
Time time.Time
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
nt.Valid = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
nt.Time, nt.Valid = value.(time.Time), true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt NullTime) Value() (driver.Value, error) {
|
||||||
|
if !nt.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nt.Time, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPreparedUser(name string, role string) *User {
|
||||||
|
var company Company
|
||||||
|
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
Name: name,
|
||||||
|
Age: 20,
|
||||||
|
Role: Role{role},
|
||||||
|
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||||
|
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||||
|
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||||
|
Emails: []Email{
|
||||||
|
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||||
|
},
|
||||||
|
Company: company,
|
||||||
|
Languages: []Language{
|
||||||
|
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||||
|
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func runMigration() {
|
func runMigration() {
|
||||||
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
||||||
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
||||||
|
|
2
model.go
2
model.go
|
@ -2,7 +2,7 @@ package gorm
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
|
||||||
// type User struct {
|
// type User struct {
|
||||||
// gorm.Model
|
// gorm.Model
|
||||||
// }
|
// }
|
||||||
|
|
721
scope.go
721
scope.go
|
@ -2,10 +2,13 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
@ -406,11 +409,15 @@ func (scope *Scope) SelectAttrs() []string {
|
||||||
return *scope.selectAttrs
|
return *scope.selectAttrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// OmitAttrs return omited attributes
|
// OmitAttrs return omitted attributes
|
||||||
func (scope *Scope) OmitAttrs() []string {
|
func (scope *Scope) OmitAttrs() []string {
|
||||||
return scope.Search.omits
|
return scope.Search.omits
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Private Methods For *gorm.Scope
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
||||||
var values = make([]interface{}, len(columns))
|
var values = make([]interface{}, len(columns))
|
||||||
var ignored interface{}
|
var ignored interface{}
|
||||||
|
@ -441,3 +448,715 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||||
|
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
|
switch value := clause["query"].(type) {
|
||||||
|
case string:
|
||||||
|
// if string is number
|
||||||
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
|
return scope.primaryCondition(scope.AddToVars(value))
|
||||||
|
} else if value != "" {
|
||||||
|
str = fmt.Sprintf("(%v)", value)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
|
return scope.primaryCondition(scope.AddToVars(value))
|
||||||
|
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||||
|
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
|
||||||
|
clause["args"] = []interface{}{value}
|
||||||
|
case map[string]interface{}:
|
||||||
|
var sqls []string
|
||||||
|
for key, value := range value {
|
||||||
|
if value != nil {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||||
|
} else {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(sqls, " AND ")
|
||||||
|
case interface{}:
|
||||||
|
var sqls []string
|
||||||
|
for _, field := range scope.New(value).Fields() {
|
||||||
|
if !field.IsIgnored && !field.IsBlank {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(sqls, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := clause["args"].([]interface{})
|
||||||
|
for _, arg := range args {
|
||||||
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
|
if bytes, ok := arg.([]byte); ok {
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||||
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
|
var tempMarks []string
|
||||||
|
for i := 0; i < values.Len(); i++ {
|
||||||
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
|
}
|
||||||
|
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||||
|
} else {
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
|
arg, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
|
var notEqualSQL string
|
||||||
|
var primaryKey = scope.PrimaryKey()
|
||||||
|
|
||||||
|
switch value := clause["query"].(type) {
|
||||||
|
case string:
|
||||||
|
// is number
|
||||||
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
|
id, _ := strconv.Atoi(value)
|
||||||
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||||
|
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||||
|
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||||
|
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||||
|
} else {
|
||||||
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||||
|
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||||
|
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||||
|
if reflect.ValueOf(value).Len() > 0 {
|
||||||
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
||||||
|
clause["args"] = []interface{}{value}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
case map[string]interface{}:
|
||||||
|
var sqls []string
|
||||||
|
for key, value := range value {
|
||||||
|
if value != nil {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||||
|
} else {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(sqls, " AND ")
|
||||||
|
case interface{}:
|
||||||
|
var sqls []string
|
||||||
|
for _, field := range scope.New(value).Fields() {
|
||||||
|
if !field.IsBlank {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(sqls, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := clause["args"].([]interface{})
|
||||||
|
for _, arg := range args {
|
||||||
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
|
if bytes, ok := arg.([]byte); ok {
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||||
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
|
var tempMarks []string
|
||||||
|
for i := 0; i < values.Len(); i++ {
|
||||||
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
|
}
|
||||||
|
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||||
|
} else {
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
|
arg, _ = scanner.Value()
|
||||||
|
}
|
||||||
|
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
|
||||||
|
switch value := clause["query"].(type) {
|
||||||
|
case string:
|
||||||
|
str = value
|
||||||
|
case []string:
|
||||||
|
str = strings.Join(value, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := clause["args"].([]interface{})
|
||||||
|
for _, arg := range args {
|
||||||
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
values := reflect.ValueOf(arg)
|
||||||
|
var tempMarks []string
|
||||||
|
for i := 0; i < values.Len(); i++ {
|
||||||
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
|
}
|
||||||
|
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||||
|
default:
|
||||||
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
|
arg, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) whereSQL() (sql string) {
|
||||||
|
var (
|
||||||
|
quotedTableName = scope.QuotedTableName()
|
||||||
|
primaryConditions, andConditions, orConditions []string
|
||||||
|
)
|
||||||
|
|
||||||
|
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
||||||
|
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
||||||
|
primaryConditions = append(primaryConditions, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.PrimaryKeyZero() {
|
||||||
|
for _, field := range scope.PrimaryFields() {
|
||||||
|
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||||
|
primaryConditions = append(primaryConditions, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clause := range scope.Search.whereConditions {
|
||||||
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
andConditions = append(andConditions, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clause := range scope.Search.orConditions {
|
||||||
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
orConditions = append(orConditions, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clause := range scope.Search.notConditions {
|
||||||
|
if sql := scope.buildNotCondition(clause); sql != "" {
|
||||||
|
andConditions = append(andConditions, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orSQL := strings.Join(orConditions, " OR ")
|
||||||
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
|
if len(combinedSQL) > 0 {
|
||||||
|
if len(orSQL) > 0 {
|
||||||
|
combinedSQL = combinedSQL + " OR " + orSQL
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
combinedSQL = orSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(primaryConditions) > 0 {
|
||||||
|
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||||
|
if len(combinedSQL) > 0 {
|
||||||
|
sql = sql + " AND (" + combinedSQL + ")"
|
||||||
|
}
|
||||||
|
} else if len(combinedSQL) > 0 {
|
||||||
|
sql = "WHERE " + combinedSQL
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) selectSQL() string {
|
||||||
|
if len(scope.Search.selects) == 0 {
|
||||||
|
if len(scope.Search.joinConditions) > 0 {
|
||||||
|
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||||
|
}
|
||||||
|
return "*"
|
||||||
|
}
|
||||||
|
return scope.buildSelectQuery(scope.Search.selects)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) orderSQL() string {
|
||||||
|
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) limitAndOffsetSQL() string {
|
||||||
|
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) groupSQL() string {
|
||||||
|
if len(scope.Search.group) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " GROUP BY " + scope.Search.group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) havingSQL() string {
|
||||||
|
if len(scope.Search.havingConditions) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var andConditions []string
|
||||||
|
for _, clause := range scope.Search.havingConditions {
|
||||||
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
andConditions = append(andConditions, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
|
if len(combinedSQL) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return " HAVING " + combinedSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) joinsSQL() string {
|
||||||
|
var joinConditions []string
|
||||||
|
for _, clause := range scope.Search.joinConditions {
|
||||||
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(joinConditions, " ") + " "
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) prepareQuerySQL() {
|
||||||
|
if scope.Search.raw {
|
||||||
|
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
||||||
|
if len(values) > 0 {
|
||||||
|
scope.Search.Where(values[0], values[1:]...)
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
|
for _, f := range funcs {
|
||||||
|
(*f)(scope)
|
||||||
|
if scope.skipLeft {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||||
|
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||||
|
return values, true
|
||||||
|
}
|
||||||
|
|
||||||
|
results = map[string]interface{}{}
|
||||||
|
for key, value := range values {
|
||||||
|
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||||
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
|
if _, ok := value.(*expr); ok {
|
||||||
|
hasUpdate = true
|
||||||
|
results[field.DBName] = value
|
||||||
|
} else if !equalAsString(field.Field.Interface(), value) {
|
||||||
|
field.Set(value)
|
||||||
|
if field.IsNormal {
|
||||||
|
hasUpdate = true
|
||||||
|
results[field.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
field.Set(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) row() *sql.Row {
|
||||||
|
defer scope.trace(NowFunc())
|
||||||
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
|
scope.prepareQuerySQL()
|
||||||
|
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
|
defer scope.trace(NowFunc())
|
||||||
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
|
scope.prepareQuerySQL()
|
||||||
|
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) initialize() *Scope {
|
||||||
|
for _, clause := range scope.Search.whereConditions {
|
||||||
|
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||||
|
}
|
||||||
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||||
|
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
||||||
|
dest := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
scope.Search.Select(column)
|
||||||
|
if dest.Kind() != reflect.Slice {
|
||||||
|
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := scope.rows()
|
||||||
|
if scope.Err(err) == nil {
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
elem := reflect.New(dest.Type().Elem()).Interface()
|
||||||
|
scope.Err(rows.Scan(elem))
|
||||||
|
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) count(value interface{}) *Scope {
|
||||||
|
scope.Search.Select("count(*)")
|
||||||
|
scope.Search.countingQuery = true
|
||||||
|
scope.Err(scope.row().Scan(value))
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) typeName() string {
|
||||||
|
typ := scope.IndirectValue().Type()
|
||||||
|
|
||||||
|
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
return typ.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// trace print sql log
|
||||||
|
func (scope *Scope) trace(t time.Time) {
|
||||||
|
if len(scope.SQL) > 0 {
|
||||||
|
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) changeableField(field *Field) bool {
|
||||||
|
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
||||||
|
for _, attr := range selectAttrs {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attr := range scope.OmitAttrs() {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) shouldSaveAssociations() bool {
|
||||||
|
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true && !scope.HasError()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
|
toScope := scope.db.NewScope(value)
|
||||||
|
|
||||||
|
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||||
|
fromField, _ := scope.FieldByName(foreignKey)
|
||||||
|
toField, _ := toScope.FieldByName(foreignKey)
|
||||||
|
|
||||||
|
if fromField != nil {
|
||||||
|
if relationship := fromField.Relationship; relationship != nil {
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
|
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
||||||
|
} else if relationship.Kind == "belongs_to" {
|
||||||
|
query := toScope.db
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||||
|
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope.Err(query.Find(value).Error)
|
||||||
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
|
query := toScope.db
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||||
|
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.PolymorphicType != "" {
|
||||||
|
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||||
|
}
|
||||||
|
scope.Err(query.Find(value).Error)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||||
|
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
} else if toField != nil {
|
||||||
|
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
||||||
|
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableOptions return the table options string or an empty string if the table options does not exist
|
||||||
|
func (scope *Scope) getTableOptions() string {
|
||||||
|
tableOptions, ok := scope.Get("gorm:table_options")
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return tableOptions.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
|
if !scope.Dialect().HasTable(joinTable) {
|
||||||
|
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||||
|
|
||||||
|
var sqlTypes, primaryKeys []string
|
||||||
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
if field, ok := scope.FieldByName(fieldName); ok {
|
||||||
|
foreignKeyStruct := field.clone()
|
||||||
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||||
|
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||||
|
foreignKeyStruct := field.clone()
|
||||||
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
||||||
|
}
|
||||||
|
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) createTable() *Scope {
|
||||||
|
var tags []string
|
||||||
|
var primaryKeys []string
|
||||||
|
var primaryKeyInColumnType = false
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if field.IsNormal {
|
||||||
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
|
|
||||||
|
// Check if the primary key constraint was specified as
|
||||||
|
// part of the column type. If so, we can only support
|
||||||
|
// one column as the primary key.
|
||||||
|
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
|
||||||
|
primaryKeyInColumnType = true
|
||||||
|
}
|
||||||
|
|
||||||
|
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||||
|
}
|
||||||
|
scope.createJoinTable(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
var primaryKeyStr string
|
||||||
|
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
|
||||||
|
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
||||||
|
|
||||||
|
scope.autoIndex()
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) dropTable() *Scope {
|
||||||
|
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) dropColumn(column string) {
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
|
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var columns []string
|
||||||
|
for _, name := range column {
|
||||||
|
columns = append(columns, scope.quoteIfPossible(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlCreate := "CREATE INDEX"
|
||||||
|
if unique {
|
||||||
|
sqlCreate = "CREATE UNIQUE INDEX"
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||||
|
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
||||||
|
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||||
|
|
||||||
|
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||||
|
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) removeIndex(indexName string) {
|
||||||
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) autoMigrate() *Scope {
|
||||||
|
tableName := scope.TableName()
|
||||||
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
|
||||||
|
if !scope.Dialect().HasTable(tableName) {
|
||||||
|
scope.createTable()
|
||||||
|
} else {
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||||
|
if field.IsNormal {
|
||||||
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope.createJoinTable(field)
|
||||||
|
}
|
||||||
|
scope.autoIndex()
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) autoIndex() *Scope {
|
||||||
|
var indexes = map[string][]string{}
|
||||||
|
var uniqueIndexes = map[string][]string{}
|
||||||
|
|
||||||
|
for _, field := range scope.GetStructFields() {
|
||||||
|
if name, ok := field.TagSettings["INDEX"]; ok {
|
||||||
|
if name == "INDEX" {
|
||||||
|
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
||||||
|
}
|
||||||
|
indexes[name] = append(indexes[name], field.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
||||||
|
if name == "UNIQUE_INDEX" {
|
||||||
|
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
||||||
|
}
|
||||||
|
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, columns := range indexes {
|
||||||
|
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, columns := range uniqueIndexes {
|
||||||
|
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||||
|
for _, value := range values {
|
||||||
|
indirectValue := reflect.ValueOf(value)
|
||||||
|
for indirectValue.Kind() == reflect.Ptr {
|
||||||
|
indirectValue = indirectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch indirectValue.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < indirectValue.Len(); i++ {
|
||||||
|
var result []interface{}
|
||||||
|
var object = indirect(indirectValue.Index(i))
|
||||||
|
for _, column := range columns {
|
||||||
|
result = append(result, object.FieldByName(column).Interface())
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
var result []interface{}
|
||||||
|
for _, column := range columns {
|
||||||
|
result = append(result, indirectValue.FieldByName(column).Interface())
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||||
|
indirectScopeValue := scope.IndirectValue()
|
||||||
|
|
||||||
|
switch indirectScopeValue.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||||
|
fieldType := fieldStruct.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||||
|
|
||||||
|
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||||
|
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||||
|
|
||||||
|
if result.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < result.Len(); j++ {
|
||||||
|
if elem := result.Index(j); elem.CanAddr() {
|
||||||
|
results = reflect.Append(results, elem.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if result.CanAddr() {
|
||||||
|
results = reflect.Append(results, result.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope.New(results.Interface())
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||||
|
return scope.New(field.Addr().Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
663
scope_private.go
663
scope_private.go
|
@ -1,663 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
|
||||||
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|
||||||
switch value := clause["query"].(type) {
|
|
||||||
case string:
|
|
||||||
// if string is number
|
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
|
||||||
return scope.primaryCondition(scope.AddToVars(value))
|
|
||||||
} else if value != "" {
|
|
||||||
str = fmt.Sprintf("(%v)", value)
|
|
||||||
}
|
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
|
||||||
return scope.primaryCondition(scope.AddToVars(value))
|
|
||||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
|
||||||
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
|
|
||||||
clause["args"] = []interface{}{value}
|
|
||||||
case map[string]interface{}:
|
|
||||||
var sqls []string
|
|
||||||
for key, value := range value {
|
|
||||||
if value != nil {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
|
||||||
} else {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(sqls, " AND ")
|
|
||||||
case interface{}:
|
|
||||||
var sqls []string
|
|
||||||
for _, field := range scope.New(value).Fields() {
|
|
||||||
if !field.IsIgnored && !field.IsBlank {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(sqls, " AND ")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := clause["args"].([]interface{})
|
|
||||||
for _, arg := range args {
|
|
||||||
switch reflect.ValueOf(arg).Kind() {
|
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
||||||
if bytes, ok := arg.([]byte); ok {
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
|
||||||
var tempMarks []string
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
||||||
}
|
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
||||||
} else {
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
||||||
arg, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
|
||||||
var notEqualSQL string
|
|
||||||
var primaryKey = scope.PrimaryKey()
|
|
||||||
|
|
||||||
switch value := clause["query"].(type) {
|
|
||||||
case string:
|
|
||||||
// is number
|
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
|
||||||
id, _ := strconv.Atoi(value)
|
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
|
||||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
|
||||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
|
||||||
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
|
||||||
} else {
|
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
|
||||||
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
|
||||||
}
|
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
|
||||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
|
||||||
if reflect.ValueOf(value).Len() > 0 {
|
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
|
||||||
clause["args"] = []interface{}{value}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
case map[string]interface{}:
|
|
||||||
var sqls []string
|
|
||||||
for key, value := range value {
|
|
||||||
if value != nil {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
|
||||||
} else {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(sqls, " AND ")
|
|
||||||
case interface{}:
|
|
||||||
var sqls []string
|
|
||||||
for _, field := range scope.New(value).Fields() {
|
|
||||||
if !field.IsBlank {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(sqls, " AND ")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := clause["args"].([]interface{})
|
|
||||||
for _, arg := range args {
|
|
||||||
switch reflect.ValueOf(arg).Kind() {
|
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
||||||
if bytes, ok := arg.([]byte); ok {
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
|
||||||
var tempMarks []string
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
||||||
}
|
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
||||||
} else {
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
|
||||||
arg, _ = scanner.Value()
|
|
||||||
}
|
|
||||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
|
|
||||||
switch value := clause["query"].(type) {
|
|
||||||
case string:
|
|
||||||
str = value
|
|
||||||
case []string:
|
|
||||||
str = strings.Join(value, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := clause["args"].([]interface{})
|
|
||||||
for _, arg := range args {
|
|
||||||
switch reflect.ValueOf(arg).Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
values := reflect.ValueOf(arg)
|
|
||||||
var tempMarks []string
|
|
||||||
for i := 0; i < values.Len(); i++ {
|
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
|
||||||
}
|
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
|
||||||
default:
|
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
|
||||||
arg, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) whereSQL() (sql string) {
|
|
||||||
var (
|
|
||||||
quotedTableName = scope.QuotedTableName()
|
|
||||||
primaryConditions, andConditions, orConditions []string
|
|
||||||
)
|
|
||||||
|
|
||||||
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
|
||||||
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
|
||||||
primaryConditions = append(primaryConditions, sql)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !scope.PrimaryKeyZero() {
|
|
||||||
for _, field := range scope.PrimaryFields() {
|
|
||||||
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
|
||||||
primaryConditions = append(primaryConditions, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, clause := range scope.Search.whereConditions {
|
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
||||||
andConditions = append(andConditions, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, clause := range scope.Search.orConditions {
|
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
||||||
orConditions = append(orConditions, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, clause := range scope.Search.notConditions {
|
|
||||||
if sql := scope.buildNotCondition(clause); sql != "" {
|
|
||||||
andConditions = append(andConditions, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
orSQL := strings.Join(orConditions, " OR ")
|
|
||||||
combinedSQL := strings.Join(andConditions, " AND ")
|
|
||||||
if len(combinedSQL) > 0 {
|
|
||||||
if len(orSQL) > 0 {
|
|
||||||
combinedSQL = combinedSQL + " OR " + orSQL
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
combinedSQL = orSQL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(primaryConditions) > 0 {
|
|
||||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
|
||||||
if len(combinedSQL) > 0 {
|
|
||||||
sql = sql + " AND (" + combinedSQL + ")"
|
|
||||||
}
|
|
||||||
} else if len(combinedSQL) > 0 {
|
|
||||||
sql = "WHERE " + combinedSQL
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) selectSQL() string {
|
|
||||||
if len(scope.Search.selects) == 0 {
|
|
||||||
if len(scope.Search.joinConditions) > 0 {
|
|
||||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
|
||||||
}
|
|
||||||
return "*"
|
|
||||||
}
|
|
||||||
return scope.buildSelectQuery(scope.Search.selects)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) orderSQL() string {
|
|
||||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) limitAndOffsetSQL() string {
|
|
||||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) groupSQL() string {
|
|
||||||
if len(scope.Search.group) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return " GROUP BY " + scope.Search.group
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) havingSQL() string {
|
|
||||||
if len(scope.Search.havingConditions) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var andConditions []string
|
|
||||||
for _, clause := range scope.Search.havingConditions {
|
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
||||||
andConditions = append(andConditions, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
combinedSQL := strings.Join(andConditions, " AND ")
|
|
||||||
if len(combinedSQL) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return " HAVING " + combinedSQL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) joinsSQL() string {
|
|
||||||
var joinConditions []string
|
|
||||||
for _, clause := range scope.Search.joinConditions {
|
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
|
||||||
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(joinConditions, " ") + " "
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) prepareQuerySQL() {
|
|
||||||
if scope.Search.raw {
|
|
||||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
|
||||||
} else {
|
|
||||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
|
||||||
if len(values) > 0 {
|
|
||||||
scope.Search.Where(values[0], values[1:]...)
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|
||||||
for _, f := range funcs {
|
|
||||||
(*f)(scope)
|
|
||||||
if scope.skipLeft {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
|
||||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
|
||||||
return values, true
|
|
||||||
}
|
|
||||||
|
|
||||||
results = map[string]interface{}{}
|
|
||||||
for key, value := range values {
|
|
||||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
|
||||||
if _, ok := value.(*expr); ok {
|
|
||||||
hasUpdate = true
|
|
||||||
results[field.DBName] = value
|
|
||||||
} else if !equalAsString(field.Field.Interface(), value) {
|
|
||||||
field.Set(value)
|
|
||||||
if field.IsNormal {
|
|
||||||
hasUpdate = true
|
|
||||||
results[field.DBName] = field.Field.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
field.Set(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
|
||||||
defer scope.trace(NowFunc())
|
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
|
||||||
scope.prepareQuerySQL()
|
|
||||||
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
|
||||||
defer scope.trace(NowFunc())
|
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
|
||||||
scope.prepareQuerySQL()
|
|
||||||
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) initialize() *Scope {
|
|
||||||
for _, clause := range scope.Search.whereConditions {
|
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
|
||||||
}
|
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
|
||||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|
||||||
dest := reflect.Indirect(reflect.ValueOf(value))
|
|
||||||
scope.Search.Select(column)
|
|
||||||
if dest.Kind() != reflect.Slice {
|
|
||||||
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := scope.rows()
|
|
||||||
if scope.Err(err) == nil {
|
|
||||||
defer rows.Close()
|
|
||||||
for rows.Next() {
|
|
||||||
elem := reflect.New(dest.Type().Elem()).Interface()
|
|
||||||
scope.Err(rows.Scan(elem))
|
|
||||||
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) count(value interface{}) *Scope {
|
|
||||||
scope.Search.Select("count(*)")
|
|
||||||
scope.Search.countingQuery = true
|
|
||||||
scope.Err(scope.row().Scan(value))
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) typeName() string {
|
|
||||||
typ := scope.IndirectValue().Type()
|
|
||||||
|
|
||||||
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
return typ.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// trace print sql log
|
|
||||||
func (scope *Scope) trace(t time.Time) {
|
|
||||||
if len(scope.SQL) > 0 {
|
|
||||||
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) changeableField(field *Field) bool {
|
|
||||||
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
|
||||||
for _, attr := range selectAttrs {
|
|
||||||
if field.Name == attr || field.DBName == attr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, attr := range scope.OmitAttrs() {
|
|
||||||
if field.Name == attr || field.DBName == attr {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) shouldSaveAssociations() bool {
|
|
||||||
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true && !scope.HasError()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
|
||||||
toScope := scope.db.NewScope(value)
|
|
||||||
|
|
||||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
|
||||||
fromField, _ := scope.FieldByName(foreignKey)
|
|
||||||
toField, _ := toScope.FieldByName(foreignKey)
|
|
||||||
|
|
||||||
if fromField != nil {
|
|
||||||
if relationship := fromField.Relationship; relationship != nil {
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
|
||||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
|
||||||
} else if relationship.Kind == "belongs_to" {
|
|
||||||
query := toScope.db
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
|
||||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scope.Err(query.Find(value).Error)
|
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
|
||||||
query := toScope.db
|
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
|
||||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
|
||||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.PolymorphicType != "" {
|
|
||||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
|
||||||
}
|
|
||||||
scope.Err(query.Find(value).Error)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
|
||||||
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
} else if toField != nil {
|
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
|
||||||
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Return the table options string or an empty string if the table options does not exist
|
|
||||||
*/
|
|
||||||
func (scope *Scope) getTableOptions() string {
|
|
||||||
tableOptions, ok := scope.Get("gorm:table_options")
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return tableOptions.(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
|
||||||
if !scope.Dialect().HasTable(joinTable) {
|
|
||||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
|
||||||
|
|
||||||
var sqlTypes, primaryKeys []string
|
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
|
||||||
if field, ok := scope.FieldByName(fieldName); ok {
|
|
||||||
foreignKeyStruct := field.clone()
|
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
|
||||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
|
||||||
foreignKeyStruct := field.clone()
|
|
||||||
foreignKeyStruct.IsPrimaryKey = false
|
|
||||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
|
||||||
}
|
|
||||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) createTable() *Scope {
|
|
||||||
var tags []string
|
|
||||||
var primaryKeys []string
|
|
||||||
var primaryKeyInColumnType = false
|
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
|
||||||
if field.IsNormal {
|
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
|
||||||
|
|
||||||
// Check if the primary key constraint was specified as
|
|
||||||
// part of the column type. If so, we can only support
|
|
||||||
// one column as the primary key.
|
|
||||||
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
|
|
||||||
primaryKeyInColumnType = true
|
|
||||||
}
|
|
||||||
|
|
||||||
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
|
||||||
}
|
|
||||||
scope.createJoinTable(field)
|
|
||||||
}
|
|
||||||
|
|
||||||
var primaryKeyStr string
|
|
||||||
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
|
|
||||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
|
||||||
|
|
||||||
scope.autoIndex()
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) dropTable() *Scope {
|
|
||||||
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) dropColumn(column string) {
|
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var columns []string
|
|
||||||
for _, name := range column {
|
|
||||||
columns = append(columns, scope.quoteIfPossible(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlCreate := "CREATE INDEX"
|
|
||||||
if unique {
|
|
||||||
sqlCreate = "CREATE UNIQUE INDEX"
|
|
||||||
}
|
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
|
||||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
|
||||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
|
||||||
|
|
||||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
|
||||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(indexName string) {
|
|
||||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) autoMigrate() *Scope {
|
|
||||||
tableName := scope.TableName()
|
|
||||||
quotedTableName := scope.QuotedTableName()
|
|
||||||
|
|
||||||
if !scope.Dialect().HasTable(tableName) {
|
|
||||||
scope.createTable()
|
|
||||||
} else {
|
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
|
||||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
|
||||||
if field.IsNormal {
|
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scope.createJoinTable(field)
|
|
||||||
}
|
|
||||||
scope.autoIndex()
|
|
||||||
}
|
|
||||||
return scope
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) autoIndex() *Scope {
|
|
||||||
var indexes = map[string][]string{}
|
|
||||||
var uniqueIndexes = map[string][]string{}
|
|
||||||
|
|
||||||
for _, field := range scope.GetStructFields() {
|
|
||||||
if name, ok := field.TagSettings["INDEX"]; ok {
|
|
||||||
if name == "INDEX" {
|
|
||||||
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
|
||||||
}
|
|
||||||
indexes[name] = append(indexes[name], field.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
|
||||||
if name == "UNIQUE_INDEX" {
|
|
||||||
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
|
||||||
}
|
|
||||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, columns := range indexes {
|
|
||||||
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, columns := range uniqueIndexes {
|
|
||||||
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return scope
|
|
||||||
}
|
|
|
@ -1,67 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import "reflect"
|
|
||||||
|
|
||||||
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
|
||||||
for _, value := range values {
|
|
||||||
indirectValue := reflect.ValueOf(value)
|
|
||||||
for indirectValue.Kind() == reflect.Ptr {
|
|
||||||
indirectValue = indirectValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch indirectValue.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
for i := 0; i < indirectValue.Len(); i++ {
|
|
||||||
var result []interface{}
|
|
||||||
var object = indirect(indirectValue.Index(i))
|
|
||||||
for _, column := range columns {
|
|
||||||
result = append(result, object.FieldByName(column).Interface())
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
var result []interface{}
|
|
||||||
for _, column := range columns {
|
|
||||||
result = append(result, indirectValue.FieldByName(column).Interface())
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
|
||||||
indirectScopeValue := scope.IndirectValue()
|
|
||||||
|
|
||||||
switch indirectScopeValue.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
|
||||||
fieldType := fieldStruct.Type
|
|
||||||
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
|
||||||
fieldType = fieldType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
|
||||||
|
|
||||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
|
||||||
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
|
||||||
|
|
||||||
if result.Kind() == reflect.Slice {
|
|
||||||
for j := 0; j < result.Len(); j++ {
|
|
||||||
if elem := result.Index(j); elem.CanAddr() {
|
|
||||||
results = reflect.Append(results, elem.Addr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if result.CanAddr() {
|
|
||||||
results = reflect.Append(results, result.Addr())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return scope.New(results.Interface())
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
|
||||||
return scope.New(field.Addr().Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
222
structs_test.go
222
structs_test.go
|
@ -1,222 +0,0 @@
|
||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Id int64
|
|
||||||
Age int64
|
|
||||||
UserNum Num
|
|
||||||
Name string `sql:"size:255"`
|
|
||||||
Birthday time.Time // Time
|
|
||||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
|
||||||
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
|
||||||
Emails []Email // Embedded structs
|
|
||||||
BillingAddress Address // Embedded struct
|
|
||||||
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
|
||||||
ShippingAddress Address // Embedded struct
|
|
||||||
ShippingAddressId int64 // Embedded struct's foreign key
|
|
||||||
CreditCard CreditCard
|
|
||||||
Latitude float64
|
|
||||||
Languages []Language `gorm:"many2many:user_languages;"`
|
|
||||||
CompanyID *int
|
|
||||||
Company Company
|
|
||||||
Role
|
|
||||||
PasswordHash []byte
|
|
||||||
IgnoreMe int64 `sql:"-"`
|
|
||||||
IgnoreStringSlice []string `sql:"-"`
|
|
||||||
Ignored struct{ Name string } `sql:"-"`
|
|
||||||
IgnoredPointer *User `sql:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CreditCard struct {
|
|
||||||
ID int8
|
|
||||||
Number string
|
|
||||||
UserId sql.NullInt64
|
|
||||||
CreatedAt time.Time `sql:"not null"`
|
|
||||||
UpdatedAt time.Time
|
|
||||||
DeletedAt *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type Email struct {
|
|
||||||
Id int16
|
|
||||||
UserId int
|
|
||||||
Email string `sql:"type:varchar(100);"`
|
|
||||||
CreatedAt time.Time
|
|
||||||
UpdatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type Address struct {
|
|
||||||
ID int
|
|
||||||
Address1 string
|
|
||||||
Address2 string
|
|
||||||
Post string
|
|
||||||
CreatedAt time.Time
|
|
||||||
UpdatedAt time.Time
|
|
||||||
DeletedAt *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type Language struct {
|
|
||||||
gorm.Model
|
|
||||||
Name string
|
|
||||||
Users []User `gorm:"many2many:user_languages;"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Product struct {
|
|
||||||
Id int64
|
|
||||||
Code string
|
|
||||||
Price int64
|
|
||||||
CreatedAt time.Time
|
|
||||||
UpdatedAt time.Time
|
|
||||||
AfterFindCallTimes int64
|
|
||||||
BeforeCreateCallTimes int64
|
|
||||||
AfterCreateCallTimes int64
|
|
||||||
BeforeUpdateCallTimes int64
|
|
||||||
AfterUpdateCallTimes int64
|
|
||||||
BeforeSaveCallTimes int64
|
|
||||||
AfterSaveCallTimes int64
|
|
||||||
BeforeDeleteCallTimes int64
|
|
||||||
AfterDeleteCallTimes int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type Company struct {
|
|
||||||
Id int64
|
|
||||||
Name string
|
|
||||||
Owner *User `sql:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Role struct {
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (role *Role) Scan(value interface{}) error {
|
|
||||||
if b, ok := value.([]uint8); ok {
|
|
||||||
role.Name = string(b)
|
|
||||||
} else {
|
|
||||||
role.Name = value.(string)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (role Role) Value() (driver.Value, error) {
|
|
||||||
return role.Name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (role Role) IsAdmin() bool {
|
|
||||||
return role.Name == "admin"
|
|
||||||
}
|
|
||||||
|
|
||||||
type Num int64
|
|
||||||
|
|
||||||
func (i *Num) Scan(src interface{}) error {
|
|
||||||
switch s := src.(type) {
|
|
||||||
case []byte:
|
|
||||||
case int64:
|
|
||||||
*i = Num(s)
|
|
||||||
default:
|
|
||||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Animal struct {
|
|
||||||
Counter uint64 `gorm:"primary_key:yes"`
|
|
||||||
Name string `sql:"DEFAULT:'galeone'"`
|
|
||||||
From string //test reserved sql keyword as field name
|
|
||||||
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
|
||||||
unexported string // unexported value
|
|
||||||
CreatedAt time.Time
|
|
||||||
UpdatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type JoinTable struct {
|
|
||||||
From uint64
|
|
||||||
To uint64
|
|
||||||
Time time.Time `sql:"default: null"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Post struct {
|
|
||||||
Id int64
|
|
||||||
CategoryId sql.NullInt64
|
|
||||||
MainCategoryId int64
|
|
||||||
Title string
|
|
||||||
Body string
|
|
||||||
Comments []*Comment
|
|
||||||
Category Category
|
|
||||||
MainCategory Category
|
|
||||||
}
|
|
||||||
|
|
||||||
type Category struct {
|
|
||||||
gorm.Model
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Comment struct {
|
|
||||||
gorm.Model
|
|
||||||
PostId int64
|
|
||||||
Content string
|
|
||||||
Post Post
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scanner
|
|
||||||
type NullValue struct {
|
|
||||||
Id int64
|
|
||||||
Name sql.NullString `sql:"not null"`
|
|
||||||
Gender *sql.NullString `sql:"not null"`
|
|
||||||
Age sql.NullInt64
|
|
||||||
Male sql.NullBool
|
|
||||||
Height sql.NullFloat64
|
|
||||||
AddedAt NullTime
|
|
||||||
}
|
|
||||||
|
|
||||||
type NullTime struct {
|
|
||||||
Time time.Time
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nt *NullTime) Scan(value interface{}) error {
|
|
||||||
if value == nil {
|
|
||||||
nt.Valid = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
nt.Time, nt.Valid = value.(time.Time), true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nt NullTime) Value() (driver.Value, error) {
|
|
||||||
if !nt.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nt.Time, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPreparedUser(name string, role string) *User {
|
|
||||||
var company Company
|
|
||||||
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
|
||||||
|
|
||||||
return &User{
|
|
||||||
Name: name,
|
|
||||||
Age: 20,
|
|
||||||
Role: Role{role},
|
|
||||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
|
||||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
|
||||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
|
||||||
Emails: []Email{
|
|
||||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
|
||||||
},
|
|
||||||
Company: company,
|
|
||||||
Languages: []Language{
|
|
||||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
|
||||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue