diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 37881090..adeeec7b 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -10,6 +10,7 @@ import ( type Hstore map[string]*string +// Value get value of Hstore func (h Hstore) Value() (driver.Value, error) { hstore := hstore.Hstore{Map: map[string]sql.NullString{}} if len(h) == 0 { @@ -27,6 +28,7 @@ func (h Hstore) Value() (driver.Value, error) { return hstore.Value() } +// Scan scan value into Hstore func (h *Hstore) Scan(value interface{}) error { hstore := hstore.Hstore{} diff --git a/field.go b/field.go index cadc1a72..87ef745b 100644 --- a/field.go +++ b/field.go @@ -56,36 +56,3 @@ func (field *Field) Set(value interface{}) (err error) { field.IsBlank = isBlank(field.Field) return nil } - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - - return fields -} - -func (scope *Scope) fieldsMap() map[string]*Field { - var results = map[string]*Field{} - for _, field := range scope.Fields() { - if field.IsNormal { - results[field.DBName] = field - } - } - return results -} diff --git a/polymorphic_test.go b/polymorphic_test.go index 31ffab4a..df573f97 100644 --- a/polymorphic_test.go +++ b/polymorphic_test.go @@ -37,7 +37,7 @@ var compareToys = func(toys []Toy, contents []string) bool { func TestPolymorphic(t *testing.T) { cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "dog toy 1"}, Toy{Name: "dog toy 2"}}} + dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} DB.Save(&cat).Save(&dog) if DB.Model(&cat).Association("Toy").Count() != 1 { diff --git a/preload_test.go b/preload_test.go index cde8e800..5c49ecc2 100644 --- a/preload_test.go +++ b/preload_test.go @@ -167,13 +167,13 @@ func TestNestedPreload2(t *testing.T) { Level2s: []Level2{ { Level1s: []*Level1{ - &Level1{Value: "value1"}, - &Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, { Level1s: []*Level1{ - &Level1{Value: "value3"}, + {Value: "value3"}, }, }, }, @@ -265,8 +265,8 @@ func TestNestedPreload4(t *testing.T) { want := Level3{ Level2: Level2{ Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, } @@ -489,8 +489,8 @@ func TestNestedPreload8(t *testing.T) { want[0] = Level3{ Level2: Level2{ Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, } @@ -500,8 +500,8 @@ func TestNestedPreload8(t *testing.T) { want[1] = Level3{ Level2: Level2{ Level1s: []Level1{ - Level1{Value: "value3"}, - Level1{Value: "value4"}, + {Value: "value3"}, + {Value: "value4"}, }, }, } @@ -563,17 +563,17 @@ func TestNestedPreload9(t *testing.T) { want[0] = Level3{ Level2: Level2{ Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, Level2_1: Level2_1{ Level1s: []Level1{ - Level1{ + { Value: "value1-1", Level0s: []Level0{{Value: "Level0-1"}}, }, - Level1{ + { Value: "value2-2", Level0s: []Level0{{Value: "Level0-2"}}, }, @@ -586,14 +586,14 @@ func TestNestedPreload9(t *testing.T) { want[1] = Level3{ Level2: Level2{ Level1s: []Level1{ - Level1{Value: "value3"}, - Level1{Value: "value4"}, + {Value: "value3"}, + {Value: "value4"}, }, }, Level2_1: Level2_1{ Level1s: []Level1{ - Level1{Value: "value3-3"}, - Level1{Value: "value4-4"}, + {Value: "value3-3"}, + {Value: "value4-4"}, }, }, } @@ -646,16 +646,16 @@ func TestNestedPreload10(t *testing.T) { } want := []*LevelA2{ - &LevelA2{ + { Value: "bar", LevelA3s: []*LevelA3{ - &LevelA3{ + { Value: "qux", LevelA1: levelA1, }, }, }, - &LevelA2{ + { Value: "bar 2", }, } @@ -1151,11 +1151,11 @@ func TestNestedManyToManyPreload4(t *testing.T) { dummy := Level1{ Value: "Level1", - Level2s: []*Level2{&Level2{ + Level2s: []*Level2{{ Value: "Level2", - Level3s: []*Level3{&Level3{ + Level3s: []*Level3{{ Value: "Level3", - Level4s: []*Level4{&Level4{ + Level4s: []*Level4{{ Value: "Level4", }}, }}, diff --git a/scope.go b/scope.go index f8c9dec0..a5eedbac 100644 --- a/scope.go +++ b/scope.go @@ -23,7 +23,7 @@ type Scope struct { instanceID string primaryKeyField *Field skipLeft bool - fields map[string]*Field + fields *[]*Field selectAttrs *[]string } @@ -37,6 +37,15 @@ func (scope *Scope) New(value interface{}) *Scope { return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} } +//////////////////////////////////////////////////////////////////////////////// +// Scope DB +//////////////////////////////////////////////////////////////////////////////// + +// DB return scope's DB connection +func (scope *Scope) DB() *DB { + return scope.db +} + // NewDB create a new DB without search information func (scope *Scope) NewDB() *DB { if scope.db != nil { @@ -48,19 +57,14 @@ func (scope *Scope) NewDB() *DB { return nil } -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - // SQLDB return *sql.DB func (scope *Scope) SQLDB() sqlCommon { return scope.db.db } -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true +// Dialect get dialect +func (scope *Scope) Dialect() Dialect { + return scope.db.parent.dialect } // Quote used to quote string to escape them for database @@ -76,18 +80,6 @@ func (scope *Scope) Quote(str string) string { return scope.Dialect().Quote(str) } -func (scope *Scope) quoteIfPossible(str string) string { - if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { - return scope.Quote(str) - } - return str -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect -} - // Err add error to Scope func (scope *Scope) Err(err error) error { if err != nil { @@ -96,14 +88,63 @@ func (scope *Scope) Err(err error) error { return err } +// HasError check if there are any error +func (scope *Scope) HasError() bool { + return scope.db.Error != nil +} + // Log print log message func (scope *Scope) Log(v ...interface{}) { scope.db.log(v...) } -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil +// SkipLeft skip remaining callbacks +func (scope *Scope) SkipLeft() { + scope.skipLeft = true +} + +// Fields get value's fields +func (scope *Scope) Fields() []*Field { + if scope.fields == nil { + var ( + fields []*Field + indirectScopeValue = scope.IndirectValue() + isStruct = indirectScopeValue.Kind() == reflect.Struct + ) + + for _, structField := range scope.GetModelStruct().StructFields { + if isStruct { + fieldValue := indirectScopeValue + for _, name := range structField.Names { + fieldValue = reflect.Indirect(fieldValue).FieldByName(name) + } + fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) + } else { + fields = append(fields, &Field{StructField: structField, IsBlank: true}) + } + } + scope.fields = &fields + } + + return *scope.fields +} + +// FieldByName find `gorm.Field` with field name or db name +func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + + for _, field := range scope.Fields() { + if field.Name == name || field.DBName == name { + return field, true + } + if field.DBName == dbName { + mostMatchedField = field + } + } + return mostMatchedField, mostMatchedField != nil } // PrimaryFields return scope's primary fields @@ -195,35 +236,6 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return errors.New("could not convert column to field") } -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - if reflectValue.CanAddr() { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - // CallMethod call scope value's method, if it is a slice, will call its element's method one by one func (scope *Scope) CallMethod(methodName string) { if scope.Value == nil { @@ -253,6 +265,31 @@ func (scope *Scope) AddToVars(value interface{}) string { return scope.Dialect().BindVar(len(scope.SQLVars)) } +// SelectAttrs return selected attributes +func (scope *Scope) SelectAttrs() []string { + if scope.selectAttrs == nil { + attrs := []string{} + for _, value := range scope.Search.selects { + if str, ok := value.(string); ok { + attrs = append(attrs, str) + } else if strs, ok := value.([]string); ok { + attrs = append(attrs, strs...) + } else if strs, ok := value.([]interface{}); ok { + for _, str := range strs { + attrs = append(attrs, fmt.Sprintf("%v", str)) + } + } + } + scope.selectAttrs = &attrs + } + return *scope.selectAttrs +} + +// OmitAttrs return omitted attributes +func (scope *Scope) OmitAttrs() []string { + return scope.Search.omits +} + type tabler interface { TableName() string } @@ -296,24 +333,6 @@ func (scope *Scope) CombinedConditionSql() string { scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() } -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - // Raw set raw sql func (scope *Scope) Raw(sql string) *Scope { scope.SQL = strings.Replace(sql, "$$", "?", -1) @@ -389,35 +408,56 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - //////////////////////////////////////////////////////////////////////////////// // Private Methods For *gorm.Scope //////////////////////////////////////////////////////////////////////////////// +func (scope *Scope) fieldsMap() map[string]*Field { + var results = map[string]*Field{} + for _, field := range scope.Fields() { + if field.IsNormal { + results[field.DBName] = field + } + } + return results +} + +func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { + if reflectValue.CanAddr() { + reflectValue = reflectValue.Addr() + } + + if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { + switch method := methodValue.Interface().(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*DB): + newDB := scope.NewDB() + method(newDB) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("unsupported function %v", methodName)) + } + } +} + +func (scope *Scope) quoteIfPossible(str string) string { + if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { + return scope.Quote(str) + } + return str +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { var values = make([]interface{}, len(columns)) var ignored interface{} @@ -448,6 +488,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string] } } } + func (scope *Scope) primaryCondition(value interface{}) string { return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) } diff --git a/utils.go b/utils.go index af11f5d2..c525631c 100644 --- a/utils.go +++ b/utils.go @@ -134,7 +134,7 @@ func toQueryMarks(primaryValues [][]interface{}) string { for _, primaryValue := range primaryValues { var marks []string - for _ = range primaryValue { + for range primaryValue { marks = append(marks, "?") }