mirror of https://github.com/go-gorm/gorm.git
Refactor
This commit is contained in:
parent
63e48191a8
commit
3f355dc050
|
@ -5,8 +5,6 @@ import (
|
|||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
|
@ -15,7 +13,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -94,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -172,7 +170,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
|
||||
// Save Has Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasMany {
|
||||
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -230,7 +228,7 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
|
||||
// Save Many2Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
|
||||
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -299,18 +297,3 @@ func SaveAfterAssociations(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool {
|
||||
savable := true
|
||||
if value, ok := db.Get("gorm:save_association"); ok {
|
||||
savable = utils.CheckTruth(value)
|
||||
}
|
||||
|
||||
if savable {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -58,6 +58,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.ExecContext(ctx, args...)
|
||||
} else {
|
||||
db.mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.mux.Unlock()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
@ -66,6 +70,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
} else {
|
||||
db.mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.mux.Unlock()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,6 +82,10 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
|||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
} else {
|
||||
db.mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.mux.Unlock()
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
@ -87,6 +99,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||
} else {
|
||||
tx.PreparedStmtDB.mux.Lock()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.mux.Unlock()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
@ -95,6 +111,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||
} else {
|
||||
tx.PreparedStmtDB.mux.Lock()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.mux.Unlock()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
@ -103,6 +123,10 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
|||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
|
||||
} else {
|
||||
tx.PreparedStmtDB.mux.Lock()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
tx.PreparedStmtDB.mux.Unlock()
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue