diff --git a/main.go b/main.go index b262ad8f..6ff04f8f 100644 --- a/main.go +++ b/main.go @@ -120,17 +120,17 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1) - return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db + return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1) - return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db + return scope.inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where).callCallbacks(s.parent.callback.queries).db + return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Row() *sql.Row { @@ -150,7 +150,7 @@ func (s *DB) Scan(dest interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() if c.First(out, where...).Error == RecordNotFound { - c.NewScope(out).inlineCondition(where).initialize() + c.NewScope(out).inlineCondition(where...).initialize() } else { c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false) } @@ -160,7 +160,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { c := s.clone() if c.First(out, where...).Error == RecordNotFound { - c.NewScope(out).inlineCondition(where).initialize().callCallbacks(s.parent.callback.creates) + c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) } else if len(s.search.assignAttrs) > 0 { c.NewScope(out).Set("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates) } @@ -219,8 +219,7 @@ func (s *DB) Model(value interface{}) *DB { } func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { - old_data := s.Value - return s.do(value).related(old_data, foreign_keys...).db + return s.clone().NewScope(s.Value).related(value, foreign_keys...).db } func (s *DB) Pluck(column string, value interface{}) *DB { diff --git a/scope.go b/scope.go index 41362452..bb92da8e 100644 --- a/scope.go +++ b/scope.go @@ -94,14 +94,21 @@ func (scope *Scope) PrimaryKeyValue() interface{} { } func (scope *Scope) HasColumn(name string) bool { + _, result := scope.FieldByName(name) + return result +} + +func (scope *Scope) FieldByName(name string) (interface{}, bool) { data := reflect.Indirect(reflect.ValueOf(scope.Value)) if data.Kind() == reflect.Struct { - return data.FieldByName(name).IsValid() + if field := data.FieldByName(name); field.IsValid() { + return field.Interface(), true + } } else if data.Kind() == reflect.Slice { - return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + return nil, reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() } - return false + return nil, false } func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { @@ -430,3 +437,26 @@ func (scope *Scope) count(value interface{}) *Scope { scope.Err(scope.row().Scan(value)) return scope } + +func (scope *Scope) typeName() string { + value := reflect.Indirect(reflect.ValueOf(scope.Value)) + if value.Kind() == reflect.Slice { + return value.Type().Elem().Name() + } else { + return value.Type().Name() + } +} + +func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { + toScope := scope.New(value) + + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + if foreignValue, ok := scope.FieldByName(foreignKey); ok { + return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) + } else if toScope.HasColumn(foreignKey) { + sql := fmt.Sprintf("%v = ?", scope.quote(toSnake(foreignKey))) + return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) + } + } + return scope +} diff --git a/scope_condition.go b/scope_condition.go index 556855b3..b5e99b4e 100644 --- a/scope_condition.go +++ b/scope_condition.go @@ -243,7 +243,7 @@ func (scope *Scope) prepareQuerySql() { return } -func (scope *Scope) inlineCondition(values []interface{}) *Scope { +func (scope *Scope) inlineCondition(values ...interface{}) *Scope { if len(values) > 0 { scope.Search = scope.Search.clone().where(values[0], values[1:]...) }