mirror of https://github.com/go-gorm/gorm.git
Refactor
This commit is contained in:
parent
d08894cd69
commit
a0aa21aec5
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
type Hstore map[string]*string
|
||||
|
||||
// Value get value of Hstore
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||
if len(h) == 0 {
|
||||
|
@ -27,6 +28,7 @@ func (h Hstore) Value() (driver.Value, error) {
|
|||
return hstore.Value()
|
||||
}
|
||||
|
||||
// Scan scan value into Hstore
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
hstore := hstore.Hstore{}
|
||||
|
||||
|
|
33
field.go
33
field.go
|
@ -56,36 +56,3 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||
field.IsBlank = isBlank(field.Field)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() []*Field {
|
||||
var (
|
||||
fields []*Field
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||
)
|
||||
|
||||
for _, structField := range scope.GetModelStruct().StructFields {
|
||||
if isStruct {
|
||||
fieldValue := indirectScopeValue
|
||||
for _, name := range structField.Names {
|
||||
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||
}
|
||||
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
||||
} else {
|
||||
fields = append(fields, &Field{StructField: structField, IsBlank: true})
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func (scope *Scope) fieldsMap() map[string]*Field {
|
||||
var results = map[string]*Field{}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
results[field.DBName] = field
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ var compareToys = func(toys []Toy, contents []string) bool {
|
|||
|
||||
func TestPolymorphic(t *testing.T) {
|
||||
cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}}
|
||||
dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "dog toy 1"}, Toy{Name: "dog toy 2"}}}
|
||||
dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}}
|
||||
DB.Save(&cat).Save(&dog)
|
||||
|
||||
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||
|
|
|
@ -167,13 +167,13 @@ func TestNestedPreload2(t *testing.T) {
|
|||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []*Level1{
|
||||
&Level1{Value: "value1"},
|
||||
&Level1{Value: "value2"},
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []*Level1{
|
||||
&Level1{Value: "value3"},
|
||||
{Value: "value3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -265,8 +265,8 @@ func TestNestedPreload4(t *testing.T) {
|
|||
want := Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -489,8 +489,8 @@ func TestNestedPreload8(t *testing.T) {
|
|||
want[0] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -500,8 +500,8 @@ func TestNestedPreload8(t *testing.T) {
|
|||
want[1] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3"},
|
||||
Level1{Value: "value4"},
|
||||
{Value: "value3"},
|
||||
{Value: "value4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -563,17 +563,17 @@ func TestNestedPreload9(t *testing.T) {
|
|||
want[0] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
Level2_1: Level2_1{
|
||||
Level1s: []Level1{
|
||||
Level1{
|
||||
{
|
||||
Value: "value1-1",
|
||||
Level0s: []Level0{{Value: "Level0-1"}},
|
||||
},
|
||||
Level1{
|
||||
{
|
||||
Value: "value2-2",
|
||||
Level0s: []Level0{{Value: "Level0-2"}},
|
||||
},
|
||||
|
@ -586,14 +586,14 @@ func TestNestedPreload9(t *testing.T) {
|
|||
want[1] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3"},
|
||||
Level1{Value: "value4"},
|
||||
{Value: "value3"},
|
||||
{Value: "value4"},
|
||||
},
|
||||
},
|
||||
Level2_1: Level2_1{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3-3"},
|
||||
Level1{Value: "value4-4"},
|
||||
{Value: "value3-3"},
|
||||
{Value: "value4-4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -646,16 +646,16 @@ func TestNestedPreload10(t *testing.T) {
|
|||
}
|
||||
|
||||
want := []*LevelA2{
|
||||
&LevelA2{
|
||||
{
|
||||
Value: "bar",
|
||||
LevelA3s: []*LevelA3{
|
||||
&LevelA3{
|
||||
{
|
||||
Value: "qux",
|
||||
LevelA1: levelA1,
|
||||
},
|
||||
},
|
||||
},
|
||||
&LevelA2{
|
||||
{
|
||||
Value: "bar 2",
|
||||
},
|
||||
}
|
||||
|
@ -1151,11 +1151,11 @@ func TestNestedManyToManyPreload4(t *testing.T) {
|
|||
|
||||
dummy := Level1{
|
||||
Value: "Level1",
|
||||
Level2s: []*Level2{&Level2{
|
||||
Level2s: []*Level2{{
|
||||
Value: "Level2",
|
||||
Level3s: []*Level3{&Level3{
|
||||
Level3s: []*Level3{{
|
||||
Value: "Level3",
|
||||
Level4s: []*Level4{&Level4{
|
||||
Level4s: []*Level4{{
|
||||
Value: "Level4",
|
||||
}},
|
||||
}},
|
||||
|
|
233
scope.go
233
scope.go
|
@ -23,7 +23,7 @@ type Scope struct {
|
|||
instanceID string
|
||||
primaryKeyField *Field
|
||||
skipLeft bool
|
||||
fields map[string]*Field
|
||||
fields *[]*Field
|
||||
selectAttrs *[]string
|
||||
}
|
||||
|
||||
|
@ -37,6 +37,15 @@ func (scope *Scope) New(value interface{}) *Scope {
|
|||
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Scope DB
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// DB return scope's DB connection
|
||||
func (scope *Scope) DB() *DB {
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// NewDB create a new DB without search information
|
||||
func (scope *Scope) NewDB() *DB {
|
||||
if scope.db != nil {
|
||||
|
@ -48,19 +57,14 @@ func (scope *Scope) NewDB() *DB {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DB return scope's DB connection
|
||||
func (scope *Scope) DB() *DB {
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// SQLDB return *sql.DB
|
||||
func (scope *Scope) SQLDB() sqlCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
||||
// SkipLeft skip remaining callbacks
|
||||
func (scope *Scope) SkipLeft() {
|
||||
scope.skipLeft = true
|
||||
// Dialect get dialect
|
||||
func (scope *Scope) Dialect() Dialect {
|
||||
return scope.db.parent.dialect
|
||||
}
|
||||
|
||||
// Quote used to quote string to escape them for database
|
||||
|
@ -76,18 +80,6 @@ func (scope *Scope) Quote(str string) string {
|
|||
return scope.Dialect().Quote(str)
|
||||
}
|
||||
|
||||
func (scope *Scope) quoteIfPossible(str string) string {
|
||||
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
||||
return scope.Quote(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Dialect get dialect
|
||||
func (scope *Scope) Dialect() Dialect {
|
||||
return scope.db.parent.dialect
|
||||
}
|
||||
|
||||
// Err add error to Scope
|
||||
func (scope *Scope) Err(err error) error {
|
||||
if err != nil {
|
||||
|
@ -96,14 +88,63 @@ func (scope *Scope) Err(err error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// HasError check if there are any error
|
||||
func (scope *Scope) HasError() bool {
|
||||
return scope.db.Error != nil
|
||||
}
|
||||
|
||||
// Log print log message
|
||||
func (scope *Scope) Log(v ...interface{}) {
|
||||
scope.db.log(v...)
|
||||
}
|
||||
|
||||
// HasError check if there are any error
|
||||
func (scope *Scope) HasError() bool {
|
||||
return scope.db.Error != nil
|
||||
// SkipLeft skip remaining callbacks
|
||||
func (scope *Scope) SkipLeft() {
|
||||
scope.skipLeft = true
|
||||
}
|
||||
|
||||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() []*Field {
|
||||
if scope.fields == nil {
|
||||
var (
|
||||
fields []*Field
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
isStruct = indirectScopeValue.Kind() == reflect.Struct
|
||||
)
|
||||
|
||||
for _, structField := range scope.GetModelStruct().StructFields {
|
||||
if isStruct {
|
||||
fieldValue := indirectScopeValue
|
||||
for _, name := range structField.Names {
|
||||
fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
|
||||
}
|
||||
fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
|
||||
} else {
|
||||
fields = append(fields, &Field{StructField: structField, IsBlank: true})
|
||||
}
|
||||
}
|
||||
scope.fields = &fields
|
||||
}
|
||||
|
||||
return *scope.fields
|
||||
}
|
||||
|
||||
// FieldByName find `gorm.Field` with field name or db name
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
var (
|
||||
dbName = ToDBName(name)
|
||||
mostMatchedField *Field
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name || field.DBName == name {
|
||||
return field, true
|
||||
}
|
||||
if field.DBName == dbName {
|
||||
mostMatchedField = field
|
||||
}
|
||||
}
|
||||
return mostMatchedField, mostMatchedField != nil
|
||||
}
|
||||
|
||||
// PrimaryFields return scope's primary fields
|
||||
|
@ -195,35 +236,6 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
|||
return errors.New("could not convert column to field")
|
||||
}
|
||||
|
||||
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
|
||||
if reflectValue.CanAddr() {
|
||||
reflectValue = reflectValue.Addr()
|
||||
}
|
||||
|
||||
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
|
||||
switch method := methodValue.Interface().(type) {
|
||||
case func():
|
||||
method()
|
||||
case func(*Scope):
|
||||
method(scope)
|
||||
case func(*DB):
|
||||
newDB := scope.NewDB()
|
||||
method(newDB)
|
||||
scope.Err(newDB.Error)
|
||||
case func() error:
|
||||
scope.Err(method())
|
||||
case func(*Scope) error:
|
||||
scope.Err(method(scope))
|
||||
case func(*DB) error:
|
||||
newDB := scope.NewDB()
|
||||
scope.Err(method(newDB))
|
||||
scope.Err(newDB.Error)
|
||||
default:
|
||||
scope.Err(fmt.Errorf("unsupported function %v", methodName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
|
||||
func (scope *Scope) CallMethod(methodName string) {
|
||||
if scope.Value == nil {
|
||||
|
@ -253,6 +265,31 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||
}
|
||||
|
||||
// SelectAttrs return selected attributes
|
||||
func (scope *Scope) SelectAttrs() []string {
|
||||
if scope.selectAttrs == nil {
|
||||
attrs := []string{}
|
||||
for _, value := range scope.Search.selects {
|
||||
if str, ok := value.(string); ok {
|
||||
attrs = append(attrs, str)
|
||||
} else if strs, ok := value.([]string); ok {
|
||||
attrs = append(attrs, strs...)
|
||||
} else if strs, ok := value.([]interface{}); ok {
|
||||
for _, str := range strs {
|
||||
attrs = append(attrs, fmt.Sprintf("%v", str))
|
||||
}
|
||||
}
|
||||
}
|
||||
scope.selectAttrs = &attrs
|
||||
}
|
||||
return *scope.selectAttrs
|
||||
}
|
||||
|
||||
// OmitAttrs return omitted attributes
|
||||
func (scope *Scope) OmitAttrs() []string {
|
||||
return scope.Search.omits
|
||||
}
|
||||
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
@ -296,24 +333,6 @@ func (scope *Scope) CombinedConditionSql() string {
|
|||
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||
}
|
||||
|
||||
// FieldByName find `gorm.Field` with field name or db name
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
var (
|
||||
dbName = ToDBName(name)
|
||||
mostMatchedField *Field
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name || field.DBName == name {
|
||||
return field, true
|
||||
}
|
||||
if field.DBName == dbName {
|
||||
mostMatchedField = field
|
||||
}
|
||||
}
|
||||
return mostMatchedField, mostMatchedField != nil
|
||||
}
|
||||
|
||||
// Raw set raw sql
|
||||
func (scope *Scope) Raw(sql string) *Scope {
|
||||
scope.SQL = strings.Replace(sql, "$$", "?", -1)
|
||||
|
@ -389,35 +408,56 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
// SelectAttrs return selected attributes
|
||||
func (scope *Scope) SelectAttrs() []string {
|
||||
if scope.selectAttrs == nil {
|
||||
attrs := []string{}
|
||||
for _, value := range scope.Search.selects {
|
||||
if str, ok := value.(string); ok {
|
||||
attrs = append(attrs, str)
|
||||
} else if strs, ok := value.([]string); ok {
|
||||
attrs = append(attrs, strs...)
|
||||
} else if strs, ok := value.([]interface{}); ok {
|
||||
for _, str := range strs {
|
||||
attrs = append(attrs, fmt.Sprintf("%v", str))
|
||||
}
|
||||
}
|
||||
}
|
||||
scope.selectAttrs = &attrs
|
||||
}
|
||||
return *scope.selectAttrs
|
||||
}
|
||||
|
||||
// OmitAttrs return omitted attributes
|
||||
func (scope *Scope) OmitAttrs() []string {
|
||||
return scope.Search.omits
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Private Methods For *gorm.Scope
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (scope *Scope) fieldsMap() map[string]*Field {
|
||||
var results = map[string]*Field{}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
results[field.DBName] = field
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
|
||||
if reflectValue.CanAddr() {
|
||||
reflectValue = reflectValue.Addr()
|
||||
}
|
||||
|
||||
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
|
||||
switch method := methodValue.Interface().(type) {
|
||||
case func():
|
||||
method()
|
||||
case func(*Scope):
|
||||
method(scope)
|
||||
case func(*DB):
|
||||
newDB := scope.NewDB()
|
||||
method(newDB)
|
||||
scope.Err(newDB.Error)
|
||||
case func() error:
|
||||
scope.Err(method())
|
||||
case func(*Scope) error:
|
||||
scope.Err(method(scope))
|
||||
case func(*DB) error:
|
||||
newDB := scope.NewDB()
|
||||
scope.Err(method(newDB))
|
||||
scope.Err(newDB.Error)
|
||||
default:
|
||||
scope.Err(fmt.Errorf("unsupported function %v", methodName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) quoteIfPossible(str string) string {
|
||||
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
||||
return scope.Quote(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) {
|
||||
var values = make([]interface{}, len(columns))
|
||||
var ignored interface{}
|
||||
|
@ -448,6 +488,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) primaryCondition(value interface{}) string {
|
||||
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue