diff --git a/callback_query.go b/callback_query.go index 08678f92..32f2e001 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,10 +18,9 @@ func queryCallback(scope *Scope) { defer scope.trace(NowFunc()) var ( - isSlice bool - isPtr bool - results = scope.IndirectValue() - resultType reflect.Type + isSlice, isPtr bool + resultType reflect.Type + results = scope.IndirectValue() ) if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { diff --git a/callback_save.go b/callback_save.go index a525b709..5ffe53b9 100644 --- a/callback_save.go +++ b/callback_save.go @@ -17,13 +17,14 @@ func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - value := field.Field - scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) + fieldValue := field.Field.Addr().Interface() + scope.Err(scope.NewDB().Save(fieldValue).Error) if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) } } } @@ -65,7 +66,7 @@ func saveAfterAssociationsCallback(scope *Scope) { scope.Err(newDB.Save(elem).Error) if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) } } default: diff --git a/callback_test.go b/callback_system_test.go similarity index 100% rename from callback_test.go rename to callback_system_test.go diff --git a/create_test.go b/create_test.go index 4d1f623d..dc82de50 100644 --- a/create_test.go +++ b/create_test.go @@ -92,7 +92,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { // We must fetch the value again, to have the default fields updated // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fiels has a default value of "now()" + // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) if an.Name != "galeone" { diff --git a/ddl_errors_test.go b/ddl_errors_test.go deleted file mode 100644 index 2c31b354..00000000 --- a/ddl_errors_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package gorm_test - -import ( - "testing" -) - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} diff --git a/main.go b/main.go index 46d64b29..a5201d48 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "time" ) // DB contains information for current db connection @@ -221,7 +222,7 @@ func (s *DB) Unscoped() *DB { return s.clone().search.unscoped().db } -// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate +// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate func (s *DB) Attrs(attrs ...interface{}) *DB { return s.clone().search.Attrs(attrs...).db } @@ -299,7 +300,7 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } -// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions) +// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) // https://jinzhu.github.io/gorm/curd.html#firstorinit func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() @@ -660,3 +661,40 @@ func (s *DB) GetErrors() (errors []error) { } return } + +//////////////////////////////////////////////////////////////////////////////// +// Private Methods For *gorm.DB +//////////////////////////////////////////////////////////////////////////////// + +func (s *DB) clone() *DB { + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + + for key, value := range s.values { + db.values[key] = value + } + + if s.search == nil { + db.search = &search{limit: -1, offset: -1} + } else { + db.search = s.search.clone() + } + + db.search.db = &db + return &db +} + +func (s *DB) print(v ...interface{}) { + s.logger.(logger).Print(v...) +} + +func (s *DB) log(v ...interface{}) { + if s != nil && s.logMode == 2 { + s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) + } +} + +func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { + if s.logMode == 2 { + s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) + } +} diff --git a/main_private.go b/main_private.go deleted file mode 100644 index a6e5a6a9..00000000 --- a/main_private.go +++ /dev/null @@ -1,36 +0,0 @@ -package gorm - -import "time" - -func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} - - for key, value := range s.values { - db.values[key] = value - } - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = &db - return &db -} - -func (s *DB) print(v ...interface{}) { - s.logger.(logger).Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) - } -} diff --git a/main_test.go b/main_test.go index 2da3de94..e9bffd0f 100644 --- a/main_test.go +++ b/main_test.go @@ -714,6 +714,24 @@ func TestOpenExistingDB(t *testing.T) { } } +func TestDdlErrors(t *testing.T) { + var err error + + if err = DB.Close(); err != nil { + t.Errorf("Closing DDL test db connection err=%s", err) + } + defer func() { + // Reopen DB connection. + if DB, err = OpenTestConnection(); err != nil { + t.Fatalf("Failed re-opening db connection: %s", err) + } + }() + + if err := DB.Find(&User{}).Error; err == nil { + t.Errorf("Expected operation on closed db to produce an error, but err was nil") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/migration_test.go b/migration_test.go index de35c1df..81fe4fd3 100644 --- a/migration_test.go +++ b/migration_test.go @@ -1,11 +1,226 @@ package gorm_test import ( + "database/sql" + "database/sql/driver" + "errors" "fmt" + "reflect" "testing" "time" + + "github.com/jinzhu/gorm" ) +type User struct { + Id int64 + Age int64 + UserNum Num + Name string `sql:"size:255"` + Birthday time.Time // Time + CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically + UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically + Emails []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressID sql.NullInt64 // Embedded struct's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct's foreign key + CreditCard CreditCard + Latitude float64 + Languages []Language `gorm:"many2many:user_languages;"` + CompanyID *int + Company Company + Role + PasswordHash []byte + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` + Ignored struct{ Name string } `sql:"-"` + IgnoredPointer *User `sql:"-"` +} + +type CreditCard struct { + ID int8 + Number string + UserId sql.NullInt64 + CreatedAt time.Time `sql:"not null"` + UpdatedAt time.Time + DeletedAt *time.Time +} + +type Email struct { + Id int16 + UserId int + Email string `sql:"type:varchar(100);"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type Address struct { + ID int + Address1 string + Address2 string + Post string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type Language struct { + gorm.Model + Name string + Users []User `gorm:"many2many:user_languages;"` +} + +type Product struct { + Id int64 + Code string + Price int64 + CreatedAt time.Time + UpdatedAt time.Time + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +type Company struct { + Id int64 + Name string + Owner *User `sql:"-"` +} + +type Role struct { + Name string +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type Animal struct { + Counter uint64 `gorm:"primary_key:yes"` + Name string `sql:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age time.Time `sql:"DEFAULT:current_timestamp"` + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time +} + +type JoinTable struct { + From uint64 + To uint64 + Time time.Time `sql:"default: null"` +} + +type Post struct { + Id int64 + CategoryId sql.NullInt64 + MainCategoryId int64 + Title string + Body string + Comments []*Comment + Category Category + MainCategory Category +} + +type Category struct { + gorm.Model + Name string +} + +type Comment struct { + gorm.Model + PostId int64 + Content string + Post Post +} + +// Scanner +type NullValue struct { + Id int64 + Name sql.NullString `sql:"not null"` + Gender *sql.NullString `sql:"not null"` + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + AddedAt NullTime +} + +type NullTime struct { + Time time.Time + Valid bool +} + +func (nt *NullTime) Scan(value interface{}) error { + if value == nil { + nt.Valid = false + return nil + } + nt.Time, nt.Valid = value.(time.Time), true + return nil +} + +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +func getPreparedUser(name string, role string) *User { + var company Company + DB.Where(Company{Name: role}).FirstOrCreate(&company) + + return &User{ + Name: name, + Age: 20, + Role: Role{role}, + BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, + ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, + CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, + Emails: []Email{ + {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, + }, + Company: company, + Languages: []Language{ + {Name: fmt.Sprintf("lang_1_%v", name)}, + {Name: fmt.Sprintf("lang_2_%v", name)}, + }, + } +} + func runMigration() { if err := DB.DropTableIfExists(&User{}).Error; err != nil { fmt.Printf("Got error when try to delete table users, %+v\n", err) diff --git a/model.go b/model.go index a6650877..f37ff7ea 100644 --- a/model.go +++ b/model.go @@ -2,7 +2,7 @@ package gorm import "time" -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models +// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models // type User struct { // gorm.Model // } diff --git a/slice_test.go b/scaner_test.go similarity index 100% rename from slice_test.go rename to scaner_test.go diff --git a/scope.go b/scope.go index 6239db7a..f8c9dec0 100644 --- a/scope.go +++ b/scope.go @@ -2,10 +2,13 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "regexp" + "strconv" "strings" + "time" "reflect" ) @@ -406,11 +409,15 @@ func (scope *Scope) SelectAttrs() []string { return *scope.selectAttrs } -// OmitAttrs return omited attributes +// OmitAttrs return omitted attributes func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } +//////////////////////////////////////////////////////////////////////////////// +// Private Methods For *gorm.Scope +//////////////////////////////////////////////////////////////////////////////// + func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { var values = make([]interface{}, len(columns)) var ignored interface{} @@ -441,3 +448,715 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string] } } } +func (scope *Scope) primaryCondition(value interface{}) string { + return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) +} + +func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + // if string is number + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + return scope.primaryCondition(scope.AddToVars(value)) + } else if value != "" { + str = fmt.Sprintf("(%v)", value) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return scope.primaryCondition(scope.AddToVars(value)) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} + case map[string]interface{}: + var sqls []string + for key, value := range value { + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + } + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.New(value).Fields() { + if !field.IsIgnored && !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + } + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { + var notEqualSQL string + var primaryKey = scope.PrimaryKey() + + switch value := clause["query"].(type) { + case string: + // is number + if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + id, _ := strconv.Atoi(value) + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) + } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { + str = fmt.Sprintf(" NOT (%v) ", value) + notEqualSQL = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) + notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: + if reflect.ValueOf(value).Len() > 0 { + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) + clause["args"] = []interface{}{value} + } + return "" + case map[string]interface{}: + var sqls []string + for key, value := range value { + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) + } + } + return strings.Join(sqls, " AND ") + case interface{}: + var sqls []string + for _, field := range scope.New(value).Fields() { + if !field.IsBlank { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } + } + return strings.Join(sqls, " AND ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: // For where("id in (?)", []int64{1,2}) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + } + default: + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = scanner.Value() + } + str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { + switch value := clause["query"].(type) { + case string: + str = value + case []string: + str = strings.Join(value, ", ") + } + + args := clause["args"].([]interface{}) + for _, arg := range args { + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: + values := reflect.ValueOf(arg) + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + default: + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, _ = valuer.Value() + } + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } + } + return +} + +func (scope *Scope) whereSQL() (sql string) { + var ( + quotedTableName = scope.QuotedTableName() + primaryConditions, andConditions, orConditions []string + ) + + if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { + sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) + primaryConditions = append(primaryConditions, sql) + } + + if !scope.PrimaryKeyZero() { + for _, field := range scope.PrimaryFields() { + sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) + primaryConditions = append(primaryConditions, sql) + } + } + + for _, clause := range scope.Search.whereConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + for _, clause := range scope.Search.orConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + orConditions = append(orConditions, sql) + } + } + + for _, clause := range scope.Search.notConditions { + if sql := scope.buildNotCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + orSQL := strings.Join(orConditions, " OR ") + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) > 0 { + if len(orSQL) > 0 { + combinedSQL = combinedSQL + " OR " + orSQL + } + } else { + combinedSQL = orSQL + } + + if len(primaryConditions) > 0 { + sql = "WHERE " + strings.Join(primaryConditions, " AND ") + if len(combinedSQL) > 0 { + sql = sql + " AND (" + combinedSQL + ")" + } + } else if len(combinedSQL) > 0 { + sql = "WHERE " + combinedSQL + } + return +} + +func (scope *Scope) selectSQL() string { + if len(scope.Search.selects) == 0 { + if len(scope.Search.joinConditions) > 0 { + return fmt.Sprintf("%v.*", scope.QuotedTableName()) + } + return "*" + } + return scope.buildSelectQuery(scope.Search.selects) +} + +func (scope *Scope) orderSQL() string { + if len(scope.Search.orders) == 0 || scope.Search.countingQuery { + return "" + } + return " ORDER BY " + strings.Join(scope.Search.orders, ",") +} + +func (scope *Scope) limitAndOffsetSQL() string { + return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) +} + +func (scope *Scope) groupSQL() string { + if len(scope.Search.group) == 0 { + return "" + } + return " GROUP BY " + scope.Search.group +} + +func (scope *Scope) havingSQL() string { + if len(scope.Search.havingConditions) == 0 { + return "" + } + + var andConditions []string + for _, clause := range scope.Search.havingConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) == 0 { + return "" + } + + return " HAVING " + combinedSQL +} + +func (scope *Scope) joinsSQL() string { + var joinConditions []string + for _, clause := range scope.Search.joinConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) + } + } + + return strings.Join(joinConditions, " ") + " " +} + +func (scope *Scope) prepareQuerySQL() { + if scope.Search.raw { + scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) + } else { + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) + } + return +} + +func (scope *Scope) inlineCondition(values ...interface{}) *Scope { + if len(values) > 0 { + scope.Search.Where(values[0], values[1:]...) + } + return scope +} + +func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + for _, f := range funcs { + (*f)(scope) + if scope.skipLeft { + break + } + } + return scope +} + +func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { + if scope.IndirectValue().Kind() != reflect.Struct { + return values, true + } + + results = map[string]interface{}{} + for key, value := range values { + if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { + if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { + if _, ok := value.(*expr); ok { + hasUpdate = true + results[field.DBName] = value + } else if !equalAsString(field.Field.Interface(), value) { + field.Set(value) + if field.IsNormal { + hasUpdate = true + results[field.DBName] = field.Field.Interface() + } + } + } else { + field.Set(value) + } + } + } + return +} + +func (scope *Scope) row() *sql.Row { + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) +} + +func (scope *Scope) rows() (*sql.Rows, error) { + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) +} + +func (scope *Scope) initialize() *Scope { + for _, clause := range scope.Search.whereConditions { + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) + } + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) + return scope +} + +func (scope *Scope) pluck(column string, value interface{}) *Scope { + dest := reflect.Indirect(reflect.ValueOf(value)) + scope.Search.Select(column) + if dest.Kind() != reflect.Slice { + scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) + return scope + } + + rows, err := scope.rows() + if scope.Err(err) == nil { + defer rows.Close() + for rows.Next() { + elem := reflect.New(dest.Type().Elem()).Interface() + scope.Err(rows.Scan(elem)) + dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) + } + } + return scope +} + +func (scope *Scope) count(value interface{}) *Scope { + scope.Search.Select("count(*)") + scope.Search.countingQuery = true + scope.Err(scope.row().Scan(value)) + return scope +} + +func (scope *Scope) typeName() string { + typ := scope.IndirectValue().Type() + + for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + return typ.Name() +} + +// trace print sql log +func (scope *Scope) trace(t time.Time) { + if len(scope.SQL) > 0 { + scope.db.slog(scope.SQL, t, scope.SQLVars...) + } +} + +func (scope *Scope) changeableField(field *Field) bool { + if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if field.Name == attr || field.DBName == attr { + return true + } + } + return false + } + + for _, attr := range scope.OmitAttrs() { + if field.Name == attr || field.DBName == attr { + return false + } + } + + return true +} + +func (scope *Scope) shouldSaveAssociations() bool { + if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { + return false + } + return true && !scope.HasError() +} + +func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { + toScope := scope.db.NewScope(value) + + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + fromField, _ := scope.FieldByName(foreignKey) + toField, _ := toScope.FieldByName(foreignKey) + + if fromField != nil { + if relationship := fromField.Relationship; relationship != nil { + if relationship.Kind == "many_to_many" { + joinTableHandler := relationship.JoinTableHandler + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) + } else if relationship.Kind == "belongs_to" { + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(foreignKey); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) + } + } + scope.Err(query.Find(value).Error) + } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + if relationship.PolymorphicType != "" { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + } + scope.Err(query.Find(value).Error) + } + } else { + sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) + scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + } + return scope + } else if toField != nil { + sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) + scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + return scope + } + } + + scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) + return scope +} + +// getTableOptions return the table options string or an empty string if the table options does not exist +func (scope *Scope) getTableOptions() string { + tableOptions, ok := scope.Get("gorm:table_options") + if !ok { + return "" + } + return tableOptions.(string) +} + +func (scope *Scope) createJoinTable(field *StructField) { + if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { + joinTableHandler := relationship.JoinTableHandler + joinTable := joinTableHandler.Table(scope.db) + if !scope.Dialect().HasTable(joinTable) { + toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} + + var sqlTypes, primaryKeys []string + for idx, fieldName := range relationship.ForeignFieldNames { + if field, ok := scope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) + } + } + + for idx, fieldName := range relationship.AssociationForeignFieldNames { + if field, ok := toScope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) + primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) + } + } + + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + } + scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) + } +} + +func (scope *Scope) createTable() *Scope { + var tags []string + var primaryKeys []string + var primaryKeyInColumnType = false + for _, field := range scope.GetModelStruct().StructFields { + if field.IsNormal { + sqlTag := scope.Dialect().DataTypeOf(field) + + // Check if the primary key constraint was specified as + // part of the column type. If so, we can only support + // one column as the primary key. + if strings.Contains(strings.ToLower(sqlTag), "primary key") { + primaryKeyInColumnType = true + } + + tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) + } + + if field.IsPrimaryKey { + primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) + } + scope.createJoinTable(field) + } + + var primaryKeyStr string + if len(primaryKeys) > 0 && !primaryKeyInColumnType { + primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) + } + + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + + scope.autoIndex() + return scope +} + +func (scope *Scope) dropTable() *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() + return scope +} + +func (scope *Scope) modifyColumn(column string, typ string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() +} + +func (scope *Scope) dropColumn(column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { + if scope.Dialect().HasIndex(scope.TableName(), indexName) { + return + } + + var columns []string + for _, name := range column { + columns = append(columns, scope.quoteIfPossible(name)) + } + + sqlCreate := "CREATE INDEX" + if unique { + sqlCreate = "CREATE UNIQUE INDEX" + } + + scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() +} + +func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { + var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + + if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() +} + +func (scope *Scope) removeIndex(indexName string) { + scope.Dialect().RemoveIndex(scope.TableName(), indexName) +} + +func (scope *Scope) autoMigrate() *Scope { + tableName := scope.TableName() + quotedTableName := scope.QuotedTableName() + + if !scope.Dialect().HasTable(tableName) { + scope.createTable() + } else { + for _, field := range scope.GetModelStruct().StructFields { + if !scope.Dialect().HasColumn(tableName, field.DBName) { + if field.IsNormal { + sqlTag := scope.Dialect().DataTypeOf(field) + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() + } + } + scope.createJoinTable(field) + } + scope.autoIndex() + } + return scope +} + +func (scope *Scope) autoIndex() *Scope { + var indexes = map[string][]string{} + var uniqueIndexes = map[string][]string{} + + for _, field := range scope.GetStructFields() { + if name, ok := field.TagSettings["INDEX"]; ok { + if name == "INDEX" { + name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + } + indexes[name] = append(indexes[name], field.DBName) + } + + if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { + if name == "UNIQUE_INDEX" { + name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + } + uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) + } + } + + for name, columns := range indexes { + scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + } + + for name, columns := range uniqueIndexes { + scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + } + + return scope +} + +func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + for _, value := range values { + indirectValue := reflect.ValueOf(value) + for indirectValue.Kind() == reflect.Ptr { + indirectValue = indirectValue.Elem() + } + + switch indirectValue.Kind() { + case reflect.Slice: + for i := 0; i < indirectValue.Len(); i++ { + var result []interface{} + var object = indirect(indirectValue.Index(i)) + for _, column := range columns { + result = append(result, object.FieldByName(column).Interface()) + } + results = append(results, result) + } + case reflect.Struct: + var result []interface{} + for _, column := range columns { + result = append(result, indirectValue.FieldByName(column).Interface()) + } + results = append(results, result) + } + } + return +} + +func (scope *Scope) getColumnAsScope(column string) *Scope { + indirectScopeValue := scope.IndirectValue() + + switch indirectScopeValue.Kind() { + case reflect.Slice: + if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { + fieldType := fieldStruct.Type + if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() + + for i := 0; i < indirectScopeValue.Len(); i++ { + result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) + + if result.Kind() == reflect.Slice { + for j := 0; j < result.Len(); j++ { + if elem := result.Index(j); elem.CanAddr() { + results = reflect.Append(results, elem.Addr()) + } + } + } else if result.CanAddr() { + results = reflect.Append(results, result.Addr()) + } + } + return scope.New(results.Interface()) + } + case reflect.Struct: + if field := indirectScopeValue.FieldByName(column); field.CanAddr() { + return scope.New(field.Addr().Interface()) + } + } + return nil +} diff --git a/scope_private.go b/scope_private.go deleted file mode 100644 index c491cc7a..00000000 --- a/scope_private.go +++ /dev/null @@ -1,663 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - // if string is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - // is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } - return "" - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { - sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.countingQuery { - return "" - } - return " ORDER BY " + strings.Join(scope.Search.orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return values, true - } - - results = map[string]interface{}{} - for key, value := range values { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else if !equalAsString(field.Field.Interface(), value) { - field.Set(value) - if field.IsNormal { - hasUpdate = true - results[field.DBName] = field.Field.Interface() - } - } - } else { - field.Set(value) - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) - } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) - return scope -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - scope.Search.Select("count(*)") - scope.Search.countingQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { - return false - } - return true && !scope.HasError() -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(query.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - scope.Err(query.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -/** - Return the table options string or an empty string if the table options does not exist -*/ -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - if name == "INDEX" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - if name == "UNIQUE_INDEX" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } - - for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - } - - for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - } - - return scope -} diff --git a/scope_utils.go b/scope_utils.go deleted file mode 100644 index 2d914314..00000000 --- a/scope_utils.go +++ /dev/null @@ -1,67 +0,0 @@ -package gorm - -import "reflect" - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - for _, value := range values { - indirectValue := reflect.ValueOf(value) - for indirectValue.Kind() == reflect.Ptr { - indirectValue = indirectValue.Elem() - } - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - for _, column := range columns { - result = append(result, object.FieldByName(column).Interface()) - } - results = append(results, result) - } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, indirectValue.FieldByName(column).Interface()) - } - results = append(results, result) - } - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() { - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} diff --git a/structs_test.go b/structs_test.go deleted file mode 100644 index e7bae25f..00000000 --- a/structs_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - - "github.com/jinzhu/gorm" - - "reflect" - "time" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Birthday time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Role struct { - Name string -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -}