mirror of https://github.com/go-gorm/gorm.git
Add Raw, Row, Rows
This commit is contained in:
parent
fab7d96da5
commit
215f5e7765
|
@ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
|
|||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
|
||||
db.Callback().Row().Register("gorm:raw", RowQuery)
|
||||
db.Callback().Raw().Register("gorm:raw", RawExec)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
}
|
|
@ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) {
|
|||
|
||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
stmt := tx.Statement
|
||||
stmt.SQL = strings.Builder{}
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(stmt)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
|
|
|
@ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
|
|
|
@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
|
|
|
@ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}}
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
|
|
|
@ -22,6 +22,10 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSqlite(t *testing.T) {
|
||||
func TestCURD(t *testing.T) {
|
||||
tests.RunTestsSuit(t, DB)
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
tests.TestMigrate(t, DB)
|
||||
}
|
||||
|
|
|
@ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
return nil
|
||||
tx := db.getInstance()
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
return tx.Statement.Dest.(*sql.Row)
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
return nil, nil
|
||||
tx := db.Set("rows", true)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
return tx.Statement.Dest.(*sql.Rows), tx.Error
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
|
@ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) {
|
|||
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
|
5
gorm.go
5
gorm.go
|
@ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks {
|
|||
return db.callbacks
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone {
|
||||
ctx := db.Instance.Context
|
||||
|
|
|
@ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) {
|
||||
return nil, gorm.ErrNotImplemented
|
||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) {
|
||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows()
|
||||
if err == nil {
|
||||
columnTypes, err = rows.ColumnTypes()
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||
|
|
|
@ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check {
|
|||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) {
|
||||
if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) {
|
||||
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = Check{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package schema_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
type UserCheck struct {
|
||||
Name string `gorm:"check:name_checker,name <> 'jinzhu'"`
|
||||
Name2 string `gorm:"check:name <> 'jinzhu'"`
|
||||
Name3 string `gorm:"check:,name <> 'jinzhu'"`
|
||||
}
|
||||
|
||||
func TestParseCheck(t *testing.T) {
|
||||
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user check, got error %v", err)
|
||||
}
|
||||
|
||||
results := map[string]schema.Check{
|
||||
"name_checker": {
|
||||
Name: "name_checker",
|
||||
Constraint: "name <> 'jinzhu'",
|
||||
},
|
||||
"chk_user_checks_name2": {
|
||||
Name: "chk_user_checks_name2",
|
||||
Constraint: "name <> 'jinzhu'",
|
||||
},
|
||||
"chk_user_checks_name3": {
|
||||
Name: "chk_user_checks_name3",
|
||||
Constraint: "name <> 'jinzhu'",
|
||||
},
|
||||
}
|
||||
|
||||
checks := user.ParseCheckConstraints()
|
||||
|
||||
for k, result := range results {
|
||||
v, ok := checks[k]
|
||||
if !ok {
|
||||
t.Errorf("Failed to found check %v from parsed checks %+v", k, checks)
|
||||
}
|
||||
|
||||
for _, name := range []string{"Name", "Constraint"} {
|
||||
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
|
||||
t.Errorf(
|
||||
"check %v %v should equal, expects %v, got %v",
|
||||
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,7 +21,7 @@ type UserIndex struct {
|
|||
func TestParseIndex(t *testing.T) {
|
||||
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user index index, got error %v", err)
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
|
||||
results := map[string]schema.Index{
|
||||
|
|
|
@ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
|||
settings = ParseTagSetting(str, ",")
|
||||
)
|
||||
|
||||
if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) {
|
||||
if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
|
||||
name = str[0:idx]
|
||||
} else {
|
||||
name = rel.Schema.namer.RelationshipFKName(*rel)
|
||||
|
@ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
|||
}
|
||||
}
|
||||
|
||||
if constraint.ReferenceSchema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &constraint
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T, db *gorm.DB) {
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
|
||||
|
||||
db.AutoMigrate(allModels...)
|
||||
|
||||
for _, m := range allModels {
|
||||
if !db.Migrator().HasTable(m) {
|
||||
t.Errorf("Failed to create table for %+v", m)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue