diff --git a/callbacks/create.go b/callbacks/create.go index 9dc5b8b1..29113128 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -57,7 +57,7 @@ func Create(config *Config) func(db *gorm.DB) { } } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) diff --git a/callbacks/delete.go b/callbacks/delete.go index b05a9d08..7f1e09ce 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,13 +118,7 @@ func Delete(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -147,6 +141,15 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } diff --git a/callbacks/query.go b/callbacks/query.go index 2f98a4b6..efb08609 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -33,7 +33,7 @@ func BuildQuerySQL(db *gorm.DB) { } } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} diff --git a/callbacks/update.go b/callbacks/update.go index fa7640de..b3eaaf11 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,13 +59,7 @@ func Update(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) if set := ConvertToAssignments(db.Statement); len(set) != 0 { @@ -73,6 +67,16 @@ func Update(config *Config) func(db *gorm.DB) { } else if _, ok := db.Statement.Clauses["SET"]; !ok { return } + + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } diff --git a/soft_delete.go b/soft_delete.go index 11c4fafc..4e236fc4 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -103,7 +103,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { - if stmt.SQL.String() == "" { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { SoftDeleteQueryClause(sd).ModifyStatement(stmt) } @@ -129,7 +129,7 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { - if stmt.SQL.String() == "" { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { curTime := stmt.DB.NowFunc() stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) stmt.SetColumn(sd.Field.DBName, curTime, true) diff --git a/tests/update_test.go b/tests/update_test.go index 14ed9820..abe520db 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,13 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } + + dryDB = DB.Session(&gorm.Session{DryRun: true}) + stmt = dryDB.Unscoped().Save(&user).Statement + if !regexp.MustCompile(`WHERE .id. = [^ ]+$`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) }