From 8dd7b4ed917a0f1ecef2f9b3a2f4318780899752 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 26 Jan 2014 19:34:06 +0800 Subject: [PATCH] make callback create works --- callback_create.go | 40 +++++++++++++++++++++----- callback_delete.go | 26 ++++++++--------- callback_shared.go | 6 ++++ do.go | 15 ---------- field.go | 26 ++++++++++++++--- main.go | 7 ++++- scope.go | 72 +++++++++++++++++++++++++++++++++++++++++++++- 7 files changed, 150 insertions(+), 42 deletions(-) diff --git a/callback_create.go b/callback_create.go index 10f7537e..00345b3d 100644 --- a/callback_create.go +++ b/callback_create.go @@ -1,15 +1,40 @@ package gorm +import ( + "fmt" + "strings" + "time" +) + func BeforeCreate(scope *Scope) { scope.CallMethod("BeforeSave") scope.CallMethod("BeforeCreate") } -func SaveBeforeAssociations(scope *Scope) { -} - func Create(scope *Scope) { + defer scope.Trace(time.Now()) + if !scope.HasError() { + // set create sql + var sqls, columns []string + + for _, field := range scope.Fields() { + if field.IsBlank || len(field.SqlTag) == 0 { + continue + } + columns = append(columns, scope.quote(field.DBName)) + sqls = append(sqls, scope.AddToVars(field.Value)) + } + + scope.Raw(fmt.Sprintf( + "INSERT INTO %v (%v) VALUES (%v) %v", + scope.TableName(), + strings.Join(columns, ","), + strings.Join(sqls, ","), + scope.Dialect().ReturningStr(scope.PrimaryKey()), + )) + + // execute create sql var id interface{} if scope.Dialect().SupportLastInsertId() { if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { @@ -20,7 +45,9 @@ func Create(scope *Scope) { scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) } - scope.SetColumn(scope.PrimaryKey(), id) + if !scope.HasError() { + scope.SetColumn(scope.PrimaryKey(), id) + } } } @@ -29,13 +56,12 @@ func AfterCreate(scope *Scope) { scope.CallMethod("AfterSave") } -func SaveAfterAssociations(scope *Scope) { -} - func init() { + DefaultCallback.Create().Register("begin_transaction", BeginTransaction) DefaultCallback.Create().Register("before_create", BeforeCreate) DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) DefaultCallback.Create().Register("create", Create) DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("after_create", AfterCreate) + DefaultCallback.Create().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callback_delete.go b/callback_delete.go index 9dc5b692..16cf8556 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -12,22 +12,20 @@ func BeforeDelete(scope *Scope) { func Delete(scope *Scope) { defer scope.Trace(time.Now()) - if scope.HasError() { - return - } + if !scope.HasError() { + if !scope.Search.unscope && scope.HasColumn("DeletedAt") { + scope.Raw( + fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", + scope.TableName(), + scope.AddToVars(time.Now()), + scope.CombinedConditionSql(), + )) + } else { + scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql())) + } - if !scope.Search.unscope && scope.HasColumn("DeletedAt") { - scope.Raw( - fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", - scope.TableName(), - scope.AddToVars(time.Now()), - scope.CombinedConditionSql(), - )) - } else { - scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql())) + scope.Exec() } - - scope.Exec() } func AfterDelete(scope *Scope) { diff --git a/callback_shared.go b/callback_shared.go index 3ff5f104..896413af 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -7,3 +7,9 @@ func BeginTransaction(scope *Scope) { func CommitOrRollbackTransaction(scope *Scope) { scope.CommitOrRollback() } + +func SaveBeforeAssociations(scope *Scope) { +} + +func SaveAfterAssociations(scope *Scope) { +} diff --git a/do.go b/do.go index 6c818eaa..031989db 100644 --- a/do.go +++ b/do.go @@ -311,21 +311,6 @@ func (s *Do) update() *Do { return s } -func (s *Do) delete() *Do { - s.model.callMethod("BeforeDelete") - - if !s.db.hasError() { - if !s.search.unscope && s.model.hasColumn("DeletedAt") { - s.setSql(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())) - } else { - s.setSql(fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())) - } - s.exec() - s.model.callMethod("AfterDelete") - } - return s -} - func (s *Do) prepareQuerySql() { if s.search.raw { s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE ")) diff --git a/field.go b/field.go index d655e70a..21083227 100644 --- a/field.go +++ b/field.go @@ -9,10 +9,18 @@ import ( ) type Field struct { - Name string - Value interface{} - model *Model + Name string + DBName string + Value interface{} + IsBlank bool + IsIgnored bool + Tag string + AddationalTag string + Size int + SqlTag string + dbName string + model *Model isBlank bool ignoreField bool isPrimaryKey bool @@ -25,6 +33,16 @@ type Field struct { structField reflect.StructField } +func (f *Field) IsScanner() bool { + _, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) + return is_scanner +} + +func (f *Field) IsTime() bool { + _, is_time := f.Value.(time.Time) + return is_time +} + func (f *Field) parseBlank() { f.isBlank = isBlank(f.reflectValue) } @@ -38,7 +56,7 @@ func (f *Field) parseIgnore() { } func (f *Field) isScanner() bool { - _, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner) + _, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) return is_scanner } diff --git a/main.go b/main.go index 527f1204..8f41c44a 100644 --- a/main.go +++ b/main.go @@ -178,7 +178,12 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) * } func (s *DB) Save(value interface{}) *DB { - return s.clone().do(value).begin().save().commit_or_rollback().db + scope := s.clone().newScope(value) + if scope.PrimaryKeyZero() { + return scope.callCallbacks(s.parent.callback.creates).db + } else { + return s.clone().do(value).begin().save().commit_or_rollback().db + } } func (s *DB) Delete(value interface{}) *DB { diff --git a/scope.go b/scope.go index 6d908cab..c69cac99 100644 --- a/scope.go +++ b/scope.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/jinzhu/gorm/dialect" + "go/ast" "strings" "time" @@ -149,8 +150,77 @@ func (s *Scope) CombinedConditionSql() string { return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql() } +func (scope *Scope) SqlTagForField(field *Field) (tag string) { + value := field.Value + reflect_value := reflect.ValueOf(value) + + if field.IsScanner() { + value = reflect_value.Field(0).Interface() + } + + switch reflect_value.Kind() { + case reflect.Slice: + if _, ok := value.([]byte); !ok { + return + } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + return + } + } + + if tag = field.Tag; len(tag) == 0 && tag != "-" { + if field.isPrimaryKey { + tag = scope.Dialect().PrimaryKeyTag(value, field.Size) + } else { + tag = scope.Dialect().SqlTag(value, field.Size) + } + + if len(field.AddationalTag) > 0 { + tag = tag + " " + field.AddationalTag + } + } + return +} + func (scope *Scope) Fields() []*Field { - return []*Field{} + indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value)) + fields := []*Field{} + + if !indirect_value.IsValid() { + return fields + } + + typ := indirect_value.Type() + for i := 0; i < typ.NumField(); i++ { + field_struct := typ.Field(i) + if field_struct.Anonymous || !ast.IsExported(field_struct.Name) { + continue + } + + var field Field + field.Name = field_struct.Name + field.DBName = toSnake(field_struct.Name) + + value := indirect_value.FieldByName(field_struct.Name) + field.Value = value.Interface() + field.IsBlank = isBlank(value) + + tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier)) + field.Tag = tag + field.AddationalTag = addational_tag + field.Size = size + field.SqlTag = scope.SqlTagForField(&field) + + if tag == "-" { + field.IsIgnored = true + } + + field.parseAssociation() + fields = append(fields, &field) + } + + return fields } func (scope *Scope) Raw(sql string) {