This commit is contained in:
Jinzhu 2016-03-08 22:29:58 +08:00
parent d08894cd69
commit a0aa21aec5
6 changed files with 164 additions and 154 deletions

View File

@ -10,6 +10,7 @@ import (
type Hstore map[string]*string
// Value get value of Hstore
func (h Hstore) Value() (driver.Value, error) {
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
if len(h) == 0 {
@ -27,6 +28,7 @@ func (h Hstore) Value() (driver.Value, error) {
return hstore.Value()
}
// Scan scan value into Hstore
func (h *Hstore) Scan(value interface{}) error {
hstore := hstore.Hstore{}

View File

@ -56,36 +56,3 @@ func (field *Field) Set(value interface{}) (err error) {
field.IsBlank = isBlank(field.Field)
return nil
}
// Fields get value's fields
func (scope *Scope) Fields() []*Field {
var (
fields []*Field
indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
for _, structField := range scope.GetModelStruct().StructFields {
if isStruct {
fieldValue := indirectScopeValue
for _, name := range structField.Names {
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
}
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
} else {
fields = append(fields, &Field{StructField: structField, IsBlank: true})
}
}
return fields
}
func (scope *Scope) fieldsMap() map[string]*Field {
var results = map[string]*Field{}
for _, field := range scope.Fields() {
if field.IsNormal {
results[field.DBName] = field
}
}
return results
}

View File

@ -37,7 +37,7 @@ var compareToys = func(toys []Toy, contents []string) bool {
func TestPolymorphic(t *testing.T) {
cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}}
dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "dog toy 1"}, Toy{Name: "dog toy 2"}}}
dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}}
DB.Save(&cat).Save(&dog)
if DB.Model(&cat).Association("Toy").Count() != 1 {

View File

@ -167,13 +167,13 @@ func TestNestedPreload2(t *testing.T) {
Level2s: []Level2{
{
Level1s: []*Level1{
&Level1{Value: "value1"},
&Level1{Value: "value2"},
{Value: "value1"},
{Value: "value2"},
},
},
{
Level1s: []*Level1{
&Level1{Value: "value3"},
{Value: "value3"},
},
},
},
@ -265,8 +265,8 @@ func TestNestedPreload4(t *testing.T) {
want := Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
{Value: "value1"},
{Value: "value2"},
},
},
}
@ -489,8 +489,8 @@ func TestNestedPreload8(t *testing.T) {
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
{Value: "value1"},
{Value: "value2"},
},
},
}
@ -500,8 +500,8 @@ func TestNestedPreload8(t *testing.T) {
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
{Value: "value3"},
{Value: "value4"},
},
},
}
@ -563,17 +563,17 @@ func TestNestedPreload9(t *testing.T) {
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
{Value: "value1"},
{Value: "value2"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{
{
Value: "value1-1",
Level0s: []Level0{{Value: "Level0-1"}},
},
Level1{
{
Value: "value2-2",
Level0s: []Level0{{Value: "Level0-2"}},
},
@ -586,14 +586,14 @@ func TestNestedPreload9(t *testing.T) {
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
{Value: "value3"},
{Value: "value4"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{Value: "value3-3"},
Level1{Value: "value4-4"},
{Value: "value3-3"},
{Value: "value4-4"},
},
},
}
@ -646,16 +646,16 @@ func TestNestedPreload10(t *testing.T) {
}
want := []*LevelA2{
&LevelA2{
{
Value: "bar",
LevelA3s: []*LevelA3{
&LevelA3{
{
Value: "qux",
LevelA1: levelA1,
},
},
},
&LevelA2{
{
Value: "bar 2",
},
}
@ -1151,11 +1151,11 @@ func TestNestedManyToManyPreload4(t *testing.T) {
dummy := Level1{
Value: "Level1",
Level2s: []*Level2{&Level2{
Level2s: []*Level2{{
Value: "Level2",
Level3s: []*Level3{&Level3{
Level3s: []*Level3{{
Value: "Level3",
Level4s: []*Level4{&Level4{
Level4s: []*Level4{{
Value: "Level4",
}},
}},

233
scope.go
View File

@ -23,7 +23,7 @@ type Scope struct {
instanceID string
primaryKeyField *Field
skipLeft bool
fields map[string]*Field
fields *[]*Field
selectAttrs *[]string
}
@ -37,6 +37,15 @@ func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
}
////////////////////////////////////////////////////////////////////////////////
// Scope DB
////////////////////////////////////////////////////////////////////////////////
// DB return scope's DB connection
func (scope *Scope) DB() *DB {
return scope.db
}
// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
if scope.db != nil {
@ -48,19 +57,14 @@ func (scope *Scope) NewDB() *DB {
return nil
}
// DB return scope's DB connection
func (scope *Scope) DB() *DB {
return scope.db
}
// SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon {
return scope.db.db
}
// SkipLeft skip remaining callbacks
func (scope *Scope) SkipLeft() {
scope.skipLeft = true
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
}
// Quote used to quote string to escape them for database
@ -76,18 +80,6 @@ func (scope *Scope) Quote(str string) string {
return scope.Dialect().Quote(str)
}
func (scope *Scope) quoteIfPossible(str string) string {
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
return scope.Quote(str)
}
return str
}
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
}
// Err add error to Scope
func (scope *Scope) Err(err error) error {
if err != nil {
@ -96,14 +88,63 @@ func (scope *Scope) Err(err error) error {
return err
}
// HasError check if there are any error
func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}
// Log print log message
func (scope *Scope) Log(v ...interface{}) {
scope.db.log(v...)
}
// HasError check if there are any error
func (scope *Scope) HasError() bool {
return scope.db.Error != nil
// SkipLeft skip remaining callbacks
func (scope *Scope) SkipLeft() {
scope.skipLeft = true
}
// Fields get value's fields
func (scope *Scope) Fields() []*Field {
if scope.fields == nil {
var (
fields []*Field
indirectScopeValue = scope.IndirectValue()
isStruct = indirectScopeValue.Kind() == reflect.Struct
)
for _, structField := range scope.GetModelStruct().StructFields {
if isStruct {
fieldValue := indirectScopeValue
for _, name := range structField.Names {
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
}
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
} else {
fields = append(fields, &Field{StructField: structField, IsBlank: true})
}
}
scope.fields = &fields
}
return *scope.fields
}
// FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var (
dbName = ToDBName(name)
mostMatchedField *Field
)
for _, field := range scope.Fields() {
if field.Name == name || field.DBName == name {
return field, true
}
if field.DBName == dbName {
mostMatchedField = field
}
}
return mostMatchedField, mostMatchedField != nil
}
// PrimaryFields return scope's primary fields
@ -195,35 +236,6 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
return errors.New("could not convert column to field")
}
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
if reflectValue.CanAddr() {
reflectValue = reflectValue.Addr()
}
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
switch method := methodValue.Interface().(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*DB):
newDB := scope.NewDB()
method(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
}
}
}
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
if scope.Value == nil {
@ -253,6 +265,31 @@ func (scope *Scope) AddToVars(value interface{}) string {
return scope.Dialect().BindVar(len(scope.SQLVars))
}
// SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil {
attrs := []string{}
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]string); ok {
attrs = append(attrs, strs...)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))
}
}
}
scope.selectAttrs = &attrs
}
return *scope.selectAttrs
}
// OmitAttrs return omitted attributes
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}
type tabler interface {
TableName() string
}
@ -296,24 +333,6 @@ func (scope *Scope) CombinedConditionSql() string {
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
}
// FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var (
dbName = ToDBName(name)
mostMatchedField *Field
)
for _, field := range scope.Fields() {
if field.Name == name || field.DBName == name {
return field, true
}
if field.DBName == dbName {
mostMatchedField = field
}
}
return mostMatchedField, mostMatchedField != nil
}
// Raw set raw sql
func (scope *Scope) Raw(sql string) *Scope {
scope.SQL = strings.Replace(sql, "$$", "?", -1)
@ -389,35 +408,56 @@ func (scope *Scope) CommitOrRollback() *Scope {
return scope
}
// SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil {
attrs := []string{}
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]string); ok {
attrs = append(attrs, strs...)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))
}
}
}
scope.selectAttrs = &attrs
}
return *scope.selectAttrs
}
// OmitAttrs return omitted attributes
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}
////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.Scope
////////////////////////////////////////////////////////////////////////////////
func (scope *Scope) fieldsMap() map[string]*Field {
var results = map[string]*Field{}
for _, field := range scope.Fields() {
if field.IsNormal {
results[field.DBName] = field
}
}
return results
}
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
if reflectValue.CanAddr() {
reflectValue = reflectValue.Addr()
}
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
switch method := methodValue.Interface().(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*DB):
newDB := scope.NewDB()
method(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
}
}
}
func (scope *Scope) quoteIfPossible(str string) string {
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
return scope.Quote(str)
}
return str
}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
var values = make([]interface{}, len(columns))
var ignored interface{}
@ -448,6 +488,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]
}
}
}
func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
}

View File

@ -134,7 +134,7 @@ func toQueryMarks(primaryValues [][]interface{}) string {
for _, primaryValue := range primaryValues {
var marks []string
for _ = range primaryValue {
for range primaryValue {
marks = append(marks, "?")
}