Rename test files

This commit is contained in:
Jinzhu 2016-03-08 21:45:20 +08:00
parent 9d57c6b961
commit d08894cd69
15 changed files with 1004 additions and 1025 deletions

View File

@ -18,10 +18,9 @@ func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
var (
isSlice bool
isPtr bool
results = scope.IndirectValue()
resultType reflect.Type
isSlice, isPtr bool
resultType reflect.Type
results = scope.IndirectValue()
)
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {

View File

@ -17,13 +17,14 @@ func saveBeforeAssociationsCallback(scope *Scope) {
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
fieldValue := field.Field.Addr().Interface()
scope.Err(scope.NewDB().Save(fieldValue).Error)
if len(relationship.ForeignFieldNames) != 0 {
// set value's foreign key
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
}
}
}
@ -65,7 +66,7 @@ func saveAfterAssociationsCallback(scope *Scope) {
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
}
}
default:

View File

@ -92,7 +92,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
// We must fetch the value again, to have the default fields updated
// (We can't do this in the update statements, since sql default can be expressions
// And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
if an.Name != "galeone" {

View File

@ -1,23 +0,0 @@
package gorm_test
import (
"testing"
)
func TestDdlErrors(t *testing.T) {
var err error
if err = DB.Close(); err != nil {
t.Errorf("Closing DDL test db connection err=%s", err)
}
defer func() {
// Reopen DB connection.
if DB, err = OpenTestConnection(); err != nil {
t.Fatalf("Failed re-opening db connection: %s", err)
}
}()
if err := DB.Find(&User{}).Error; err == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
}
}

42
main.go
View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
)
// DB contains information for current db connection
@ -221,7 +222,7 @@ func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db
}
// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.Attrs(attrs...).db
}
@ -299,7 +300,7 @@ func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
}
// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions)
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/curd.html#firstorinit
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone()
@ -660,3 +661,40 @@ func (s *DB) GetErrors() (errors []error) {
}
return
}
////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.DB
////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
for key, value := range s.values {
db.values[key] = value
}
if s.search == nil {
db.search = &search{limit: -1, offset: -1}
} else {
db.search = s.search.clone()
}
db.search.db = &db
return &db
}
func (s *DB) print(v ...interface{}) {
s.logger.(logger).Print(v...)
}
func (s *DB) log(v ...interface{}) {
if s != nil && s.logMode == 2 {
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
}
}
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
}
}

View File

@ -1,36 +0,0 @@
package gorm
import "time"
func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
for key, value := range s.values {
db.values[key] = value
}
if s.search == nil {
db.search = &search{limit: -1, offset: -1}
} else {
db.search = s.search.clone()
}
db.search.db = &db
return &db
}
func (s *DB) print(v ...interface{}) {
s.logger.(logger).Print(v...)
}
func (s *DB) log(v ...interface{}) {
if s != nil && s.logMode == 2 {
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
}
}
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
}
}

View File

@ -714,6 +714,24 @@ func TestOpenExistingDB(t *testing.T) {
}
}
func TestDdlErrors(t *testing.T) {
var err error
if err = DB.Close(); err != nil {
t.Errorf("Closing DDL test db connection err=%s", err)
}
defer func() {
// Reopen DB connection.
if DB, err = OpenTestConnection(); err != nil {
t.Fatalf("Failed re-opening db connection: %s", err)
}
}()
if err := DB.Find(&User{}).Error; err == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {

View File

@ -1,11 +1,226 @@
package gorm_test
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/jinzhu/gorm"
)
type User struct {
Id int64
Age int64
UserNum Num
Name string `sql:"size:255"`
Birthday time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
Latitude float64
Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int
Company Company
Role
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
}
type CreditCard struct {
ID int8
Number string
UserId sql.NullInt64
CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time
DeletedAt *time.Time
}
type Email struct {
Id int16
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Address struct {
ID int
Address1 string
Address2 string
Post string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type Language struct {
gorm.Model
Name string
Users []User `gorm:"many2many:user_languages;"`
}
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
type Company struct {
Id int64
Name string
Owner *User `sql:"-"`
}
type Role struct {
Name string
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type Animal struct {
Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value
CreatedAt time.Time
UpdatedAt time.Time
}
type JoinTable struct {
From uint64
To uint64
Time time.Time `sql:"default: null"`
}
type Post struct {
Id int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
Comments []*Comment
Category Category
MainCategory Category
}
type Category struct {
gorm.Model
Name string
}
type Comment struct {
gorm.Model
PostId int64
Content string
Post Post
}
// Scanner
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}
func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err)

View File

@ -2,7 +2,7 @@ package gorm
import "time"
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
// type User struct {
// gorm.Model
// }

721
scope.go
View File

@ -2,10 +2,13 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
"reflect"
)
@ -406,11 +409,15 @@ func (scope *Scope) SelectAttrs() []string {
return *scope.selectAttrs
}
// OmitAttrs return omited attributes
// OmitAttrs return omitted attributes
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}
////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.Scope
////////////////////////////////////////////////////////////////////////////////
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
var values = make([]interface{}, len(columns))
var ignored interface{}
@ -441,3 +448,715 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]
}
}
}
func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
}
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
// if string is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
return scope.primaryCondition(scope.AddToVars(value))
} else if value != "" {
str = fmt.Sprintf("(%v)", value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
}
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSQL string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
case string:
// is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
}
return ""
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
}
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
str = value
case []string:
str = strings.Join(value, ", ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice:
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) whereSQL() (sql string) {
var (
quotedTableName = scope.QuotedTableName()
primaryConditions, andConditions, orConditions []string
)
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
primaryConditions = append(primaryConditions, sql)
}
if !scope.PrimaryKeyZero() {
for _, field := range scope.PrimaryFields() {
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
primaryConditions = append(primaryConditions, sql)
}
}
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
orConditions = append(orConditions, sql)
}
}
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
orSQL := strings.Join(orConditions, " OR ")
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) > 0 {
if len(orSQL) > 0 {
combinedSQL = combinedSQL + " OR " + orSQL
}
} else {
combinedSQL = orSQL
}
if len(primaryConditions) > 0 {
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
if len(combinedSQL) > 0 {
sql = sql + " AND (" + combinedSQL + ")"
}
} else if len(combinedSQL) > 0 {
sql = "WHERE " + combinedSQL
}
return
}
func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 {
if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName())
}
return "*"
}
return scope.buildSelectQuery(scope.Search.selects)
}
func (scope *Scope) orderSQL() string {
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}
func (scope *Scope) groupSQL() string {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.group
}
func (scope *Scope) havingSQL() string {
if len(scope.Search.havingConditions) == 0 {
return ""
}
var andConditions []string
for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) == 0 {
return ""
}
return " HAVING " + combinedSQL
}
func (scope *Scope) joinsSQL() string {
var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
}
}
return strings.Join(joinConditions, " ") + " "
}
func (scope *Scope) prepareQuerySQL() {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
}
return
}
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
if len(values) > 0 {
scope.Search.Where(values[0], values[1:]...)
}
return scope
}
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
for _, f := range funcs {
(*f)(scope)
if scope.skipLeft {
break
}
}
return scope
}
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
if scope.IndirectValue().Kind() != reflect.Struct {
return values, true
}
results = map[string]interface{}{}
for key, value := range values {
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok {
hasUpdate = true
results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) {
field.Set(value)
if field.IsNormal {
hasUpdate = true
results[field.DBName] = field.Field.Interface()
}
}
} else {
field.Set(value)
}
}
}
return
}
func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
return scope
}
func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search.Select(column)
if dest.Kind() != reflect.Slice {
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
return scope
}
rows, err := scope.rows()
if scope.Err(err) == nil {
defer rows.Close()
for rows.Next() {
elem := reflect.New(dest.Type().Elem()).Interface()
scope.Err(rows.Scan(elem))
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
}
}
return scope
}
func (scope *Scope) count(value interface{}) *Scope {
scope.Search.Select("count(*)")
scope.Search.countingQuery = true
scope.Err(scope.row().Scan(value))
return scope
}
func (scope *Scope) typeName() string {
typ := scope.IndirectValue().Type()
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
return typ.Name()
}
// trace print sql log
func (scope *Scope) trace(t time.Time) {
if len(scope.SQL) > 0 {
scope.db.slog(scope.SQL, t, scope.SQLVars...)
}
}
func (scope *Scope) changeableField(field *Field) bool {
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range scope.OmitAttrs() {
if field.Name == attr || field.DBName == attr {
return false
}
}
return true
}
func (scope *Scope) shouldSaveAssociations() bool {
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
return false
}
return true && !scope.HasError()
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey)
toField, _ := toScope.FieldByName(foreignKey)
if fromField != nil {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
}
}
scope.Err(query.Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
scope.Err(query.Find(value).Error)
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
return scope
}
}
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
return scope
}
// getTableOptions return the table options string or an empty string if the table options does not exist
func (scope *Scope) getTableOptions() string {
tableOptions, ok := scope.Get("gorm:table_options")
if !ok {
return ""
}
return tableOptions.(string)
}
func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(fieldName); ok {
foreignKeyStruct := field.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}
for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.FieldByName(fieldName); ok {
foreignKeyStruct := field.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}
}
func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
// one column as the primary key.
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
primaryKeyInColumnType = true
}
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
scope.autoIndex()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
return
}
var columns []string
for _, name := range column {
columns = append(columns, scope.quoteIfPossible(name))
}
sqlCreate := "CREATE INDEX"
if unique {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
}
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(tableName) {
scope.createTable()
} else {
for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
scope.createJoinTable(field)
}
scope.autoIndex()
}
return scope
}
func (scope *Scope) autoIndex() *Scope {
var indexes = map[string][]string{}
var uniqueIndexes = map[string][]string{}
for _, field := range scope.GetStructFields() {
if name, ok := field.TagSettings["INDEX"]; ok {
if name == "INDEX" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
}
indexes[name] = append(indexes[name], field.DBName)
}
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
if name == "UNIQUE_INDEX" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
}
for name, columns := range indexes {
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
}
for name, columns := range uniqueIndexes {
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
}
return scope
}
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
for _, value := range values {
indirectValue := reflect.ValueOf(value)
for indirectValue.Kind() == reflect.Ptr {
indirectValue = indirectValue.Elem()
}
switch indirectValue.Kind() {
case reflect.Slice:
for i := 0; i < indirectValue.Len(); i++ {
var result []interface{}
var object = indirect(indirectValue.Index(i))
for _, column := range columns {
result = append(result, object.FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, indirectValue.FieldByName(column).Interface())
}
results = append(results, result)
}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
fieldType := fieldStruct.Type
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ {
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
}
} else if result.CanAddr() {
results = reflect.Append(results, result.Addr())
}
}
return scope.New(results.Interface())
}
case reflect.Struct:
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
return scope.New(field.Addr().Interface())
}
}
return nil
}

View File

@ -1,663 +0,0 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
}
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
// if string is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
return scope.primaryCondition(scope.AddToVars(value))
} else if value != "" {
str = fmt.Sprintf("(%v)", value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
}
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSQL string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
case string:
// is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
}
return ""
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
}
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
str = value
case []string:
str = strings.Join(value, ", ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice:
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) whereSQL() (sql string) {
var (
quotedTableName = scope.QuotedTableName()
primaryConditions, andConditions, orConditions []string
)
if !scope.Search.Unscoped && scope.HasColumn("deleted_at") {
sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName)
primaryConditions = append(primaryConditions, sql)
}
if !scope.PrimaryKeyZero() {
for _, field := range scope.PrimaryFields() {
sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
primaryConditions = append(primaryConditions, sql)
}
}
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
orConditions = append(orConditions, sql)
}
}
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
orSQL := strings.Join(orConditions, " OR ")
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) > 0 {
if len(orSQL) > 0 {
combinedSQL = combinedSQL + " OR " + orSQL
}
} else {
combinedSQL = orSQL
}
if len(primaryConditions) > 0 {
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
if len(combinedSQL) > 0 {
sql = sql + " AND (" + combinedSQL + ")"
}
} else if len(combinedSQL) > 0 {
sql = "WHERE " + combinedSQL
}
return
}
func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 {
if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName())
}
return "*"
}
return scope.buildSelectQuery(scope.Search.selects)
}
func (scope *Scope) orderSQL() string {
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}
func (scope *Scope) groupSQL() string {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.group
}
func (scope *Scope) havingSQL() string {
if len(scope.Search.havingConditions) == 0 {
return ""
}
var andConditions []string
for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) == 0 {
return ""
}
return " HAVING " + combinedSQL
}
func (scope *Scope) joinsSQL() string {
var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
}
}
return strings.Join(joinConditions, " ") + " "
}
func (scope *Scope) prepareQuerySQL() {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
}
return
}
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
if len(values) > 0 {
scope.Search.Where(values[0], values[1:]...)
}
return scope
}
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
for _, f := range funcs {
(*f)(scope)
if scope.skipLeft {
break
}
}
return scope
}
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) {
if scope.IndirectValue().Kind() != reflect.Struct {
return values, true
}
results = map[string]interface{}{}
for key, value := range values {
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok {
hasUpdate = true
results[field.DBName] = value
} else if !equalAsString(field.Field.Interface(), value) {
field.Set(value)
if field.IsNormal {
hasUpdate = true
results[field.DBName] = field.Field.Interface()
}
}
} else {
field.Set(value)
}
}
}
return
}
func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]))
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs))
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs))
return scope
}
func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search.Select(column)
if dest.Kind() != reflect.Slice {
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
return scope
}
rows, err := scope.rows()
if scope.Err(err) == nil {
defer rows.Close()
for rows.Next() {
elem := reflect.New(dest.Type().Elem()).Interface()
scope.Err(rows.Scan(elem))
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
}
}
return scope
}
func (scope *Scope) count(value interface{}) *Scope {
scope.Search.Select("count(*)")
scope.Search.countingQuery = true
scope.Err(scope.row().Scan(value))
return scope
}
func (scope *Scope) typeName() string {
typ := scope.IndirectValue().Type()
for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
return typ.Name()
}
// trace print sql log
func (scope *Scope) trace(t time.Time) {
if len(scope.SQL) > 0 {
scope.db.slog(scope.SQL, t, scope.SQLVars...)
}
}
func (scope *Scope) changeableField(field *Field) bool {
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range scope.OmitAttrs() {
if field.Name == attr || field.DBName == attr {
return false
}
}
return true
}
func (scope *Scope) shouldSaveAssociations() bool {
if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) {
return false
}
return true && !scope.HasError()
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey)
toField, _ := toScope.FieldByName(foreignKey)
if fromField != nil {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
}
}
scope.Err(query.Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
scope.Err(query.Find(value).Error)
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
return scope
}
}
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
return scope
}
/**
Return the table options string or an empty string if the table options does not exist
*/
func (scope *Scope) getTableOptions() string {
tableOptions, ok := scope.Get("gorm:table_options")
if !ok {
return ""
}
return tableOptions.(string)
}
func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(fieldName); ok {
foreignKeyStruct := field.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}
for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.FieldByName(fieldName); ok {
foreignKeyStruct := field.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}
}
func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
// one column as the primary key.
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
primaryKeyInColumnType = true
}
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
scope.autoIndex()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
return
}
var columns []string
for _, name := range column {
columns = append(columns, scope.quoteIfPossible(name))
}
sqlCreate := "CREATE INDEX"
if unique {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
}
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(tableName) {
scope.createTable()
} else {
for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
scope.createJoinTable(field)
}
scope.autoIndex()
}
return scope
}
func (scope *Scope) autoIndex() *Scope {
var indexes = map[string][]string{}
var uniqueIndexes = map[string][]string{}
for _, field := range scope.GetStructFields() {
if name, ok := field.TagSettings["INDEX"]; ok {
if name == "INDEX" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
}
indexes[name] = append(indexes[name], field.DBName)
}
if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok {
if name == "UNIQUE_INDEX" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
}
for name, columns := range indexes {
scope.NewDB().Model(scope.Value).AddIndex(name, columns...)
}
for name, columns := range uniqueIndexes {
scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...)
}
return scope
}

View File

@ -1,67 +0,0 @@
package gorm
import "reflect"
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
for _, value := range values {
indirectValue := reflect.ValueOf(value)
for indirectValue.Kind() == reflect.Ptr {
indirectValue = indirectValue.Elem()
}
switch indirectValue.Kind() {
case reflect.Slice:
for i := 0; i < indirectValue.Len(); i++ {
var result []interface{}
var object = indirect(indirectValue.Index(i))
for _, column := range columns {
result = append(result, object.FieldByName(column).Interface())
}
results = append(results, result)
}
case reflect.Struct:
var result []interface{}
for _, column := range columns {
result = append(result, indirectValue.FieldByName(column).Interface())
}
results = append(results, result)
}
}
return
}
func (scope *Scope) getColumnAsScope(column string) *Scope {
indirectScopeValue := scope.IndirectValue()
switch indirectScopeValue.Kind() {
case reflect.Slice:
if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
fieldType := fieldStruct.Type
if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
for i := 0; i < indirectScopeValue.Len(); i++ {
result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
}
} else if result.CanAddr() {
results = reflect.Append(results, result.Addr())
}
}
return scope.New(results.Interface())
}
case reflect.Struct:
if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
return scope.New(field.Addr().Interface())
}
}
return nil
}

View File

@ -1,222 +0,0 @@
package gorm_test
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/jinzhu/gorm"
"reflect"
"time"
)
type User struct {
Id int64
Age int64
UserNum Num
Name string `sql:"size:255"`
Birthday time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
Latitude float64
Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int
Company Company
Role
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
}
type CreditCard struct {
ID int8
Number string
UserId sql.NullInt64
CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time
DeletedAt *time.Time
}
type Email struct {
Id int16
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Address struct {
ID int
Address1 string
Address2 string
Post string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type Language struct {
gorm.Model
Name string
Users []User `gorm:"many2many:user_languages;"`
}
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
type Company struct {
Id int64
Name string
Owner *User `sql:"-"`
}
type Role struct {
Name string
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type Animal struct {
Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value
CreatedAt time.Time
UpdatedAt time.Time
}
type JoinTable struct {
From uint64
To uint64
Time time.Time `sql:"default: null"`
}
type Post struct {
Id int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
Comments []*Comment
Category Category
MainCategory Category
}
type Category struct {
gorm.Model
Name string
}
type Comment struct {
gorm.Model
PostId int64
Content string
Post Post
}
// Scanner
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}