mirror of https://github.com/go-gorm/gorm.git
Review and Refactor
This commit is contained in:
parent
38cbff9c79
commit
0b32041135
|
@ -7,12 +7,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Association struct {
|
type Association struct {
|
||||||
Scope *Scope
|
Scope *Scope
|
||||||
PrimaryKey interface{}
|
PrimaryKey interface{}
|
||||||
PrimaryType interface{}
|
Column string
|
||||||
Column string
|
Error error
|
||||||
Error error
|
Field *Field
|
||||||
Field *Field
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) setErr(err error) *Association {
|
func (association *Association) setErr(err error) *Association {
|
||||||
|
@ -158,11 +157,11 @@ func (association *Association) Count() int {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||||
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
|
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
|
||||||
if relationship.ForeignType != "" {
|
if relationship.ForeignType != "" {
|
||||||
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
|
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
|
||||||
}
|
}
|
||||||
countScope.Count(&count)
|
countScope.Count(&count)
|
||||||
} else if relationship.Kind == "belongs_to" {
|
} else if relationship.Kind == "belongs_to" {
|
||||||
if v, err := scope.FieldValueByName(association.Column); err == nil {
|
if v, ok := scope.FieldByName(association.Column); ok {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||||
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
|
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,11 +26,10 @@ func Create(scope *Scope) {
|
||||||
var sqls, columns []string
|
var sqls, columns []string
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
|
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
|
||||||
if field.DefaultValue != nil && field.IsBlank {
|
if !field.IsBlank || field.DefaultValue == nil {
|
||||||
continue
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
|
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||||
}
|
}
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
|
||||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,18 +41,16 @@ func Update(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
if ok {
|
|
||||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
for key, value := range updateAttrs.(map[string]interface{}) {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
|
if !field.IsPrimaryKey && field.IsNormal {
|
||||||
if field.DefaultValue != nil && field.IsBlank {
|
if !field.IsBlank || field.DefaultValue == nil {
|
||||||
continue
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
}
|
}
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,8 +66,7 @@ func Update(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterUpdate(scope *Scope) {
|
func AfterUpdate(scope *Scope) {
|
||||||
_, ok := scope.Get("gorm:update_column")
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
if !ok {
|
|
||||||
scope.CallMethod("AfterUpdate")
|
scope.CallMethod("AfterUpdate")
|
||||||
scope.CallMethod("AfterSave")
|
scope.CallMethod("AfterSave")
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type commonDialect struct{}
|
type commonDialect struct{}
|
||||||
|
@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
||||||
}
|
}
|
||||||
return "VARCHAR(65532)"
|
return "VARCHAR(65532)"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if value.Type() == timeType {
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
return "TIMESTAMP"
|
return "TIMESTAMP"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -3,11 +3,8 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var timeType = reflect.TypeOf(time.Time{})
|
|
||||||
|
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
BinVar(i int) string
|
BinVar(i int) string
|
||||||
SupportLastInsertId() bool
|
SupportLastInsertId() bool
|
||||||
|
|
47
field.go
47
field.go
|
@ -12,7 +12,7 @@ type Field struct {
|
||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (field *Field) Set(value interface{}) (err error) {
|
func (field *Field) Set(value interface{}) error {
|
||||||
if !field.Field.IsValid() {
|
if !field.Field.IsValid() {
|
||||||
return errors.New("field value not valid")
|
return errors.New("field value not valid")
|
||||||
}
|
}
|
||||||
|
@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||||
scanner.Scan(value)
|
if v, ok := value.(reflect.Value); ok {
|
||||||
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
|
scanner.Scan(v.Interface())
|
||||||
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
|
} else {
|
||||||
|
scanner.Scan(value)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return errors.New("could not convert argument")
|
reflectValue, ok := value.(reflect.Value)
|
||||||
|
if !ok {
|
||||||
|
reflectValue = reflect.ValueOf(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
||||||
|
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
||||||
|
} else {
|
||||||
|
return errors.New("could not convert argument")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
field.IsBlank = isBlank(field.Field)
|
field.IsBlank = isBlank(field.Field)
|
||||||
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fields get value's fields
|
// Fields get value's fields
|
||||||
|
@ -44,8 +54,14 @@ func (scope *Scope) Fields() map[string]*Field {
|
||||||
fields := map[string]*Field{}
|
fields := map[string]*Field{}
|
||||||
structFields := scope.GetStructFields()
|
structFields := scope.GetStructFields()
|
||||||
|
|
||||||
|
indirectValue := scope.IndirectValue()
|
||||||
|
isStruct := indirectValue.Kind() == reflect.Struct
|
||||||
for _, structField := range structFields {
|
for _, structField := range structFields {
|
||||||
fields[structField.DBName] = scope.getField(structField)
|
if isStruct {
|
||||||
|
fields[structField.DBName] = getField(indirectValue, structField)
|
||||||
|
} else {
|
||||||
|
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.fields = fields
|
scope.fields = fields
|
||||||
|
@ -53,15 +69,12 @@ func (scope *Scope) Fields() map[string]*Field {
|
||||||
return scope.fields
|
return scope.fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) getField(structField *StructField) *Field {
|
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
||||||
field := Field{StructField: structField}
|
field := &Field{StructField: structField}
|
||||||
indirectValue := scope.IndirectValue()
|
for _, name := range structField.Names {
|
||||||
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
|
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
||||||
for _, name := range structField.Names {
|
|
||||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
|
||||||
}
|
|
||||||
field.Field = indirectValue
|
|
||||||
}
|
}
|
||||||
|
field.Field = indirectValue
|
||||||
field.IsBlank = isBlank(indirectValue)
|
field.IsBlank = isBlank(indirectValue)
|
||||||
return &field
|
return field
|
||||||
}
|
}
|
||||||
|
|
14
logger.go
14
logger.go
|
@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||||
// Format log
|
// Format log
|
||||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||||
|
|
||||||
func (logger Logger) Print(v ...interface{}) {
|
func (logger Logger) Print(values ...interface{}) {
|
||||||
if len(v) > 1 {
|
if len(values) > 1 {
|
||||||
level := v[0]
|
level := values[0]
|
||||||
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
||||||
source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1])
|
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
|
||||||
messages := []interface{}{source, currentTime}
|
messages := []interface{}{source, currentTime}
|
||||||
|
|
||||||
if level == "sql" {
|
if level == "sql" {
|
||||||
// duration
|
// duration
|
||||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||||
// sql
|
// sql
|
||||||
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...))
|
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...))
|
||||||
} else {
|
} else {
|
||||||
messages = append(messages, "\033[31;1m")
|
messages = append(messages, "\033[31;1m")
|
||||||
messages = append(messages, v[2:]...)
|
messages = append(messages, values[2:]...)
|
||||||
messages = append(messages, "\033[0m")
|
messages = append(messages, "\033[0m")
|
||||||
}
|
}
|
||||||
logger.Println(messages...)
|
logger.Println(messages...)
|
||||||
|
|
114
main.go
114
main.go
|
@ -30,7 +30,6 @@ type DB struct {
|
||||||
logMode int
|
logMode int
|
||||||
logger logger
|
logger logger
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
tagIdentifier string
|
|
||||||
singularTable bool
|
singularTable bool
|
||||||
source string
|
source string
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
@ -39,33 +38,38 @@ type DB struct {
|
||||||
func Open(dialect string, args ...interface{}) (DB, error) {
|
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
var db DB
|
var db DB
|
||||||
var err error
|
var err error
|
||||||
var source string
|
|
||||||
var dbSql sqlCommon
|
|
||||||
|
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
err = errors.New("invalid database source")
|
err = errors.New("invalid database source")
|
||||||
}
|
} else {
|
||||||
|
var source string
|
||||||
|
var dbSql sqlCommon
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
var driver = dialect
|
var driver = dialect
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
source = value
|
source = value
|
||||||
} else if len(args) >= 2 {
|
} else if len(args) >= 2 {
|
||||||
driver = value
|
driver = value
|
||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
|
}
|
||||||
|
dbSql, err = sql.Open(driver, source)
|
||||||
|
case sqlCommon:
|
||||||
|
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||||
|
dbSql = value
|
||||||
}
|
}
|
||||||
dbSql, err = sql.Open(driver, source)
|
|
||||||
case sqlCommon:
|
|
||||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
|
||||||
dbSql = value
|
|
||||||
}
|
|
||||||
|
|
||||||
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
db = DB{
|
||||||
logger: defaultLogger, callback: DefaultCallback, source: source,
|
dialect: NewDialect(dialect),
|
||||||
values: map[string]interface{}{}}
|
logger: defaultLogger,
|
||||||
db.db = dbSql
|
callback: DefaultCallback,
|
||||||
db.parent = &db
|
source: source,
|
||||||
|
values: map[string]interface{}{},
|
||||||
|
db: dbSql,
|
||||||
|
}
|
||||||
|
db.parent = &db
|
||||||
|
}
|
||||||
|
|
||||||
return db, err
|
return db, err
|
||||||
}
|
}
|
||||||
|
@ -84,7 +88,7 @@ func (s *DB) New() *DB {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the underlying sql.DB or sql.Tx instance.
|
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||||
// Use of this method is discouraged. It's mainly intended to allow
|
// Use of this method is discouraged. It's mainly intended to allow
|
||||||
// coexistence with legacy non-GORM code.
|
// coexistence with legacy non-GORM code.
|
||||||
func (s *DB) CommonDB() sqlCommon {
|
func (s *DB) CommonDB() sqlCommon {
|
||||||
|
@ -96,16 +100,12 @@ func (s *DB) Callback() *callback {
|
||||||
return s.parent.callback
|
return s.parent.callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SetTagIdentifier(str string) {
|
|
||||||
s.parent.tagIdentifier = str
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) SetLogger(l logger) {
|
func (s *DB) SetLogger(l logger) {
|
||||||
s.parent.logger = l
|
s.parent.logger = l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) LogMode(b bool) *DB {
|
func (s *DB) LogMode(enable bool) *DB {
|
||||||
if b {
|
if enable {
|
||||||
s.logMode = 2
|
s.logMode = 2
|
||||||
} else {
|
} else {
|
||||||
s.logMode = 1
|
s.logMode = 1
|
||||||
|
@ -113,8 +113,8 @@ func (s *DB) LogMode(b bool) *DB {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SingularTable(b bool) {
|
func (s *DB) SingularTable(enable bool) {
|
||||||
s.parent.singularTable = b
|
s.parent.singularTable = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||||
|
@ -158,11 +158,10 @@ func (s *DB) Joins(query string) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||||
c := s
|
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
c = f(c)
|
s = f(s)
|
||||||
}
|
}
|
||||||
return c
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Unscoped() *DB {
|
func (s *DB) Unscoped() *DB {
|
||||||
|
@ -179,16 +178,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||||
|
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search = newScope.Search.clone()
|
newScope.Search = newScope.Search.clone().limit(1)
|
||||||
newScope.Search.limit(1)
|
|
||||||
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
|
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search = newScope.Search.clone()
|
newScope.Search = newScope.Search.clone().limit(1)
|
||||||
newScope.Search.limit(1)
|
|
||||||
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
|
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
@ -213,10 +210,9 @@ func (s *DB) Scan(dest interface{}) *DB {
|
||||||
|
|
||||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
r := c.First(out, where...)
|
if result := c.First(out, where...); result.Error != nil {
|
||||||
if r.Error != nil {
|
if !result.RecordNotFound() {
|
||||||
if !r.RecordNotFound() {
|
return result
|
||||||
return r
|
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize()
|
c.NewScope(out).inlineCondition(where...).initialize()
|
||||||
} else {
|
} else {
|
||||||
|
@ -227,10 +223,9 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
|
|
||||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
r := c.First(out, where...)
|
if result := c.First(out, where...); result.Error != nil {
|
||||||
if r.Error != nil {
|
if !result.RecordNotFound() {
|
||||||
if !r.RecordNotFound() {
|
return result
|
||||||
return r
|
|
||||||
}
|
}
|
||||||
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
||||||
} else if len(c.search.AssignAttrs) > 0 {
|
} else if len(c.search.AssignAttrs) > 0 {
|
||||||
|
@ -418,25 +413,24 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Association(column string) *Association {
|
func (s *DB) Association(column string) *Association {
|
||||||
|
var err error
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
|
||||||
primaryKey := scope.PrimaryKeyValue()
|
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
|
||||||
primaryType := scope.TableName()
|
err = errors.New("primary key can't be nil")
|
||||||
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
|
||||||
scope.Err(errors.New("primary key can't be nil"))
|
|
||||||
}
|
|
||||||
|
|
||||||
var field *Field
|
|
||||||
var ok bool
|
|
||||||
if field, ok = scope.FieldByName(column); ok {
|
|
||||||
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
|
|
||||||
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
|
if field, ok := scope.FieldByName(column); ok {
|
||||||
|
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
|
||||||
|
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||||
|
} else {
|
||||||
|
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
|
return &Association{Error: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||||
|
|
221
model_struct.go
221
model_struct.go
|
@ -133,137 +133,134 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
|
|
||||||
// Set fields
|
// Set fields
|
||||||
for i := 0; i < scopeType.NumField(); i++ {
|
for i := 0; i < scopeType.NumField(); i++ {
|
||||||
fieldStruct := scopeType.Field(i)
|
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
field := &StructField{
|
||||||
continue
|
Struct: fieldStruct,
|
||||||
}
|
Name: fieldStruct.Name,
|
||||||
|
Names: []string{fieldStruct.Name},
|
||||||
field := &StructField{
|
Tag: fieldStruct.Tag,
|
||||||
Struct: fieldStruct,
|
|
||||||
Name: fieldStruct.Name,
|
|
||||||
Names: []string{fieldStruct.Name},
|
|
||||||
Tag: fieldStruct.Tag,
|
|
||||||
}
|
|
||||||
|
|
||||||
if fieldStruct.Tag.Get("sql") == "-" {
|
|
||||||
field.IsIgnored = true
|
|
||||||
} else {
|
|
||||||
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
|
||||||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
|
||||||
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
|
|
||||||
field.IsPrimaryKey = true
|
|
||||||
modelStruct.PrimaryKeyField = field
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, ok := sqlSettings["DEFAULT"]; ok {
|
if fieldStruct.Tag.Get("sql") == "-" {
|
||||||
field.DefaultValue = &value
|
field.IsIgnored = true
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := gormSettings["COLUMN"]; ok {
|
|
||||||
field.DBName = value
|
|
||||||
} else {
|
} else {
|
||||||
field.DBName = ToSnake(fieldStruct.Name)
|
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
||||||
}
|
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||||
|
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
|
||||||
|
field.IsPrimaryKey = true
|
||||||
|
modelStruct.PrimaryKeyField = field
|
||||||
|
}
|
||||||
|
|
||||||
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
|
if value, ok := sqlSettings["DEFAULT"]; ok {
|
||||||
if indirectType.Kind() == reflect.Ptr {
|
field.DefaultValue = &value
|
||||||
indirectType = indirectType.Elem()
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
|
if value, ok := gormSettings["COLUMN"]; ok {
|
||||||
field.IsScanner, field.IsNormal = true, true
|
field.DBName = value
|
||||||
}
|
} else {
|
||||||
|
field.DBName = ToSnake(fieldStruct.Name)
|
||||||
|
}
|
||||||
|
|
||||||
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
|
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
|
||||||
field.IsTime, field.IsNormal = true, true
|
if indirectType.Kind() == reflect.Ptr {
|
||||||
}
|
indirectType = indirectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
many2many := gormSettings["MANY2MANY"]
|
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
|
||||||
foreignKey := SnakeToUpperCamel(gormSettings["FOREIGNKEY"])
|
field.IsScanner, field.IsNormal = true, true
|
||||||
foreignType := SnakeToUpperCamel(gormSettings["FOREIGNTYPE"])
|
}
|
||||||
associationForeignKey := SnakeToUpperCamel(gormSettings["ASSOCIATIONFOREIGNKEY"])
|
|
||||||
if polymorphic := SnakeToUpperCamel(gormSettings["POLYMORPHIC"]); polymorphic != "" {
|
|
||||||
foreignKey = polymorphic + "Id"
|
|
||||||
foreignType = polymorphic + "Type"
|
|
||||||
}
|
|
||||||
|
|
||||||
if !field.IsNormal {
|
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
|
||||||
switch indirectType.Kind() {
|
field.IsTime, field.IsNormal = true, true
|
||||||
case reflect.Slice:
|
}
|
||||||
typ := indirectType.Elem()
|
|
||||||
if typ.Kind() == reflect.Ptr {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if typ.Kind() == reflect.Struct {
|
many2many := gormSettings["MANY2MANY"]
|
||||||
kind := "has_many"
|
foreignKey := SnakeToUpperCamel(gormSettings["FOREIGNKEY"])
|
||||||
|
foreignType := SnakeToUpperCamel(gormSettings["FOREIGNTYPE"])
|
||||||
|
associationForeignKey := SnakeToUpperCamel(gormSettings["ASSOCIATIONFOREIGNKEY"])
|
||||||
|
if polymorphic := SnakeToUpperCamel(gormSettings["POLYMORPHIC"]); polymorphic != "" {
|
||||||
|
foreignKey = polymorphic + "Id"
|
||||||
|
foreignType = polymorphic + "Type"
|
||||||
|
}
|
||||||
|
|
||||||
if foreignKey == "" {
|
if !field.IsNormal {
|
||||||
foreignKey = scopeType.Name() + "Id"
|
switch indirectType.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
typ := indirectType.Elem()
|
||||||
|
if typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if associationForeignKey == "" {
|
if typ.Kind() == reflect.Struct {
|
||||||
associationForeignKey = typ.Name() + "Id"
|
kind := "has_many"
|
||||||
|
|
||||||
|
if foreignKey == "" {
|
||||||
|
foreignKey = scopeType.Name() + "Id"
|
||||||
|
}
|
||||||
|
|
||||||
|
if associationForeignKey == "" {
|
||||||
|
associationForeignKey = typ.Name() + "Id"
|
||||||
|
}
|
||||||
|
|
||||||
|
if many2many != "" {
|
||||||
|
kind = "many_to_many"
|
||||||
|
} else if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
||||||
|
foreignKey = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Relationship = &Relationship{
|
||||||
|
JoinTable: many2many,
|
||||||
|
ForeignType: foreignType,
|
||||||
|
ForeignFieldName: foreignKey,
|
||||||
|
AssociationForeignFieldName: associationForeignKey,
|
||||||
|
ForeignDBName: ToSnake(foreignKey),
|
||||||
|
AssociationForeignDBName: ToSnake(associationForeignKey),
|
||||||
|
Kind: kind,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
field.IsNormal = true
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
||||||
|
for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() {
|
||||||
|
field.Names = append([]string{fieldStruct.Name}, field.Names...)
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
var belongsToForeignKey, hasOneForeignKey, kind string
|
||||||
|
|
||||||
|
if foreignKey == "" {
|
||||||
|
belongsToForeignKey = field.Name + "Id"
|
||||||
|
hasOneForeignKey = scopeType.Name() + "Id"
|
||||||
|
} else {
|
||||||
|
belongsToForeignKey = foreignKey
|
||||||
|
hasOneForeignKey = foreignKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := scopeType.FieldByName(belongsToForeignKey); ok {
|
||||||
|
kind = "belongs_to"
|
||||||
|
foreignKey = belongsToForeignKey
|
||||||
|
} else {
|
||||||
|
foreignKey = hasOneForeignKey
|
||||||
|
kind = "has_one"
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Relationship = &Relationship{
|
||||||
|
ForeignFieldName: foreignKey,
|
||||||
|
ForeignDBName: ToSnake(foreignKey),
|
||||||
|
ForeignType: foreignType,
|
||||||
|
Kind: kind,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if many2many != "" {
|
default:
|
||||||
kind = "many_to_many"
|
|
||||||
} else if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
|
|
||||||
foreignKey = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
field.Relationship = &Relationship{
|
|
||||||
JoinTable: many2many,
|
|
||||||
ForeignType: foreignType,
|
|
||||||
ForeignFieldName: foreignKey,
|
|
||||||
AssociationForeignFieldName: associationForeignKey,
|
|
||||||
ForeignDBName: ToSnake(foreignKey),
|
|
||||||
AssociationForeignDBName: ToSnake(associationForeignKey),
|
|
||||||
Kind: kind,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
|
||||||
for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() {
|
|
||||||
field.Names = append([]string{fieldStruct.Name}, field.Names...)
|
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
var belongsToForeignKey, hasOneForeignKey, kind string
|
|
||||||
|
|
||||||
if foreignKey == "" {
|
|
||||||
belongsToForeignKey = field.Name + "Id"
|
|
||||||
hasOneForeignKey = scopeType.Name() + "Id"
|
|
||||||
} else {
|
|
||||||
belongsToForeignKey = foreignKey
|
|
||||||
hasOneForeignKey = foreignKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := scopeType.FieldByName(belongsToForeignKey); ok {
|
|
||||||
kind = "belongs_to"
|
|
||||||
foreignKey = belongsToForeignKey
|
|
||||||
} else {
|
|
||||||
foreignKey = hasOneForeignKey
|
|
||||||
kind = "has_one"
|
|
||||||
}
|
|
||||||
|
|
||||||
field.Relationship = &Relationship{
|
|
||||||
ForeignFieldName: foreignKey,
|
|
||||||
ForeignDBName: ToSnake(foreignKey),
|
|
||||||
ForeignType: foreignType,
|
|
||||||
Kind: kind,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
field.IsNormal = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
}
|
}
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range modelStruct.StructFields {
|
for _, field := range modelStruct.StructFields {
|
||||||
|
|
3
mssql.go
3
mssql.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mssql struct{}
|
type mssql struct{}
|
||||||
|
@ -36,7 +37,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
||||||
}
|
}
|
||||||
return "text"
|
return "text"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if value.Type() == timeType {
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
return "datetime2"
|
return "datetime2"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
3
mysql.go
3
mysql.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysql struct{}
|
type mysql struct{}
|
||||||
|
@ -36,7 +37,7 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
||||||
}
|
}
|
||||||
return "longtext"
|
return "longtext"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if value.Type() == timeType {
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
return "timestamp NULL"
|
return "timestamp NULL"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq/hstore"
|
"github.com/lib/pq/hstore"
|
||||||
)
|
)
|
||||||
|
@ -40,7 +41,7 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
||||||
}
|
}
|
||||||
return "text"
|
return "text"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if value.Type() == timeType {
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
return "timestamp with time zone"
|
return "timestamp with time zone"
|
||||||
}
|
}
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
|
|
40
preload.go
40
preload.go
|
@ -7,7 +7,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getFieldValue(value reflect.Value, field string) interface{} {
|
func getRealValue(value reflect.Value, field string) interface{} {
|
||||||
result := reflect.Indirect(value).FieldByName(field).Interface()
|
result := reflect.Indirect(value).FieldByName(field).Interface()
|
||||||
if r, ok := result.(driver.Valuer); ok {
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
result, _ = r.Value()
|
result, _ = r.Value()
|
||||||
|
@ -20,26 +20,14 @@ func equalAsString(a interface{}, b interface{}) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(scope *Scope) {
|
func Preload(scope *Scope) {
|
||||||
// Get Fields
|
fields := scope.Fields()
|
||||||
var fields map[string]*Field
|
isSlice := scope.IndirectValue().Kind() == reflect.Slice
|
||||||
var isSlice bool
|
|
||||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
|
||||||
isSlice = true
|
|
||||||
typ := scope.IndirectValue().Type().Elem()
|
|
||||||
if typ.Kind() == reflect.Ptr {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
elem := reflect.New(typ).Elem()
|
|
||||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
|
||||||
} else {
|
|
||||||
fields = scope.Fields()
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope.Search.Preload != nil {
|
if scope.Search.Preload != nil {
|
||||||
for key, conditions := range scope.Search.Preload {
|
for key, conditions := range scope.Search.Preload {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == key && field.Relationship != nil {
|
if field.Name == key && field.Relationship != nil {
|
||||||
results := makeSlice(field.Field)
|
results := makeSlice(field.Struct.Type)
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
primaryName := scope.PrimaryKeyField().Name
|
primaryName := scope.PrimaryKeyField().Name
|
||||||
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
||||||
|
@ -53,10 +41,10 @@ func Preload(scope *Scope) {
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
result := resultValues.Index(i)
|
result := resultValues.Index(i)
|
||||||
if isSlice {
|
if isSlice {
|
||||||
value := getFieldValue(result, relation.ForeignFieldName)
|
value := getRealValue(result, relation.ForeignFieldName)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
|
||||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -72,11 +60,11 @@ func Preload(scope *Scope) {
|
||||||
if isSlice {
|
if isSlice {
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
result := resultValues.Index(i)
|
result := resultValues.Index(i)
|
||||||
value := getFieldValue(result, relation.ForeignFieldName)
|
value := getRealValue(result, relation.ForeignFieldName)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := reflect.Indirect(objects.Index(j))
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getFieldValue(object, primaryName), value) {
|
if equalAsString(getRealValue(object, primaryName), value) {
|
||||||
f := object.FieldByName(field.Name)
|
f := object.FieldByName(field.Name)
|
||||||
f.Set(reflect.Append(f, result))
|
f.Set(reflect.Append(f, result))
|
||||||
break
|
break
|
||||||
|
@ -92,11 +80,11 @@ func Preload(scope *Scope) {
|
||||||
for i := 0; i < resultValues.Len(); i++ {
|
for i := 0; i < resultValues.Len(); i++ {
|
||||||
result := resultValues.Index(i)
|
result := resultValues.Index(i)
|
||||||
if isSlice {
|
if isSlice {
|
||||||
value := getFieldValue(result, associationPrimaryKey)
|
value := getRealValue(result, associationPrimaryKey)
|
||||||
objects := scope.IndirectValue()
|
objects := scope.IndirectValue()
|
||||||
for j := 0; j < objects.Len(); j++ {
|
for j := 0; j < objects.Len(); j++ {
|
||||||
object := reflect.Indirect(objects.Index(j))
|
object := reflect.Indirect(objects.Index(j))
|
||||||
if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) {
|
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
|
||||||
object.FieldByName(field.Name).Set(result)
|
object.FieldByName(field.Name).Set(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -116,9 +104,8 @@ func Preload(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeSlice(value reflect.Value) interface{} {
|
func makeSlice(typ reflect.Type) interface{} {
|
||||||
typ := value.Type()
|
if typ.Kind() == reflect.Slice {
|
||||||
if value.Kind() == reflect.Slice {
|
|
||||||
typ = typ.Elem()
|
typ = typ.Elem()
|
||||||
}
|
}
|
||||||
sliceType := reflect.SliceOf(typ)
|
sliceType := reflect.SliceOf(typ)
|
||||||
|
@ -132,8 +119,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
|
||||||
switch values.Kind() {
|
switch values.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
value := values.Index(i)
|
primaryKeys = append(primaryKeys, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||||
primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
|
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
return []interface{}{values.FieldByName(column).Interface()}
|
return []interface{}{values.FieldByName(column).Interface()}
|
||||||
|
|
58
scope.go
58
scope.go
|
@ -111,7 +111,8 @@ func (scope *Scope) PrimaryKey() string {
|
||||||
|
|
||||||
// PrimaryKeyZero check the primary key is blank or not
|
// PrimaryKeyZero check the primary key is blank or not
|
||||||
func (scope *Scope) PrimaryKeyZero() bool {
|
func (scope *Scope) PrimaryKeyZero() bool {
|
||||||
return isBlank(reflect.ValueOf(scope.PrimaryKeyValue()))
|
field := scope.PrimaryKeyField()
|
||||||
|
return field == nil || field.IsBlank
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKeyValue get the primary key's value
|
// PrimaryKeyValue get the primary key's value
|
||||||
|
@ -125,35 +126,23 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
|
||||||
// HasColumn to check if has column
|
// HasColumn to check if has column
|
||||||
func (scope *Scope) HasColumn(column string) bool {
|
func (scope *Scope) HasColumn(column string) bool {
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetStructFields() {
|
||||||
if !field.IsIgnored {
|
if field.IsNormal && (field.Name == column || field.DBName == column) {
|
||||||
if field.Name == column || field.DBName == column {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// FieldValueByName to get column's value and existence
|
|
||||||
func (scope *Scope) FieldValueByName(name string) (interface{}, error) {
|
|
||||||
return FieldValueByName(name, scope.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetColumn to set the column's value
|
// SetColumn to set the column's value
|
||||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||||
if field, ok := column.(*Field); ok {
|
if field, ok := column.(*Field); ok {
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
} else if dbName, ok := column.(string); ok {
|
} else if dbName, ok := column.(string); ok {
|
||||||
if scope.Value == nil {
|
|
||||||
return errors.New("scope value must not be nil for string columns")
|
|
||||||
}
|
|
||||||
|
|
||||||
if field, ok := scope.Fields()[dbName]; ok {
|
if field, ok := scope.Fields()[dbName]; ok {
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbName = ToSnake(dbName)
|
dbName = ToSnake(dbName)
|
||||||
|
|
||||||
if field, ok := scope.Fields()[dbName]; ok {
|
if field, ok := scope.Fields()[dbName]; ok {
|
||||||
return field.Set(value)
|
return field.Set(value)
|
||||||
}
|
}
|
||||||
|
@ -204,45 +193,11 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName get table name
|
// TableName get table name
|
||||||
|
|
||||||
func (scope *Scope) TableName() string {
|
func (scope *Scope) TableName() string {
|
||||||
if scope.Search != nil && len(scope.Search.TableName) > 0 {
|
if scope.Search != nil && len(scope.Search.TableName) > 0 {
|
||||||
return scope.Search.TableName
|
return scope.Search.TableName
|
||||||
}
|
}
|
||||||
|
return scope.GetModelStruct().TableName
|
||||||
if scope.Value == nil {
|
|
||||||
scope.Err(errors.New("can't get table name"))
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
data := scope.IndirectValue()
|
|
||||||
if data.Kind() == reflect.Slice {
|
|
||||||
elem := data.Type().Elem()
|
|
||||||
if elem.Kind() == reflect.Ptr {
|
|
||||||
elem = elem.Elem()
|
|
||||||
}
|
|
||||||
data = reflect.New(elem).Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
|
||||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
|
||||||
if result, ok := v[0].Interface().(string); ok {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
str := ToSnake(data.Type().Name())
|
|
||||||
|
|
||||||
if scope.db == nil || !scope.db.parent.singularTable {
|
|
||||||
for index, reg := range pluralMapKeys {
|
|
||||||
if reg.MatchString(str) {
|
|
||||||
return reg.ReplaceAllString(str, pluralMapValues[index])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return str
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) QuotedTableName() string {
|
func (scope *Scope) QuotedTableName() string {
|
||||||
|
@ -284,8 +239,7 @@ func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.Trace(NowFunc())
|
defer scope.Trace(NowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
if scope.Err(err) == nil {
|
|
||||||
if count, err := result.RowsAffected(); err == nil {
|
if count, err := result.RowsAffected(); err == nil {
|
||||||
scope.db.RowsAffected = count
|
scope.db.RowsAffected = count
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,10 +25,8 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
} else if value != "" {
|
} else if value != "" {
|
||||||
str = fmt.Sprintf("(%v)", value)
|
str = fmt.Sprintf("(%v)", value)
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
return scope.primaryCondition(scope.AddToVars(value))
|
return scope.primaryCondition(scope.AddToVars(value))
|
||||||
case sql.NullInt64:
|
|
||||||
return scope.primaryCondition(scope.AddToVars(value.Int64))
|
|
||||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||||
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
||||||
clause["args"] = []interface{}{value}
|
clause["args"] = []interface{}{value}
|
||||||
|
@ -71,12 +69,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
|
|
||||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
var notEqualSql string
|
var notEqualSql string
|
||||||
|
var primaryKey = scope.PrimaryKey()
|
||||||
|
|
||||||
switch value := clause["query"].(type) {
|
switch value := clause["query"].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
// is number
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
id, _ := strconv.Atoi(value)
|
id, _ := strconv.Atoi(value)
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
||||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
||||||
|
@ -84,15 +84,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||||
if reflect.ValueOf(value).Len() > 0 {
|
if reflect.ValueOf(value).Len() > 0 {
|
||||||
str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey()))
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
||||||
clause["args"] = []interface{}{value}
|
clause["args"] = []interface{}{value}
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
return ""
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range value {
|
for key, value := range value {
|
||||||
|
@ -157,16 +156,10 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) where(where ...interface{}) {
|
|
||||||
if len(where) > 0 {
|
|
||||||
scope.Search = scope.Search.clone().where(where[0], where[1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scope *Scope) whereSql() (sql string) {
|
func (scope *Scope) whereSql() (sql string) {
|
||||||
var primaryConditions, andConditions, orConditions []string
|
var primaryConditions, andConditions, orConditions []string
|
||||||
|
|
||||||
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
|
if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
|
||||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
||||||
primaryConditions = append(primaryConditions, sql)
|
primaryConditions = append(primaryConditions, sql)
|
||||||
}
|
}
|
||||||
|
@ -317,41 +310,19 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
||||||
data := scope.IndirectValue()
|
if !scope.IndirectValue().CanAddr() {
|
||||||
if !data.CanAddr() {
|
|
||||||
return values, true
|
return values, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fields := scope.Fields()
|
||||||
for key, value := range values {
|
for key, value := range values {
|
||||||
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
|
if field, ok := fields[ToSnake(key)]; ok && field.Field.IsValid() {
|
||||||
func() {
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
defer func() {
|
if !equalAsString(field.Field.Interface(), value) {
|
||||||
if err := recover(); err != nil {
|
hasUpdate = true
|
||||||
hasUpdate = true
|
field.Set(value)
|
||||||
field.Set(value)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if field.Field.Interface() != value {
|
|
||||||
switch field.Field.Kind() {
|
|
||||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
|
||||||
if s, ok := value.(string); ok {
|
|
||||||
i, err := strconv.Atoi(s)
|
|
||||||
if scope.Err(err) == nil {
|
|
||||||
value = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.Field.Int() != reflect.ValueOf(value).Int() {
|
|
||||||
hasUpdate = true
|
|
||||||
field.Set(value)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
hasUpdate = true
|
|
||||||
field.Set(value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -464,7 +435,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if field.Relationship != nil && field.Relationship.JoinTable != "" {
|
if field.Relationship != nil && field.Relationship.JoinTable != "" {
|
||||||
if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) {
|
if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) {
|
||||||
newScope := scope.db.NewScope("")
|
newScope := scope.db.NewScope("")
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
||||||
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
|
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
field.Relationship.JoinTable,
|
field.Relationship.JoinTable,
|
||||||
strings.Join([]string{
|
strings.Join([]string{
|
||||||
|
@ -523,14 +494,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||||
var table = scope.TableName()
|
var table = scope.TableName()
|
||||||
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||||
var query = `
|
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||||
ALTER TABLE %s
|
|
||||||
ADD CONSTRAINT %s
|
|
||||||
FOREIGN KEY (%s)
|
|
||||||
REFERENCES %s
|
|
||||||
ON DELETE %s
|
|
||||||
ON UPDATE %s;
|
|
||||||
`
|
|
||||||
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
|
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlite3 struct{}
|
type sqlite3 struct{}
|
||||||
|
@ -35,7 +36,7 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||||
}
|
}
|
||||||
return "text"
|
return "text"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if value.Type() == timeType {
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
return "datetime"
|
return "datetime"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
20
utils.go
20
utils.go
|
@ -2,9 +2,6 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -26,23 +23,6 @@ func (s *safeMap) Get(key string) string {
|
||||||
return s.m[key]
|
return s.m[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
func FieldValueByName(name string, value interface{}) (i interface{}, err error) {
|
|
||||||
data := reflect.Indirect(reflect.ValueOf(value))
|
|
||||||
name = SnakeToUpperCamel(name)
|
|
||||||
|
|
||||||
if data.Kind() == reflect.Struct {
|
|
||||||
if field := data.FieldByName(name); field.IsValid() {
|
|
||||||
i = field.Interface()
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("struct has no field with name %s", name)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("value must be of kind struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSafeMap() *safeMap {
|
func newSafeMap() *safeMap {
|
||||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,18 +2,16 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func fileWithLineNum() string {
|
func fileWithLineNum() string {
|
||||||
for i := 2; i < 15; i++ {
|
for i := 2; i < 15; i++ {
|
||||||
_, file, line, ok := runtime.Caller(i)
|
_, file, line, ok := runtime.Caller(i)
|
||||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||||
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
|
return fmt.Sprintf("%v:%v", file, line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|
Loading…
Reference in New Issue