Add Raw, Row, Rows

This commit is contained in:
Jinzhu 2020-02-22 19:41:01 +08:00
parent fab7d96da5
commit 215f5e7765
17 changed files with 162 additions and 12 deletions

View File

@ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
db.Callback().Row().Register("gorm:raw", RowQuery)
db.Callback().Raw().Register("gorm:raw", RawExec)
}

11
callbacks/raw.go Normal file
View File

@ -0,0 +1,11 @@
package callbacks
import "github.com/jinzhu/gorm"
func RawExec(db *gorm.DB) {
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.RowsAffected, _ = result.RowsAffected()
if err != nil {
db.AddError(err)
}
}

19
callbacks/row.go Normal file
View File

@ -0,0 +1,19 @@
package callbacks
import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)
func RowQuery(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
}

View File

@ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
stmt := tx.Statement
stmt.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(stmt)
return
}

View File

@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
}}}
}
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

View File

@ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
}}}
}
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

View File

@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
}}}
}
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

View File

@ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
}}}
}
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {

View File

@ -22,6 +22,10 @@ func init() {
}
}
func TestSqlite(t *testing.T) {
func TestCURD(t *testing.T) {
tests.RunTestsSuit(t, DB)
}
func TestMigrate(t *testing.T) {
tests.TestMigrate(t, DB)
}

View File

@ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) {
}
func (db *DB) Row() *sql.Row {
return nil
tx := db.getInstance()
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row)
}
func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
tx := db.Set("rows", true)
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error
}
// Scan scan value to a struct
@ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) {
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.callbacks.Raw().Execute(tx)
return
}

View File

@ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
func (db *DB) getInstance() *DB {
if db.clone {
ctx := db.Instance.Context

View File

@ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
})
}
func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
return nil, gorm.ErrNotImplemented
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
if err == nil {
columnTypes, err = rows.ColumnTypes()
}
return err
})
return
}
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {

View File

@ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check {
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) {
if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}

55
schema/check_test.go Normal file
View File

@ -0,0 +1,55 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm/schema"
)
type UserCheck struct {
Name string `gorm:"check:name_checker,name <> 'jinzhu'"`
Name2 string `gorm:"check:name <> 'jinzhu'"`
Name3 string `gorm:"check:,name <> 'jinzhu'"`
}
func TestParseCheck(t *testing.T) {
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user check, got error %v", err)
}
results := map[string]schema.Check{
"name_checker": {
Name: "name_checker",
Constraint: "name <> 'jinzhu'",
},
"chk_user_checks_name2": {
Name: "chk_user_checks_name2",
Constraint: "name <> 'jinzhu'",
},
"chk_user_checks_name3": {
Name: "chk_user_checks_name3",
Constraint: "name <> 'jinzhu'",
},
}
checks := user.ParseCheckConstraints()
for k, result := range results {
v, ok := checks[k]
if !ok {
t.Errorf("Failed to found check %v from parsed checks %+v", k, checks)
}
for _, name := range []string{"Name", "Constraint"} {
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
t.Errorf(
"check %v %v should equal, expects %v, got %v",
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
}
}
}
}

View File

@ -21,7 +21,7 @@ type UserIndex struct {
func TestParseIndex(t *testing.T) {
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index index, got error %v", err)
t.Fatalf("failed to parse user index, got error %v", err)
}
results := map[string]schema.Index{

View File

@ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
settings = ParseTagSetting(str, ",")
)
if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) {
if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
name = str[0:idx]
} else {
name = rel.Schema.namer.RelationshipFKName(*rel)
@ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint {
}
}
if constraint.ReferenceSchema == nil {
return nil
}
return &constraint
}

19
tests/migrate.go Normal file
View File

@ -0,0 +1,19 @@
package tests
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestMigrate(t *testing.T, db *gorm.DB) {
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
db.AutoMigrate(allModels...)
for _, m := range allModels {
if !db.Migrator().HasTable(m) {
t.Errorf("Failed to create table for %+v", m)
}
}
}