mirror of https://github.com/go-gorm/gorm.git
Review and Refactor
This commit is contained in:
parent
38cbff9c79
commit
0b32041135
|
@ -9,7 +9,6 @@ import (
|
|||
type Association struct {
|
||||
Scope *Scope
|
||||
PrimaryKey interface{}
|
||||
PrimaryType interface{}
|
||||
Column string
|
||||
Error error
|
||||
Field *Field
|
||||
|
@ -158,11 +157,11 @@ func (association *Association) Count() int {
|
|||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
|
||||
if relationship.ForeignType != "" {
|
||||
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
|
||||
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
|
||||
}
|
||||
countScope.Count(&count)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
if v, err := scope.FieldValueByName(association.Column); err == nil {
|
||||
if v, ok := scope.FieldByName(association.Column); ok {
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
|
||||
}
|
||||
|
|
|
@ -26,13 +26,12 @@ func Create(scope *Scope) {
|
|||
var sqls, columns []string
|
||||
for _, field := range scope.Fields() {
|
||||
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
|
||||
if field.DefaultValue != nil && field.IsBlank {
|
||||
continue
|
||||
}
|
||||
if !field.IsBlank || field.DefaultValue == nil {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
returningKey := "*"
|
||||
primaryField := scope.PrimaryKeyField()
|
||||
|
|
|
@ -41,21 +41,19 @@ func Update(scope *Scope) {
|
|||
if !scope.HasError() {
|
||||
var sqls []string
|
||||
|
||||
updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
|
||||
if ok {
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
for key, value := range updateAttrs.(map[string]interface{}) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
|
||||
if field.DefaultValue != nil && field.IsBlank {
|
||||
continue
|
||||
}
|
||||
if !field.IsPrimaryKey && field.IsNormal {
|
||||
if !field.IsBlank || field.DefaultValue == nil {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
|
@ -68,8 +66,7 @@ func Update(scope *Scope) {
|
|||
}
|
||||
|
||||
func AfterUpdate(scope *Scope) {
|
||||
_, ok := scope.Get("gorm:update_column")
|
||||
if !ok {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.CallMethod("AfterUpdate")
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type commonDialect struct{}
|
||||
|
@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
|||
}
|
||||
return "VARCHAR(65532)"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -3,11 +3,8 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
var timeType = reflect.TypeOf(time.Time{})
|
||||
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
|
|
37
field.go
37
field.go
|
@ -12,7 +12,7 @@ type Field struct {
|
|||
Field reflect.Value
|
||||
}
|
||||
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
func (field *Field) Set(value interface{}) error {
|
||||
if !field.Field.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
}
|
||||
|
@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||
}
|
||||
|
||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||
if v, ok := value.(reflect.Value); ok {
|
||||
scanner.Scan(v.Interface())
|
||||
} else {
|
||||
scanner.Scan(value)
|
||||
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
|
||||
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
|
||||
}
|
||||
} else {
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
||||
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
||||
} else {
|
||||
return errors.New("could not convert argument")
|
||||
}
|
||||
}
|
||||
|
||||
field.IsBlank = isBlank(field.Field)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fields get value's fields
|
||||
|
@ -44,8 +54,14 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
fields := map[string]*Field{}
|
||||
structFields := scope.GetStructFields()
|
||||
|
||||
indirectValue := scope.IndirectValue()
|
||||
isStruct := indirectValue.Kind() == reflect.Struct
|
||||
for _, structField := range structFields {
|
||||
fields[structField.DBName] = scope.getField(structField)
|
||||
if isStruct {
|
||||
fields[structField.DBName] = getField(indirectValue, structField)
|
||||
} else {
|
||||
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
|
||||
}
|
||||
}
|
||||
|
||||
scope.fields = fields
|
||||
|
@ -53,15 +69,12 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
return scope.fields
|
||||
}
|
||||
|
||||
func (scope *Scope) getField(structField *StructField) *Field {
|
||||
field := Field{StructField: structField}
|
||||
indirectValue := scope.IndirectValue()
|
||||
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
|
||||
func getField(indirectValue reflect.Value, structField *StructField) *Field {
|
||||
field := &Field{StructField: structField}
|
||||
for _, name := range structField.Names {
|
||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
||||
}
|
||||
field.Field = indirectValue
|
||||
}
|
||||
field.IsBlank = isBlank(indirectValue)
|
||||
return &field
|
||||
return field
|
||||
}
|
||||
|
|
14
logger.go
14
logger.go
|
@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
|||
// Format log
|
||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||
|
||||
func (logger Logger) Print(v ...interface{}) {
|
||||
if len(v) > 1 {
|
||||
level := v[0]
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
if len(values) > 1 {
|
||||
level := values[0]
|
||||
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
||||
source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1])
|
||||
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
|
||||
messages := []interface{}{source, currentTime}
|
||||
|
||||
if level == "sql" {
|
||||
// duration
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
// sql
|
||||
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...))
|
||||
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...))
|
||||
} else {
|
||||
messages = append(messages, "\033[31;1m")
|
||||
messages = append(messages, v[2:]...)
|
||||
messages = append(messages, values[2:]...)
|
||||
messages = append(messages, "\033[0m")
|
||||
}
|
||||
logger.Println(messages...)
|
||||
|
|
82
main.go
82
main.go
|
@ -30,7 +30,6 @@ type DB struct {
|
|||
logMode int
|
||||
logger logger
|
||||
dialect Dialect
|
||||
tagIdentifier string
|
||||
singularTable bool
|
||||
source string
|
||||
values map[string]interface{}
|
||||
|
@ -39,12 +38,12 @@ type DB struct {
|
|||
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||
var db DB
|
||||
var err error
|
||||
var source string
|
||||
var dbSql sqlCommon
|
||||
|
||||
if len(args) == 0 {
|
||||
err = errors.New("invalid database source")
|
||||
}
|
||||
} else {
|
||||
var source string
|
||||
var dbSql sqlCommon
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
|
@ -61,11 +60,16 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
|||
dbSql = value
|
||||
}
|
||||
|
||||
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
|
||||
logger: defaultLogger, callback: DefaultCallback, source: source,
|
||||
values: map[string]interface{}{}}
|
||||
db.db = dbSql
|
||||
db = DB{
|
||||
dialect: NewDialect(dialect),
|
||||
logger: defaultLogger,
|
||||
callback: DefaultCallback,
|
||||
source: source,
|
||||
values: map[string]interface{}{},
|
||||
db: dbSql,
|
||||
}
|
||||
db.parent = &db
|
||||
}
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
@ -84,7 +88,7 @@ func (s *DB) New() *DB {
|
|||
return clone
|
||||
}
|
||||
|
||||
// Return the underlying sql.DB or sql.Tx instance.
|
||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||
// Use of this method is discouraged. It's mainly intended to allow
|
||||
// coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() sqlCommon {
|
||||
|
@ -96,16 +100,12 @@ func (s *DB) Callback() *callback {
|
|||
return s.parent.callback
|
||||
}
|
||||
|
||||
func (s *DB) SetTagIdentifier(str string) {
|
||||
s.parent.tagIdentifier = str
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l logger) {
|
||||
s.parent.logger = l
|
||||
}
|
||||
|
||||
func (s *DB) LogMode(b bool) *DB {
|
||||
if b {
|
||||
func (s *DB) LogMode(enable bool) *DB {
|
||||
if enable {
|
||||
s.logMode = 2
|
||||
} else {
|
||||
s.logMode = 1
|
||||
|
@ -113,8 +113,8 @@ func (s *DB) LogMode(b bool) *DB {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *DB) SingularTable(b bool) {
|
||||
s.parent.singularTable = b
|
||||
func (s *DB) SingularTable(enable bool) {
|
||||
s.parent.singularTable = enable
|
||||
}
|
||||
|
||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||
|
@ -158,11 +158,10 @@ func (s *DB) Joins(query string) *DB {
|
|||
}
|
||||
|
||||
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||
c := s
|
||||
for _, f := range funcs {
|
||||
c = f(c)
|
||||
s = f(s)
|
||||
}
|
||||
return c
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) Unscoped() *DB {
|
||||
|
@ -179,16 +178,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||
|
||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.clone().NewScope(out)
|
||||
newScope.Search = newScope.Search.clone()
|
||||
newScope.Search.limit(1)
|
||||
newScope.Search = newScope.Search.clone().limit(1)
|
||||
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||
newScope := s.clone().NewScope(out)
|
||||
newScope.Search = newScope.Search.clone()
|
||||
newScope.Search.limit(1)
|
||||
newScope.Search = newScope.Search.clone().limit(1)
|
||||
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
|
||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
@ -213,10 +210,9 @@ func (s *DB) Scan(dest interface{}) *DB {
|
|||
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
r := c.First(out, where...)
|
||||
if r.Error != nil {
|
||||
if !r.RecordNotFound() {
|
||||
return r
|
||||
if result := c.First(out, where...); result.Error != nil {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize()
|
||||
} else {
|
||||
|
@ -227,10 +223,9 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
|||
|
||||
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
c := s.clone()
|
||||
r := c.First(out, where...)
|
||||
if r.Error != nil {
|
||||
if !r.RecordNotFound() {
|
||||
return r
|
||||
if result := c.First(out, where...); result.Error != nil {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
||||
} else if len(c.search.AssignAttrs) > 0 {
|
||||
|
@ -418,25 +413,24 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
|||
}
|
||||
|
||||
func (s *DB) Association(column string) *Association {
|
||||
var err error
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
|
||||
primaryKey := scope.PrimaryKeyValue()
|
||||
primaryType := scope.TableName()
|
||||
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
|
||||
scope.Err(errors.New("primary key can't be nil"))
|
||||
}
|
||||
|
||||
var field *Field
|
||||
var ok bool
|
||||
if field, ok = scope.FieldByName(column); ok {
|
||||
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
|
||||
err = errors.New("primary key can't be nil")
|
||||
} else {
|
||||
if field, ok := scope.FieldByName(column); ok {
|
||||
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
|
||||
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
|
||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||
} else {
|
||||
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
|
||||
}
|
||||
} else {
|
||||
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
|
||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||
}
|
||||
}
|
||||
|
||||
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
|
||||
return &Association{Error: err}
|
||||
}
|
||||
|
||||
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||
|
|
|
@ -133,11 +133,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
|
||||
// Set fields
|
||||
for i := 0; i < scopeType.NumField(); i++ {
|
||||
fieldStruct := scopeType.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := &StructField{
|
||||
Struct: fieldStruct,
|
||||
Name: fieldStruct.Name,
|
||||
|
@ -265,6 +261,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range modelStruct.StructFields {
|
||||
if field.IsNormal {
|
||||
|
|
3
mssql.go
3
mssql.go
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mssql struct{}
|
||||
|
@ -36,7 +37,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "datetime2"
|
||||
}
|
||||
default:
|
||||
|
|
3
mysql.go
3
mysql.go
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysql struct{}
|
||||
|
@ -36,7 +37,7 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
|||
}
|
||||
return "longtext"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "timestamp NULL"
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/hstore"
|
||||
)
|
||||
|
@ -40,7 +41,7 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
|
|
40
preload.go
40
preload.go
|
@ -7,7 +7,7 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
func getFieldValue(value reflect.Value, field string) interface{} {
|
||||
func getRealValue(value reflect.Value, field string) interface{} {
|
||||
result := reflect.Indirect(value).FieldByName(field).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
|
@ -20,26 +20,14 @@ func equalAsString(a interface{}, b interface{}) bool {
|
|||
}
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
// Get Fields
|
||||
var fields map[string]*Field
|
||||
var isSlice bool
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
isSlice = true
|
||||
typ := scope.IndirectValue().Type().Elem()
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
elem := reflect.New(typ).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
} else {
|
||||
fields = scope.Fields()
|
||||
}
|
||||
fields := scope.Fields()
|
||||
isSlice := scope.IndirectValue().Kind() == reflect.Slice
|
||||
|
||||
if scope.Search.Preload != nil {
|
||||
for key, conditions := range scope.Search.Preload {
|
||||
for _, field := range fields {
|
||||
if field.Name == key && field.Relationship != nil {
|
||||
results := makeSlice(field.Field)
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryKeyField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
||||
|
@ -53,10 +41,10 @@ func Preload(scope *Scope) {
|
|||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getFieldValue(result, relation.ForeignFieldName)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
|
||||
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
|
||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
|
@ -72,11 +60,11 @@ func Preload(scope *Scope) {
|
|||
if isSlice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getFieldValue(result, relation.ForeignFieldName)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getFieldValue(object, primaryName), value) {
|
||||
if equalAsString(getRealValue(object, primaryName), value) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
break
|
||||
|
@ -92,11 +80,11 @@ func Preload(scope *Scope) {
|
|||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
value := getFieldValue(result, associationPrimaryKey)
|
||||
value := getRealValue(result, associationPrimaryKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) {
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
|
@ -116,9 +104,8 @@ func Preload(scope *Scope) {
|
|||
}
|
||||
}
|
||||
|
||||
func makeSlice(value reflect.Value) interface{} {
|
||||
typ := value.Type()
|
||||
if value.Kind() == reflect.Slice {
|
||||
func makeSlice(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
|
@ -132,8 +119,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
|
|||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
value := values.Index(i)
|
||||
primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
|
||||
primaryKeys = append(primaryKeys, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
return []interface{}{values.FieldByName(column).Interface()}
|
||||
|
|
56
scope.go
56
scope.go
|
@ -111,7 +111,8 @@ func (scope *Scope) PrimaryKey() string {
|
|||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
func (scope *Scope) PrimaryKeyZero() bool {
|
||||
return isBlank(reflect.ValueOf(scope.PrimaryKeyValue()))
|
||||
field := scope.PrimaryKeyField()
|
||||
return field == nil || field.IsBlank
|
||||
}
|
||||
|
||||
// PrimaryKeyValue get the primary key's value
|
||||
|
@ -125,35 +126,23 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
|
|||
// HasColumn to check if has column
|
||||
func (scope *Scope) HasColumn(column string) bool {
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if !field.IsIgnored {
|
||||
if field.Name == column || field.DBName == column {
|
||||
if field.IsNormal && (field.Name == column || field.DBName == column) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FieldValueByName to get column's value and existence
|
||||
func (scope *Scope) FieldValueByName(name string) (interface{}, error) {
|
||||
return FieldValueByName(name, scope.Value)
|
||||
}
|
||||
|
||||
// SetColumn to set the column's value
|
||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
if field, ok := column.(*Field); ok {
|
||||
return field.Set(value)
|
||||
} else if dbName, ok := column.(string); ok {
|
||||
if scope.Value == nil {
|
||||
return errors.New("scope value must not be nil for string columns")
|
||||
}
|
||||
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
dbName = ToSnake(dbName)
|
||||
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
|
@ -204,45 +193,11 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||
}
|
||||
|
||||
// TableName get table name
|
||||
|
||||
func (scope *Scope) TableName() string {
|
||||
if scope.Search != nil && len(scope.Search.TableName) > 0 {
|
||||
return scope.Search.TableName
|
||||
}
|
||||
|
||||
if scope.Value == nil {
|
||||
scope.Err(errors.New("can't get table name"))
|
||||
return ""
|
||||
}
|
||||
|
||||
data := scope.IndirectValue()
|
||||
if data.Kind() == reflect.Slice {
|
||||
elem := data.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
data = reflect.New(elem).Elem()
|
||||
}
|
||||
|
||||
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
||||
if result, ok := v[0].Interface().(string); ok {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
str := ToSnake(data.Type().Name())
|
||||
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
for index, reg := range pluralMapKeys {
|
||||
if reg.MatchString(str) {
|
||||
return reg.ReplaceAllString(str, pluralMapValues[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return str
|
||||
return scope.GetModelStruct().TableName
|
||||
}
|
||||
|
||||
func (scope *Scope) QuotedTableName() string {
|
||||
|
@ -284,8 +239,7 @@ func (scope *Scope) Exec() *Scope {
|
|||
defer scope.Trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
||||
if scope.Err(err) == nil {
|
||||
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); err == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
|
|
|
@ -25,10 +25,8 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
} else if value != "" {
|
||||
str = fmt.Sprintf("(%v)", value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return scope.primaryCondition(scope.AddToVars(value))
|
||||
case sql.NullInt64:
|
||||
return scope.primaryCondition(scope.AddToVars(value.Int64))
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
|
||||
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
|
||||
clause["args"] = []interface{}{value}
|
||||
|
@ -71,12 +69,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
|
||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
var notEqualSql string
|
||||
var primaryKey = scope.PrimaryKey()
|
||||
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
// is number
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id)
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
||||
|
@ -84,15 +84,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value)
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
|
||||
if reflect.ValueOf(value).Len() > 0 {
|
||||
str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey()))
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
|
||||
clause["args"] = []interface{}{value}
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
|
@ -157,16 +156,10 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) where(where ...interface{}) {
|
||||
if len(where) > 0 {
|
||||
scope.Search = scope.Search.clone().where(where[0], where[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) whereSql() (sql string) {
|
||||
var primaryConditions, andConditions, orConditions []string
|
||||
|
||||
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
|
||||
if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
|
||||
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
|
||||
primaryConditions = append(primaryConditions, sql)
|
||||
}
|
||||
|
@ -317,41 +310,19 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
|
|||
}
|
||||
|
||||
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
|
||||
data := scope.IndirectValue()
|
||||
if !data.CanAddr() {
|
||||
if !scope.IndirectValue().CanAddr() {
|
||||
return values, true
|
||||
}
|
||||
|
||||
fields := scope.Fields()
|
||||
for key, value := range values {
|
||||
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
|
||||
func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
hasUpdate = true
|
||||
field.Set(value)
|
||||
}
|
||||
}()
|
||||
|
||||
if field.Field.Interface() != value {
|
||||
switch field.Field.Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
if s, ok := value.(string); ok {
|
||||
i, err := strconv.Atoi(s)
|
||||
if scope.Err(err) == nil {
|
||||
value = i
|
||||
}
|
||||
}
|
||||
|
||||
if field.Field.Int() != reflect.ValueOf(value).Int() {
|
||||
hasUpdate = true
|
||||
field.Set(value)
|
||||
}
|
||||
default:
|
||||
if field, ok := fields[ToSnake(key)]; ok && field.Field.IsValid() {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if !equalAsString(field.Field.Interface(), value) {
|
||||
hasUpdate = true
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -464,7 +435,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
if field.Relationship != nil && field.Relationship.JoinTable != "" {
|
||||
if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) {
|
||||
newScope := scope.db.NewScope("")
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255)
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
||||
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||
field.Relationship.JoinTable,
|
||||
strings.Join([]string{
|
||||
|
@ -523,14 +494,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var table = scope.TableName()
|
||||
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||
var query = `
|
||||
ALTER TABLE %s
|
||||
ADD CONSTRAINT %s
|
||||
FOREIGN KEY (%s)
|
||||
REFERENCES %s
|
||||
ON DELETE %s
|
||||
ON UPDATE %s;
|
||||
`
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sqlite3 struct{}
|
||||
|
@ -35,7 +36,7 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
|
|
20
utils.go
20
utils.go
|
@ -2,9 +2,6 @@ package gorm
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
@ -26,23 +23,6 @@ func (s *safeMap) Get(key string) string {
|
|||
return s.m[key]
|
||||
}
|
||||
|
||||
func FieldValueByName(name string, value interface{}) (i interface{}, err error) {
|
||||
data := reflect.Indirect(reflect.ValueOf(value))
|
||||
name = SnakeToUpperCamel(name)
|
||||
|
||||
if data.Kind() == reflect.Struct {
|
||||
if field := data.FieldByName(name); field.IsValid() {
|
||||
i = field.Interface()
|
||||
} else {
|
||||
return nil, fmt.Errorf("struct has no field with name %s", name)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("value must be of kind struct")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func newSafeMap() *safeMap {
|
||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||
}
|
||||
|
|
|
@ -2,18 +2,16 @@ package gorm
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func fileWithLineNum() string {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
|
||||
return fmt.Sprintf("%v:%v", file, line)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
|
Loading…
Reference in New Issue