diff --git a/main.go b/main.go index 87a0d7e5..527f1204 100644 --- a/main.go +++ b/main.go @@ -182,7 +182,7 @@ func (s *DB) Save(value interface{}) *DB { } func (s *DB) Delete(value interface{}) *DB { - return s.clone().do(value).begin().delete().commit_or_rollback().db + return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db } func (s *DB) Raw(sql string, values ...interface{}) *DB { diff --git a/scope.go b/scope.go index f094d770..a1c81e06 100644 --- a/scope.go +++ b/scope.go @@ -1,14 +1,32 @@ package gorm -import "github.com/jinzhu/gorm/dialect" +import ( + "errors" + "fmt" + "github.com/jinzhu/gorm/dialect" + + "reflect" +) type Scope struct { + Value interface{} Search *search Sql string SqlVars []interface{} db *DB } +func (db *DB) newScope(value interface{}) *Scope { + return &Scope{db: db, Search: db.search, Value: value} +} + +func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + for _, f := range funcs { + (*f)(scope) + } + return scope +} + func (scope *Scope) DB() sqlCommon { return scope.db.db } @@ -25,21 +43,50 @@ func (scope *Scope) Err(err error) error { } func (scope *Scope) HasError() bool { - return true + return scope.db.hasError() } func (scope *Scope) PrimaryKey() string { - return "" + return "Id" } func (scope *Scope) HasColumn(name string) bool { + data := reflect.Indirect(reflect.ValueOf(scope.Value)) + + if data.Kind() == reflect.Struct { + return data.FieldByName(name).IsValid() + } else if data.Kind() == reflect.Slice { + return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + } return false } func (scope *Scope) SetColumn(column string, value interface{}) { + data := reflect.Indirect(reflect.ValueOf(scope.Value)) + setFieldValue(data.FieldByName(snakeToUpperCamel(column)), value) } func (scope *Scope) CallMethod(name string) { + if fm := reflect.ValueOf(scope.Value).MethodByName(name); fm.IsValid() { + fi := fm.Interface() + if f, ok := fi.(func()); ok { + f() + } else if f, ok := fi.(func(s *Scope)); ok { + f(scope) + } else if f, ok := fi.(func(s *DB)); ok { + f(scope.db.new()) + } else if f, ok := fi.(func() error); ok { + scope.Err(f()) + } else if f, ok := fi.(func(s *Scope) error); ok { + scope.Err(f(scope)) + } else if f, ok := fi.(func(s *DB) error); ok { + scope.Err(f(scope.db.new())) + } else { + scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name))) + } + } else { + scope.Err(errors.New(fmt.Sprintf("no valid function %v found", name))) + } } func (scope *Scope) CombinedConditionSql() string { @@ -55,6 +102,7 @@ func (scope *Scope) TableName() string { } func (scope *Scope) Raw(sql string, values ...interface{}) { + fmt.Println(sql, values) } func (scope *Scope) Exec() {