mirror of https://github.com/go-gorm/gorm.git
Improve Hooks
This commit is contained in:
parent
95a6539331
commit
befef0c9a9
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
||||
|
||||
// Save Belongs To associations
|
||||
|
@ -83,7 +83,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func SaveAfterAssociations(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
|
||||
|
||||
// Save Has One associations
|
||||
|
|
|
@ -9,20 +9,21 @@ import (
|
|||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
ok = true
|
||||
i.BeforeSave(db)
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
||||
ok = true
|
||||
i.BeforeCreate(db)
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
}
|
||||
return ok
|
||||
|
@ -31,7 +32,7 @@ func BeforeCreate(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
@ -46,48 +47,127 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
return CreateWithReturning
|
||||
} else {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID--
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID++
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID--
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID++
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CreateWithReturning(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||
}
|
||||
|
||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
idx int
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for dbName, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx != 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(dbName)
|
||||
idx++
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for rows.Next() {
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.RowsAffected++
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
err = rows.Scan(values...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
} else {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
|
@ -96,96 +176,22 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func CreateWithReturning(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||
}
|
||||
|
||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
idx int
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for dbName, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx != 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(dbName)
|
||||
idx++
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for rows.Next() {
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.RowsAffected++
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
err = rows.Scan(values...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
} else {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
ok = true
|
||||
i.AfterSave(db)
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
||||
ok = true
|
||||
i.AfterCreate(db)
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
}
|
||||
return ok
|
||||
|
@ -194,7 +200,7 @@ func AfterCreate(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
|
|
@ -9,11 +9,12 @@ import (
|
|||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.BeforeDelete {
|
||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
||||
i.BeforeDelete(db)
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +24,7 @@ func BeforeDelete(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
@ -34,57 +35,60 @@ func BeforeDelete(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Delete(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterDelete {
|
||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
||||
i.AfterDelete(db)
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +98,7 @@ func AfterDelete(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
|
|
@ -12,24 +12,26 @@ import (
|
|||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
BuildQuerySQL(db)
|
||||
}
|
||||
if db.Statement.SQL.String() == "" {
|
||||
BuildQuerySQL(db)
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
gorm.Scan(rows, db, false)
|
||||
gorm.Scan(rows, db, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
|
@ -129,50 +131,53 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string][]string{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
for idx := range preloadFields {
|
||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
||||
}
|
||||
}
|
||||
|
||||
preloadNames := make([]string, len(preloadMap))
|
||||
idx := 0
|
||||
for key := range preloadMap {
|
||||
preloadNames[idx] = key
|
||||
idx++
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
var (
|
||||
curSchema = db.Statement.Schema
|
||||
preloadFields = preloadMap[name]
|
||||
rels = make([]*schema.Relationship, len(preloadFields))
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||
rels[idx] = rel
|
||||
curSchema = rel.FieldSchema
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||
if db.Error == nil {
|
||||
if len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string][]string{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
for idx := range preloadFields {
|
||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
||||
}
|
||||
}
|
||||
|
||||
preload(db, rels, db.Statement.Preloads[name])
|
||||
preloadNames := make([]string, len(preloadMap))
|
||||
idx := 0
|
||||
for key := range preloadMap {
|
||||
preloadNames[idx] = key
|
||||
idx++
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
var (
|
||||
curSchema = db.Statement.Schema
|
||||
preloadFields = preloadMap[name]
|
||||
rels = make([]*schema.Relationship, len(preloadFields))
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||
rels[idx] = rel
|
||||
curSchema = rel.FieldSchema
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||
}
|
||||
}
|
||||
|
||||
preload(db, rels, db.Statement.Preloads[name])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterFind {
|
||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
||||
i.AfterFind(db)
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -182,7 +187,7 @@ func AfterQuery(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
|
|
@ -5,10 +5,12 @@ import (
|
|||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
} else {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
} else {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,15 @@ import (
|
|||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Statement.SQL.String() == "" {
|
||||
BuildQuerySQL(db)
|
||||
}
|
||||
if db.Error == nil {
|
||||
if db.Statement.SQL.String() == "" {
|
||||
BuildQuerySQL(db)
|
||||
}
|
||||
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,25 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func BeginTransaction(db *gorm.DB) {
|
||||
if tx := db.Begin(); tx.Error == nil {
|
||||
db.Statement.ConnPool = tx.Statement.ConnPool
|
||||
tx.InstanceSet("gorm:started_transaction", true)
|
||||
} else {
|
||||
tx.Error = nil
|
||||
}
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db.Error == nil {
|
||||
db.Commit()
|
||||
} else {
|
||||
db.Rollback()
|
||||
}
|
||||
db.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,20 +10,21 @@ import (
|
|||
)
|
||||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
ok = true
|
||||
i.BeforeSave(db)
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
||||
ok = true
|
||||
i.BeforeUpdate(db)
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
return ok
|
||||
|
@ -32,7 +33,7 @@ func BeforeUpdate(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
@ -43,51 +44,54 @@ func BeforeUpdate(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Update(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var ok bool
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
ok = true
|
||||
i.AfterSave(db)
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
||||
ok = true
|
||||
i.AfterUpdate(db)
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
}
|
||||
return ok
|
||||
|
@ -96,7 +100,7 @@ func AfterUpdate(db *gorm.DB) {
|
|||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
|
|
|
@ -16,7 +16,7 @@ var (
|
|||
// ErrNotImplemented not implemented
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
// ErrMissingWhereClause missing where clause
|
||||
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
|
||||
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
||||
// ErrUnsupportedRelation unsupported relations
|
||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||
// ErrPtrStructSupported only ptr of struct supported
|
||||
|
|
108
gorm.go
108
gorm.go
|
@ -40,14 +40,15 @@ type DB struct {
|
|||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
clone bool
|
||||
clone int
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
WithConditions bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
|
@ -76,10 +77,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
db = &DB{
|
||||
Config: config,
|
||||
clone: true,
|
||||
}
|
||||
db = &DB{Config: config, clone: 1}
|
||||
|
||||
db.callbacks = initializeCallbacks(db)
|
||||
|
||||
|
@ -96,38 +94,54 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
// Session create new db session
|
||||
func (db *DB) Session(config *Session) *DB {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement.clone()
|
||||
txConfig = *tx.Config
|
||||
txConfig = *db.Config
|
||||
tx = &DB{
|
||||
Config: &txConfig,
|
||||
Statement: db.Statement,
|
||||
clone: 1,
|
||||
}
|
||||
)
|
||||
|
||||
if config.Context != nil {
|
||||
stmt.Context = config.Context
|
||||
if tx.Statement != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
} else {
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
ConnPool: tx.ConnPool,
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.WithConditions {
|
||||
tx.clone = 3
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
txConfig.Logger = config.Logger
|
||||
tx.Config.Logger = config.Logger
|
||||
}
|
||||
|
||||
if config.NowFunc != nil {
|
||||
txConfig.NowFunc = config.NowFunc
|
||||
tx.Config.NowFunc = config.NowFunc
|
||||
}
|
||||
|
||||
return &DB{
|
||||
Config: &txConfig,
|
||||
Statement: stmt,
|
||||
clone: true,
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
// WithContext change current instance db's context to ctx
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
return db.Session(&Session{Context: ctx})
|
||||
return db.Session(&Session{WithConditions: true, Context: ctx})
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
|
||||
return db.Session(&Session{
|
||||
WithConditions: true,
|
||||
Logger: db.Logger.LogMode(logger.Info),
|
||||
})
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
|
@ -145,6 +159,21 @@ func (db *DB) Get(key string) (interface{}, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
// InstanceSet store value with key into current db instance's context
|
||||
func (db *DB) InstanceSet(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// InstanceGet get value with key from current db instance's context
|
||||
func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
if db.Statement != nil {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
|
@ -166,18 +195,37 @@ func (db *DB) AddError(err error) error {
|
|||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone {
|
||||
stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}}
|
||||
if db.clone > 0 {
|
||||
tx := &DB{Config: db.Config}
|
||||
|
||||
if db.Statement != nil {
|
||||
stmt.Context = db.Statement.Context
|
||||
stmt.ConnPool = db.Statement.ConnPool
|
||||
} else {
|
||||
stmt.Context = context.Background()
|
||||
stmt.ConnPool = db.ConnPool
|
||||
switch db.clone {
|
||||
case 1: // clone with new statement
|
||||
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
|
||||
db.clone = 1
|
||||
tx.Statement = db.Statement
|
||||
case 3: // with clone statement
|
||||
if db.Statement != nil {
|
||||
tx.Statement = db.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
}
|
||||
|
||||
return &DB{Config: db.Config, Statement: stmt}
|
||||
if tx.Statement == nil {
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement != nil {
|
||||
tx.Statement.Context = db.Statement.Context
|
||||
tx.Statement.ConnPool = db.Statement.ConnPool
|
||||
} else {
|
||||
tx.Statement.Context = context.Background()
|
||||
tx.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
return db
|
||||
|
|
|
@ -36,37 +36,37 @@ type TxCommiter interface {
|
|||
}
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*DB)
|
||||
BeforeCreate(*DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*DB)
|
||||
AfterCreate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*DB)
|
||||
BeforeUpdate(*DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*DB)
|
||||
AfterUpdate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*DB)
|
||||
BeforeSave(*DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*DB)
|
||||
AfterSave(*DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*DB)
|
||||
BeforeDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*DB)
|
||||
AfterDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*DB)
|
||||
AfterFind(*DB) error
|
||||
}
|
||||
|
|
|
@ -12,10 +12,12 @@ import (
|
|||
type UserWithCallback struct {
|
||||
}
|
||||
|
||||
func (UserWithCallback) BeforeSave(*gorm.DB) {
|
||||
func (UserWithCallback) BeforeSave(*gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (UserWithCallback) AfterCreate(*gorm.DB) {
|
||||
func (UserWithCallback) AfterCreate(*gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
|
|
|
@ -200,12 +200,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.New(modelType))
|
||||
reflectValue := reflect.New(modelType)
|
||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||
for _, name := range callbacks {
|
||||
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB)": // TODO hack
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
type Product struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Code string
|
||||
Price float64
|
||||
AfterFindCallTimes int64
|
||||
BeforeCreateCallTimes int64
|
||||
AfterCreateCallTimes int64
|
||||
BeforeUpdateCallTimes int64
|
||||
AfterUpdateCallTimes int64
|
||||
BeforeSaveCallTimes int64
|
||||
AfterSaveCallTimes int64
|
||||
BeforeDeleteCallTimes int64
|
||||
AfterDeleteCallTimes int64
|
||||
}
|
||||
|
||||
func (s *Product) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if s.Code == "Invalid" {
|
||||
err = errors.New("invalid product")
|
||||
}
|
||||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||
if s.Code == "dont_update" {
|
||||
err = errors.New("can't update")
|
||||
}
|
||||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeSave(tx *gorm.DB) (err error) {
|
||||
if s.Code == "dont_save" {
|
||||
err = errors.New("can't save")
|
||||
}
|
||||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterFind(tx *gorm.DB) (err error) {
|
||||
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterCreate(tx *gorm.DB) (err error) {
|
||||
return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error
|
||||
}
|
||||
|
||||
func (s *Product) AfterUpdate(tx *gorm.DB) (err error) {
|
||||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterSave(tx *gorm.DB) (err error) {
|
||||
if s.Code == "after_save_error" {
|
||||
err = errors.New("can't save")
|
||||
}
|
||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeDelete(tx *gorm.DB) (err error) {
|
||||
if s.Code == "dont_delete" {
|
||||
err = errors.New("can't delete")
|
||||
}
|
||||
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterDelete(tx *gorm.DB) (err error) {
|
||||
if s.Code == "after_delete_error" {
|
||||
err = errors.New("can't delete")
|
||||
}
|
||||
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) GetCallTimes() []int64 {
|
||||
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
||||
}
|
||||
|
||||
func TestRunCallbacks(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
DB.AutoMigrate(&Product{})
|
||||
|
||||
p := Product{Code: "unique_code", Price: 100}
|
||||
DB.Save(&p)
|
||||
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
||||
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
DB.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
||||
t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
p.Price = 200
|
||||
DB.Save(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
||||
t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
var products []Product
|
||||
DB.Find(&products, "code = ?", "unique_code")
|
||||
if products[0].AfterFindCallTimes != 1 {
|
||||
t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes)
|
||||
}
|
||||
|
||||
DB.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
||||
t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
DB.Delete(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
||||
t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
||||
t.Fatalf("Can't find a deleted record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksWithErrors(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product{})
|
||||
DB.AutoMigrate(&Product{})
|
||||
|
||||
p := Product{Code: "Invalid", Price: 100}
|
||||
if DB.Save(&p).Error == nil {
|
||||
t.Fatalf("An error from before create callbacks happened when create with invalid value")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
||||
t.Fatalf("Should not save record that have errors")
|
||||
}
|
||||
|
||||
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
||||
t.Fatalf("An error from after create callbacks happened when create with invalid value")
|
||||
}
|
||||
|
||||
p2 := Product{Code: "update_callback", Price: 100}
|
||||
DB.Save(&p2)
|
||||
|
||||
p2.Code = "dont_update"
|
||||
if DB.Save(&p2).Error == nil {
|
||||
t.Fatalf("An error from before update callbacks happened when update with invalid value")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
||||
t.Fatalf("Record Should not be updated due to errors happened in before update callback")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
||||
t.Fatalf("Record Should not be updated due to errors happened in before update callback")
|
||||
}
|
||||
|
||||
p2.Code = "dont_save"
|
||||
if DB.Save(&p2).Error == nil {
|
||||
t.Fatalf("An error from before save callbacks happened when update with invalid value")
|
||||
}
|
||||
|
||||
p3 := Product{Code: "dont_delete", Price: 100}
|
||||
DB.Save(&p3)
|
||||
if DB.Delete(&p3).Error == nil {
|
||||
t.Fatalf("An error from before delete callbacks happened when delete")
|
||||
}
|
||||
|
||||
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||
t.Fatalf("An error from before delete callbacks happened")
|
||||
}
|
||||
|
||||
p4 := Product{Code: "after_save_error", Price: 100}
|
||||
DB.Save(&p4)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
||||
t.Fatalf("Record should be reverted if get an error in after save callback")
|
||||
}
|
||||
|
||||
p5 := Product{Code: "after_delete_error", Price: 100}
|
||||
DB.Save(&p5)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Fatalf("Record should be found")
|
||||
}
|
||||
|
||||
DB.Delete(&p5)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||
}
|
||||
}
|
|
@ -59,9 +59,9 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||
}
|
||||
|
||||
if debug := os.Getenv("DEBUG"); debug == "true" {
|
||||
db.Logger.LogMode(logger.Info)
|
||||
db.Logger = db.Logger.LogMode(logger.Info)
|
||||
} else if debug == "false" {
|
||||
db.Logger.LogMode(logger.Silent)
|
||||
db.Logger = db.Logger.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
return
|
||||
|
|
|
@ -14,37 +14,37 @@ func TestTransaction(t *testing.T) {
|
|||
user := *GetUser("transcation", Config{})
|
||||
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got %v", err)
|
||||
t.Fatalf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got %v", err)
|
||||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
|
||||
t.Errorf("Should return the underlying sql.Tx")
|
||||
t.Fatalf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback, but got %v", err)
|
||||
t.Fatalf("Should not find record after rollback, but got %v", err)
|
||||
}
|
||||
|
||||
tx2 := DB.Begin()
|
||||
user2 := *GetUser("transcation-2", Config{})
|
||||
if err := tx2.Save(&user2).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got %v", err)
|
||||
t.Fatalf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got %v", err)
|
||||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
tx2.Commit()
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record, but got %v", err)
|
||||
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||
assertPanic := func(f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic")
|
||||
t.Fatalf("The code did not panic")
|
||||
}
|
||||
}()
|
||||
f()
|
||||
|
@ -62,39 +62,39 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
t.Fatalf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
return errors.New("the error message")
|
||||
})
|
||||
|
||||
if err.Error() != "the error message" {
|
||||
t.Errorf("Transaction return error will equal the block returns error")
|
||||
t.Fatalf("Transaction return error will equal the block returns error")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback")
|
||||
t.Fatalf("Should not find record after rollback")
|
||||
}
|
||||
|
||||
// commit
|
||||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block-2", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
t.Fatalf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record")
|
||||
t.Fatalf("Should be able to find committed record")
|
||||
}
|
||||
|
||||
// panic will rollback
|
||||
|
@ -102,11 +102,11 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block-3", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
t.Fatalf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
panic("force panic")
|
||||
|
@ -114,7 +114,7 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||
})
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil {
|
||||
t.Errorf("Should not find record after panic rollback")
|
||||
t.Fatalf("Should not find record after panic rollback")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,14 +122,14 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
|
|||
tx := DB.Begin()
|
||||
user := User{Name: "transcation"}
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
t.Fatalf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Errorf("Commit should not raise error")
|
||||
t.Fatalf("Commit should not raise error")
|
||||
}
|
||||
|
||||
if err := tx.Rollback().Error; err == nil {
|
||||
t.Errorf("Rollback after commit should raise error")
|
||||
t.Fatalf("Rollback after commit should raise error")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue