forked from mirror/gorm
Rename test files
This commit is contained in:
parent
9d57c6b961
commit
d08894cd69
|
@ -18,10 +18,9 @@ func queryCallback(scope *Scope) {
|
|||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice bool
|
||||
isPtr bool
|
||||
results = scope.IndirectValue()
|
||||
resultType reflect.Type
|
||||
isSlice, isPtr bool
|
||||
resultType reflect.Type
|
||||
results = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
|
|
|
@ -17,13 +17,14 @@ func saveBeforeAssociationsCallback(scope *Scope) {
|
|||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
value := field.Field
|
||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
||||
fieldValue := field.Field.Addr().Interface()
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
// set value's foreign key
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
|
||||
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -65,7 +66,7 @@ func saveAfterAssociationsCallback(scope *Scope) {
|
|||
scope.Err(newDB.Save(elem).Error)
|
||||
|
||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -92,7 +92,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
|||
|
||||
// We must fetch the value again, to have the default fields updated
|
||||
// (We can't do this in the update statements, since sql default can be expressions
|
||||
// And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
|
||||
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
|
||||
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
|
||||
|
||||
if an.Name != "galeone" {
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
package gorm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDdlErrors(t *testing.T) {
|
||||
var err error
|
||||
|
||||
if err = DB.Close(); err != nil {
|
||||
t.Errorf("Closing DDL test db connection err=%s", err)
|
||||
}
|
||||
defer func() {
|
||||
// Reopen DB connection.
|
||||
if DB, err = OpenTestConnection(); err != nil {
|
||||
t.Fatalf("Failed re-opening db connection: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := DB.Find(&User{}).Error; err == nil {
|
||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||
}
|
||||
}
|
42
main.go
42
main.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DB contains information for current db connection
|
||||
|
@ -221,7 +222,7 @@ func (s *DB) Unscoped() *DB {
|
|||
return s.clone().search.unscoped().db
|
||||
}
|
||||
|
||||
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||
return s.clone().search.Attrs(attrs...).db
|
||||
}
|
||||
|
@ -299,7 +300,7 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
|||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||
}
|
||||
|
||||
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
|
||||
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
|
@ -660,3 +661,40 @@ func (s *DB) GetErrors() (errors []error) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Private Methods For *gorm.DB
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (s *DB) clone() *DB {
|
||||
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||
|
||||
for key, value := range s.values {
|
||||
db.values[key] = value
|
||||
}
|
||||
|
||||
if s.search == nil {
|
||||
db.search = &search{limit: -1, offset: -1}
|
||||
} else {
|
||||
db.search = s.search.clone()
|
||||
}
|
||||
|
||||
db.search.db = &db
|
||||
return &db
|
||||
}
|
||||
|
||||
func (s *DB) print(v ...interface{}) {
|
||||
s.logger.(logger).Print(v...)
|
||||
}
|
||||
|
||||
func (s *DB) log(v ...interface{}) {
|
||||
if s != nil && s.logMode == 2 {
|
||||
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||
if s.logMode == 2 {
|
||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
func (s *DB) clone() *DB {
|
||||
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||
|
||||
for key, value := range s.values {
|
||||
db.values[key] = value
|
||||
}
|
||||
|
||||
if s.search == nil {
|
||||
db.search = &search{limit: -1, offset: -1}
|
||||
} else {
|
||||
db.search = s.search.clone()
|
||||
}
|
||||
|
||||
db.search.db = &db
|
||||
return &db
|
||||
}
|
||||
|
||||
func (s *DB) print(v ...interface{}) {
|
||||
s.logger.(logger).Print(v...)
|
||||
}
|
||||
|
||||
func (s *DB) log(v ...interface{}) {
|
||||
if s != nil && s.logMode == 2 {
|
||||
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||
if s.logMode == 2 {
|
||||
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
|
||||
}
|
||||
}
|
18
main_test.go
18
main_test.go
|
@ -714,6 +714,24 @@ func TestOpenExistingDB(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDdlErrors(t *testing.T) {
|
||||
var err error
|
||||
|
||||
if err = DB.Close(); err != nil {
|
||||
t.Errorf("Closing DDL test db connection err=%s", err)
|
||||
}
|
||||
defer func() {
|
||||
// Reopen DB connection.
|
||||
if DB, err = OpenTestConnection(); err != nil {
|
||||
t.Fatalf("Failed re-opening db connection: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := DB.Find(&User{}).Error; err == nil {
|
||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
|
|
|
@ -1,11 +1,226 @@
|
|||
package gorm_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id int64
|
||||
Age int64
|
||||
UserNum Num
|
||||
Name string `sql:"size:255"`
|
||||
Birthday time.Time // Time
|
||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||
Emails []Email // Embedded structs
|
||||
BillingAddress Address // Embedded struct
|
||||
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct's foreign key
|
||||
CreditCard CreditCard
|
||||
Latitude float64
|
||||
Languages []Language `gorm:"many2many:user_languages;"`
|
||||
CompanyID *int
|
||||
Company Company
|
||||
Role
|
||||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
IgnoreStringSlice []string `sql:"-"`
|
||||
Ignored struct{ Name string } `sql:"-"`
|
||||
IgnoredPointer *User `sql:"-"`
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
ID int8
|
||||
Number string
|
||||
UserId sql.NullInt64
|
||||
CreatedAt time.Time `sql:"not null"`
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Id int16
|
||||
UserId int
|
||||
Email string `sql:"type:varchar(100);"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
ID int
|
||||
Address1 string
|
||||
Address2 string
|
||||
Post string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Language struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Users []User `gorm:"many2many:user_languages;"`
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
Id int64
|
||||
Code string
|
||||
Price int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
AfterFindCallTimes int64
|
||||
BeforeCreateCallTimes int64
|
||||
AfterCreateCallTimes int64
|
||||
BeforeUpdateCallTimes int64
|
||||
AfterUpdateCallTimes int64
|
||||
BeforeSaveCallTimes int64
|
||||
AfterSaveCallTimes int64
|
||||
BeforeDeleteCallTimes int64
|
||||
AfterDeleteCallTimes int64
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
Id int64
|
||||
Name string
|
||||
Owner *User `sql:"-"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (role *Role) Scan(value interface{}) error {
|
||||
if b, ok := value.([]uint8); ok {
|
||||
role.Name = string(b)
|
||||
} else {
|
||||
role.Name = value.(string)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (role Role) Value() (driver.Value, error) {
|
||||
return role.Name, nil
|
||||
}
|
||||
|
||||
func (role Role) IsAdmin() bool {
|
||||
return role.Name == "admin"
|
||||
}
|
||||
|
||||
type Num int64
|
||||
|
||||
func (i *Num) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
case int64:
|
||||
*i = Num(s)
|
||||
default:
|
||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Animal struct {
|
||||
Counter uint64 `gorm:"primary_key:yes"`
|
||||
Name string `sql:"DEFAULT:'galeone'"`
|
||||
From string //test reserved sql keyword as field name
|
||||
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
||||
unexported string // unexported value
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type JoinTable struct {
|
||||
From uint64
|
||||
To uint64
|
||||
Time time.Time `sql:"default: null"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
Id int64
|
||||
CategoryId sql.NullInt64
|
||||
MainCategoryId int64
|
||||
Title string
|
||||
Body string
|
||||
Comments []*Comment
|
||||
Category Category
|
||||
MainCategory Category
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
type Comment struct {
|
||||
gorm.Model
|
||||
PostId int64
|
||||
Content string
|
||||
Post Post
|
||||
}
|
||||
|
||||
// Scanner
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString `sql:"not null"`
|
||||
Gender *sql.NullString `sql:"not null"`
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
AddedAt NullTime
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
nt.Valid = false
|
||||
return nil
|
||||
}
|
||||
nt.Time, nt.Valid = value.(time.Time), true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
func getPreparedUser(name string, role string) *User {
|
||||
var company Company
|
||||
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||
|
||||
return &User{
|
||||
Name: name,
|
||||
Age: 20,
|
||||
Role: Role{role},
|
||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||
Emails: []Email{
|
||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||
},
|
||||
Company: company,
|
||||
Languages: []Language{
|
||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func runMigration() {
|
||||
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
||||
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
||||
|
|
2
model.go
2
model.go
|
@ -2,7 +2,7 @@ package gorm
|
|||
|
||||
import "time"
|
||||
|
||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
|
|
721
scope.go
721
scope.go
|
@ -2,10 +2,13 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"reflect"
|
||||
)
|
||||
|
@ -406,11 +409,15 @@ func (scope *Scope) SelectAttrs() []string {
|
|||
return *scope.selectAttrs
|
||||
}
|
||||
|
||||
// OmitAttrs return omited attributes
|
||||
// OmitAttrs return omitted attributes
|
||||
func (scope *Scope) OmitAttrs() []string {
|
||||
return scope.Search.omits
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Private Methods For *gorm.Scope
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
||||
var values = make([]interface{}, len(columns))
|
||||
var ignored interface{}
|
||||
|
@ -441,3 +448,715 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]
|
|||
}
|
||||
}
|
||||
}
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
|
||||
}
|
||||
|
||||
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
// if string is number
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
return scope.primaryCondition(scope.AddToVars(value))
|
||||
} else if value != "" {
|
||||
str = fmt.Sprintf("(%v)", value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return scope.primaryCondition(scope.AddToVars(value))
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
|
||||
clause["args"] = []interface{}{value}
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
if value != nil {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||
} else {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsIgnored && !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
if bytes, ok := arg.([]byte); ok {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
} else {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||
}
|
||||
default:
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
var notEqualSQL string
|
||||
var primaryKey = scope.PrimaryKey()
|
||||
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
// is number
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||
} else {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||
if reflect.ValueOf(value).Len() > 0 {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
||||
clause["args"] = []interface{}{value}
|
||||
}
|
||||
return ""
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
if value != nil {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||
} else {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
if bytes, ok := arg.([]byte); ok {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
} else {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||
}
|
||||
default:
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
str = value
|
||||
case []string:
|
||||
str = strings.Join(value, ", ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice:
|
||||
values := reflect.ValueOf(arg)
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
default:
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) whereSQL() (sql string) {
|
||||
var (
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryConditions, andConditions, orConditions []string
|
||||
)
|
||||
|
||||
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
||||
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
||||
if !scope.PrimaryKeyZero() {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.orConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
orConditions = append(orConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.notConditions {
|
||||
if sql := scope.buildNotCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
orSQL := strings.Join(orConditions, " OR ")
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
if len(orSQL) > 0 {
|
||||
combinedSQL = combinedSQL + " OR " + orSQL
|
||||
}
|
||||
} else {
|
||||
combinedSQL = orSQL
|
||||
}
|
||||
|
||||
if len(primaryConditions) > 0 {
|
||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
sql = sql + " AND (" + combinedSQL + ")"
|
||||
}
|
||||
} else if len(combinedSQL) > 0 {
|
||||
sql = "WHERE " + combinedSQL
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) selectSQL() string {
|
||||
if len(scope.Search.selects) == 0 {
|
||||
if len(scope.Search.joinConditions) > 0 {
|
||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||
}
|
||||
return "*"
|
||||
}
|
||||
return scope.buildSelectQuery(scope.Search.selects)
|
||||
}
|
||||
|
||||
func (scope *Scope) orderSQL() string {
|
||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||
return ""
|
||||
}
|
||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||
}
|
||||
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSQL() string {
|
||||
if len(scope.Search.group) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " GROUP BY " + scope.Search.group
|
||||
}
|
||||
|
||||
func (scope *Scope) havingSQL() string {
|
||||
if len(scope.Search.havingConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var andConditions []string
|
||||
for _, clause := range scope.Search.havingConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return " HAVING " + combinedSQL
|
||||
}
|
||||
|
||||
func (scope *Scope) joinsSQL() string {
|
||||
var joinConditions []string
|
||||
for _, clause := range scope.Search.joinConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(joinConditions, " ") + " "
|
||||
}
|
||||
|
||||
func (scope *Scope) prepareQuerySQL() {
|
||||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
||||
if len(values) > 0 {
|
||||
scope.Search.Where(values[0], values[1:]...)
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||
for _, f := range funcs {
|
||||
(*f)(scope)
|
||||
if scope.skipLeft {
|
||||
break
|
||||
}
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||
return values, true
|
||||
}
|
||||
|
||||
results = map[string]interface{}{}
|
||||
for key, value := range values {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
field.Set(value)
|
||||
if field.IsNormal {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||
}
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
||||
dest := reflect.Indirect(reflect.ValueOf(value))
|
||||
scope.Search.Select(column)
|
||||
if dest.Kind() != reflect.Slice {
|
||||
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
||||
return scope
|
||||
}
|
||||
|
||||
rows, err := scope.rows()
|
||||
if scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
elem := reflect.New(dest.Type().Elem()).Interface()
|
||||
scope.Err(rows.Scan(elem))
|
||||
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
|
||||
}
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) count(value interface{}) *Scope {
|
||||
scope.Search.Select("count(*)")
|
||||
scope.Search.countingQuery = true
|
||||
scope.Err(scope.row().Scan(value))
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) typeName() string {
|
||||
typ := scope.IndirectValue().Type()
|
||||
|
||||
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
return typ.Name()
|
||||
}
|
||||
|
||||
// trace print sql log
|
||||
func (scope *Scope) trace(t time.Time) {
|
||||
if len(scope.SQL) > 0 {
|
||||
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) changeableField(field *Field) bool {
|
||||
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, attr := range scope.OmitAttrs() {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (scope *Scope) shouldSaveAssociations() bool {
|
||||
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
||||
return false
|
||||
}
|
||||
return true && !scope.HasError()
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
toScope := scope.db.NewScope(value)
|
||||
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
fromField, _ := scope.FieldByName(foreignKey)
|
||||
toField, _ := toScope.FieldByName(foreignKey)
|
||||
|
||||
if fromField != nil {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
scope.Err(query.Find(value).Error)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
scope.Err(query.Find(value).Error)
|
||||
}
|
||||
} else {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||
}
|
||||
return scope
|
||||
} else if toField != nil {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
||||
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
|
||||
return scope
|
||||
}
|
||||
|
||||
// getTableOptions return the table options string or an empty string if the table options does not exist
|
||||
func (scope *Scope) getTableOptions() string {
|
||||
tableOptions, ok := scope.Get("gorm:table_options")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tableOptions.(string)
|
||||
}
|
||||
|
||||
func (scope *Scope) createJoinTable(field *StructField) {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
joinTable := joinTableHandler.Table(scope.db)
|
||||
if !scope.Dialect().HasTable(joinTable) {
|
||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||
|
||||
var sqlTypes, primaryKeys []string
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
||||
}
|
||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) createTable() *Scope {
|
||||
var tags []string
|
||||
var primaryKeys []string
|
||||
var primaryKeyInColumnType = false
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
|
||||
// Check if the primary key constraint was specified as
|
||||
// part of the column type. If so, we can only support
|
||||
// one column as the primary key.
|
||||
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
|
||||
primaryKeyInColumnType = true
|
||||
}
|
||||
|
||||
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
|
||||
var primaryKeyStr string
|
||||
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
|
||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
||||
|
||||
scope.autoIndex()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) dropTable() *Scope {
|
||||
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) dropColumn(column string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||
return
|
||||
}
|
||||
|
||||
var columns []string
|
||||
for _, name := range column {
|
||||
columns = append(columns, scope.quoteIfPossible(name))
|
||||
}
|
||||
|
||||
sqlCreate := "CREATE INDEX"
|
||||
if unique {
|
||||
sqlCreate = "CREATE UNIQUE INDEX"
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||
|
||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||
return
|
||||
}
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||
}
|
||||
|
||||
func (scope *Scope) autoMigrate() *Scope {
|
||||
tableName := scope.TableName()
|
||||
quotedTableName := scope.QuotedTableName()
|
||||
|
||||
if !scope.Dialect().HasTable(tableName) {
|
||||
scope.createTable()
|
||||
} else {
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
scope.autoIndex()
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) autoIndex() *Scope {
|
||||
var indexes = map[string][]string{}
|
||||
var uniqueIndexes = map[string][]string{}
|
||||
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if name, ok := field.TagSettings["INDEX"]; ok {
|
||||
if name == "INDEX" {
|
||||
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
||||
}
|
||||
indexes[name] = append(indexes[name], field.DBName)
|
||||
}
|
||||
|
||||
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
||||
if name == "UNIQUE_INDEX" {
|
||||
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
||||
}
|
||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for name, columns := range indexes {
|
||||
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
||||
}
|
||||
|
||||
for name, columns := range uniqueIndexes {
|
||||
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
||||
}
|
||||
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
for _, value := range values {
|
||||
indirectValue := reflect.ValueOf(value)
|
||||
for indirectValue.Kind() == reflect.Ptr {
|
||||
indirectValue = indirectValue.Elem()
|
||||
}
|
||||
|
||||
switch indirectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < indirectValue.Len(); i++ {
|
||||
var result []interface{}
|
||||
var object = indirect(indirectValue.Index(i))
|
||||
for _, column := range columns {
|
||||
result = append(result, object.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, indirectValue.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||
indirectScopeValue := scope.IndirectValue()
|
||||
|
||||
switch indirectScopeValue.Kind() {
|
||||
case reflect.Slice:
|
||||
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||
fieldType := fieldStruct.Type
|
||||
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||
|
||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||
|
||||
if result.Kind() == reflect.Slice {
|
||||
for j := 0; j < result.Len(); j++ {
|
||||
if elem := result.Index(j); elem.CanAddr() {
|
||||
results = reflect.Append(results, elem.Addr())
|
||||
}
|
||||
}
|
||||
} else if result.CanAddr() {
|
||||
results = reflect.Append(results, result.Addr())
|
||||
}
|
||||
}
|
||||
return scope.New(results.Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||
return scope.New(field.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
663
scope_private.go
663
scope_private.go
|
@ -1,663 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
|
||||
}
|
||||
|
||||
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
// if string is number
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
return scope.primaryCondition(scope.AddToVars(value))
|
||||
} else if value != "" {
|
||||
str = fmt.Sprintf("(%v)", value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return scope.primaryCondition(scope.AddToVars(value))
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
|
||||
clause["args"] = []interface{}{value}
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
if value != nil {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||
} else {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsIgnored && !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
if bytes, ok := arg.([]byte); ok {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
} else {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||
}
|
||||
default:
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
var notEqualSQL string
|
||||
var primaryKey = scope.PrimaryKey()
|
||||
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
// is number
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||
} else {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||
if reflect.ValueOf(value).Len() > 0 {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
||||
clause["args"] = []interface{}{value}
|
||||
}
|
||||
return ""
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
if value != nil {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||
} else {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
if bytes, ok := arg.([]byte); ok {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
} else {
|
||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
||||
}
|
||||
default:
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
str = value
|
||||
case []string:
|
||||
str = strings.Join(value, ", ")
|
||||
}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
for _, arg := range args {
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice:
|
||||
values := reflect.ValueOf(arg)
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
||||
default:
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) whereSQL() (sql string) {
|
||||
var (
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryConditions, andConditions, orConditions []string
|
||||
)
|
||||
|
||||
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
|
||||
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
||||
if !scope.PrimaryKeyZero() {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.orConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
orConditions = append(orConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
for _, clause := range scope.Search.notConditions {
|
||||
if sql := scope.buildNotCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
orSQL := strings.Join(orConditions, " OR ")
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
if len(orSQL) > 0 {
|
||||
combinedSQL = combinedSQL + " OR " + orSQL
|
||||
}
|
||||
} else {
|
||||
combinedSQL = orSQL
|
||||
}
|
||||
|
||||
if len(primaryConditions) > 0 {
|
||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
sql = sql + " AND (" + combinedSQL + ")"
|
||||
}
|
||||
} else if len(combinedSQL) > 0 {
|
||||
sql = "WHERE " + combinedSQL
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) selectSQL() string {
|
||||
if len(scope.Search.selects) == 0 {
|
||||
if len(scope.Search.joinConditions) > 0 {
|
||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||
}
|
||||
return "*"
|
||||
}
|
||||
return scope.buildSelectQuery(scope.Search.selects)
|
||||
}
|
||||
|
||||
func (scope *Scope) orderSQL() string {
|
||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||
return ""
|
||||
}
|
||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||
}
|
||||
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSQL() string {
|
||||
if len(scope.Search.group) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " GROUP BY " + scope.Search.group
|
||||
}
|
||||
|
||||
func (scope *Scope) havingSQL() string {
|
||||
if len(scope.Search.havingConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var andConditions []string
|
||||
for _, clause := range scope.Search.havingConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return " HAVING " + combinedSQL
|
||||
}
|
||||
|
||||
func (scope *Scope) joinsSQL() string {
|
||||
var joinConditions []string
|
||||
for _, clause := range scope.Search.joinConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(joinConditions, " ") + " "
|
||||
}
|
||||
|
||||
func (scope *Scope) prepareQuerySQL() {
|
||||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
|
||||
if len(values) > 0 {
|
||||
scope.Search.Where(values[0], values[1:]...)
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||
for _, f := range funcs {
|
||||
(*f)(scope)
|
||||
if scope.skipLeft {
|
||||
break
|
||||
}
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
|
||||
if scope.IndirectValue().Kind() != reflect.Struct {
|
||||
return values, true
|
||||
}
|
||||
|
||||
results = map[string]interface{}{}
|
||||
for key, value := range values {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
field.Set(value)
|
||||
if field.IsNormal {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
for _, clause := range scope.Search.whereConditions {
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
|
||||
}
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
|
||||
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
||||
dest := reflect.Indirect(reflect.ValueOf(value))
|
||||
scope.Search.Select(column)
|
||||
if dest.Kind() != reflect.Slice {
|
||||
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
||||
return scope
|
||||
}
|
||||
|
||||
rows, err := scope.rows()
|
||||
if scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
elem := reflect.New(dest.Type().Elem()).Interface()
|
||||
scope.Err(rows.Scan(elem))
|
||||
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
|
||||
}
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) count(value interface{}) *Scope {
|
||||
scope.Search.Select("count(*)")
|
||||
scope.Search.countingQuery = true
|
||||
scope.Err(scope.row().Scan(value))
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) typeName() string {
|
||||
typ := scope.IndirectValue().Type()
|
||||
|
||||
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
return typ.Name()
|
||||
}
|
||||
|
||||
// trace print sql log
|
||||
func (scope *Scope) trace(t time.Time) {
|
||||
if len(scope.SQL) > 0 {
|
||||
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) changeableField(field *Field) bool {
|
||||
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
|
||||
for _, attr := range selectAttrs {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, attr := range scope.OmitAttrs() {
|
||||
if field.Name == attr || field.DBName == attr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (scope *Scope) shouldSaveAssociations() bool {
|
||||
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
|
||||
return false
|
||||
}
|
||||
return true && !scope.HasError()
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
toScope := scope.db.NewScope(value)
|
||||
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
fromField, _ := scope.FieldByName(foreignKey)
|
||||
toField, _ := toScope.FieldByName(foreignKey)
|
||||
|
||||
if fromField != nil {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
scope.Err(query.Find(value).Error)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
scope.Err(query.Find(value).Error)
|
||||
}
|
||||
} else {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||
}
|
||||
return scope
|
||||
} else if toField != nil {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
||||
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
|
||||
return scope
|
||||
}
|
||||
|
||||
/**
|
||||
Return the table options string or an empty string if the table options does not exist
|
||||
*/
|
||||
func (scope *Scope) getTableOptions() string {
|
||||
tableOptions, ok := scope.Get("gorm:table_options")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tableOptions.(string)
|
||||
}
|
||||
|
||||
func (scope *Scope) createJoinTable(field *StructField) {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
joinTable := joinTableHandler.Table(scope.db)
|
||||
if !scope.Dialect().HasTable(joinTable) {
|
||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||
|
||||
var sqlTypes, primaryKeys []string
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.FieldByName(fieldName); ok {
|
||||
foreignKeyStruct := field.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
|
||||
}
|
||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) createTable() *Scope {
|
||||
var tags []string
|
||||
var primaryKeys []string
|
||||
var primaryKeyInColumnType = false
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
|
||||
// Check if the primary key constraint was specified as
|
||||
// part of the column type. If so, we can only support
|
||||
// one column as the primary key.
|
||||
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
|
||||
primaryKeyInColumnType = true
|
||||
}
|
||||
|
||||
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
|
||||
var primaryKeyStr string
|
||||
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
|
||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
||||
|
||||
scope.autoIndex()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) dropTable() *Scope {
|
||||
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) dropColumn(column string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||
return
|
||||
}
|
||||
|
||||
var columns []string
|
||||
for _, name := range column {
|
||||
columns = append(columns, scope.quoteIfPossible(name))
|
||||
}
|
||||
|
||||
sqlCreate := "CREATE INDEX"
|
||||
if unique {
|
||||
sqlCreate = "CREATE UNIQUE INDEX"
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
|
||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||
|
||||
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
|
||||
return
|
||||
}
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||
}
|
||||
|
||||
func (scope *Scope) autoMigrate() *Scope {
|
||||
tableName := scope.TableName()
|
||||
quotedTableName := scope.QuotedTableName()
|
||||
|
||||
if !scope.Dialect().HasTable(tableName) {
|
||||
scope.createTable()
|
||||
} else {
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
scope.autoIndex()
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) autoIndex() *Scope {
|
||||
var indexes = map[string][]string{}
|
||||
var uniqueIndexes = map[string][]string{}
|
||||
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if name, ok := field.TagSettings["INDEX"]; ok {
|
||||
if name == "INDEX" {
|
||||
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
|
||||
}
|
||||
indexes[name] = append(indexes[name], field.DBName)
|
||||
}
|
||||
|
||||
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
|
||||
if name == "UNIQUE_INDEX" {
|
||||
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
||||
}
|
||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for name, columns := range indexes {
|
||||
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
|
||||
}
|
||||
|
||||
for name, columns := range uniqueIndexes {
|
||||
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
|
||||
}
|
||||
|
||||
return scope
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
package gorm
|
||||
|
||||
import "reflect"
|
||||
|
||||
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||
for _, value := range values {
|
||||
indirectValue := reflect.ValueOf(value)
|
||||
for indirectValue.Kind() == reflect.Ptr {
|
||||
indirectValue = indirectValue.Elem()
|
||||
}
|
||||
|
||||
switch indirectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < indirectValue.Len(); i++ {
|
||||
var result []interface{}
|
||||
var object = indirect(indirectValue.Index(i))
|
||||
for _, column := range columns {
|
||||
result = append(result, object.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, indirectValue.FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsScope(column string) *Scope {
|
||||
indirectScopeValue := scope.IndirectValue()
|
||||
|
||||
switch indirectScopeValue.Kind() {
|
||||
case reflect.Slice:
|
||||
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
|
||||
fieldType := fieldStruct.Type
|
||||
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
|
||||
|
||||
for i := 0; i < indirectScopeValue.Len(); i++ {
|
||||
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
|
||||
|
||||
if result.Kind() == reflect.Slice {
|
||||
for j := 0; j < result.Len(); j++ {
|
||||
if elem := result.Index(j); elem.CanAddr() {
|
||||
results = reflect.Append(results, elem.Addr())
|
||||
}
|
||||
}
|
||||
} else if result.CanAddr() {
|
||||
results = reflect.Append(results, result.Addr())
|
||||
}
|
||||
}
|
||||
return scope.New(results.Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
|
||||
return scope.New(field.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
222
structs_test.go
222
structs_test.go
|
@ -1,222 +0,0 @@
|
|||
package gorm_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id int64
|
||||
Age int64
|
||||
UserNum Num
|
||||
Name string `sql:"size:255"`
|
||||
Birthday time.Time // Time
|
||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||
Emails []Email // Embedded structs
|
||||
BillingAddress Address // Embedded struct
|
||||
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct's foreign key
|
||||
CreditCard CreditCard
|
||||
Latitude float64
|
||||
Languages []Language `gorm:"many2many:user_languages;"`
|
||||
CompanyID *int
|
||||
Company Company
|
||||
Role
|
||||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
IgnoreStringSlice []string `sql:"-"`
|
||||
Ignored struct{ Name string } `sql:"-"`
|
||||
IgnoredPointer *User `sql:"-"`
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
ID int8
|
||||
Number string
|
||||
UserId sql.NullInt64
|
||||
CreatedAt time.Time `sql:"not null"`
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Id int16
|
||||
UserId int
|
||||
Email string `sql:"type:varchar(100);"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
ID int
|
||||
Address1 string
|
||||
Address2 string
|
||||
Post string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type Language struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Users []User `gorm:"many2many:user_languages;"`
|
||||
}
|
||||
|
||||
type Product struct {
|
||||
Id int64
|
||||
Code string
|
||||
Price int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
AfterFindCallTimes int64
|
||||
BeforeCreateCallTimes int64
|
||||
AfterCreateCallTimes int64
|
||||
BeforeUpdateCallTimes int64
|
||||
AfterUpdateCallTimes int64
|
||||
BeforeSaveCallTimes int64
|
||||
AfterSaveCallTimes int64
|
||||
BeforeDeleteCallTimes int64
|
||||
AfterDeleteCallTimes int64
|
||||
}
|
||||
|
||||
type Company struct {
|
||||
Id int64
|
||||
Name string
|
||||
Owner *User `sql:"-"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (role *Role) Scan(value interface{}) error {
|
||||
if b, ok := value.([]uint8); ok {
|
||||
role.Name = string(b)
|
||||
} else {
|
||||
role.Name = value.(string)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (role Role) Value() (driver.Value, error) {
|
||||
return role.Name, nil
|
||||
}
|
||||
|
||||
func (role Role) IsAdmin() bool {
|
||||
return role.Name == "admin"
|
||||
}
|
||||
|
||||
type Num int64
|
||||
|
||||
func (i *Num) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
case int64:
|
||||
*i = Num(s)
|
||||
default:
|
||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Animal struct {
|
||||
Counter uint64 `gorm:"primary_key:yes"`
|
||||
Name string `sql:"DEFAULT:'galeone'"`
|
||||
From string //test reserved sql keyword as field name
|
||||
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
||||
unexported string // unexported value
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type JoinTable struct {
|
||||
From uint64
|
||||
To uint64
|
||||
Time time.Time `sql:"default: null"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
Id int64
|
||||
CategoryId sql.NullInt64
|
||||
MainCategoryId int64
|
||||
Title string
|
||||
Body string
|
||||
Comments []*Comment
|
||||
Category Category
|
||||
MainCategory Category
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
type Comment struct {
|
||||
gorm.Model
|
||||
PostId int64
|
||||
Content string
|
||||
Post Post
|
||||
}
|
||||
|
||||
// Scanner
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString `sql:"not null"`
|
||||
Gender *sql.NullString `sql:"not null"`
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
AddedAt NullTime
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
nt.Valid = false
|
||||
return nil
|
||||
}
|
||||
nt.Time, nt.Valid = value.(time.Time), true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
func getPreparedUser(name string, role string) *User {
|
||||
var company Company
|
||||
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||
|
||||
return &User{
|
||||
Name: name,
|
||||
Age: 20,
|
||||
Role: Role{role},
|
||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||
Emails: []Email{
|
||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||
},
|
||||
Company: company,
|
||||
Languages: []Language{
|
||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||
},
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue