diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index f9d5543d..0a48ada6 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -38,4 +38,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + + db.Callback().Row().Register("gorm:raw", RowQuery) + db.Callback().Raw().Register("gorm:raw", RawExec) } diff --git a/callbacks/raw.go b/callbacks/raw.go new file mode 100644 index 00000000..6d0a5aac --- /dev/null +++ b/callbacks/raw.go @@ -0,0 +1,11 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func RawExec(db *gorm.DB) { + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.RowsAffected, _ = result.RowsAffected() + if err != nil { + db.AddError(err) + } +} diff --git a/callbacks/row.go b/callbacks/row.go new file mode 100644 index 00000000..04fe4f48 --- /dev/null +++ b/callbacks/row.go @@ -0,0 +1,19 @@ +package callbacks + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) + +func RowQuery(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + + if _, ok := db.Get("rows"); ok { + db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } +} diff --git a/chainable_api.go b/chainable_api.go index a57deb63..ccd61716 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -222,5 +222,8 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + stmt := tx.Statement + stmt.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(stmt) return } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index bdca667d..78c048b4 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index e2fea53c..3b456891 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -29,7 +29,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index a3eeefb9..4ffc4204 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -28,7 +28,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index b77226db..804016a5 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -27,7 +27,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index 51c1def0..a42bc8ee 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -22,6 +22,10 @@ func init() { } } -func TestSqlite(t *testing.T) { +func TestCURD(t *testing.T) { tests.RunTestsSuit(t, DB) } + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/finisher_api.go b/finisher_api.go index 5389ed6a..8b824d12 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -108,11 +108,15 @@ func (db *DB) Count(value interface{}) (tx *DB) { } func (db *DB) Row() *sql.Row { - return nil + tx := db.getInstance() + tx.callbacks.Row().Execute(tx) + return tx.Statement.Dest.(*sql.Row) } func (db *DB) Rows() (*sql.Rows, error) { - return nil, nil + tx := db.Set("rows", true) + tx.callbacks.Row().Execute(tx) + return tx.Statement.Dest.(*sql.Rows), tx.Error } // Scan scan value to a struct @@ -162,5 +166,6 @@ func (db *DB) Rollback() (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + tx.callbacks.Raw().Execute(tx) return } diff --git a/gorm.go b/gorm.go index 23f812d1..2f10be60 100644 --- a/gorm.go +++ b/gorm.go @@ -138,6 +138,11 @@ func (db *DB) Callback() *callbacks { return db.callbacks } +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + func (db *DB) getInstance() *DB { if db.clone { ctx := db.Instance.Context diff --git a/migrator/migrator.go b/migrator/migrator.go index 9e94cc68..5debc600 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -265,8 +265,15 @@ func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { }) } -func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { - return nil, gorm.ErrNotImplemented +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + if err == nil { + columnTypes, err = rows.ColumnTypes() + } + return err + }) + return } func (m Migrator) CreateView(name string, option gorm.ViewOption) error { diff --git a/schema/check.go b/schema/check.go index a06ac67b..7d31ec70 100644 --- a/schema/check.go +++ b/schema/check.go @@ -17,9 +17,12 @@ func (schema *Schema) ParseCheckConstraints() map[string]Check { for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") - if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) { + if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } name := schema.namer.CheckerName(schema.Table, field.DBName) checks[name] = Check{Name: name, Constraint: chk, Field: field} } diff --git a/schema/check_test.go b/schema/check_test.go new file mode 100644 index 00000000..e4bc9ebe --- /dev/null +++ b/schema/check_test.go @@ -0,0 +1,55 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +type UserCheck struct { + Name string `gorm:"check:name_checker,name <> 'jinzhu'"` + Name2 string `gorm:"check:name <> 'jinzhu'"` + Name3 string `gorm:"check:,name <> 'jinzhu'"` +} + +func TestParseCheck(t *testing.T) { + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user check, got error %v", err) + } + + results := map[string]schema.Check{ + "name_checker": { + Name: "name_checker", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name2": { + Name: "chk_user_checks_name2", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name3": { + Name: "chk_user_checks_name3", + Constraint: "name <> 'jinzhu'", + }, + } + + checks := user.ParseCheckConstraints() + + for k, result := range results { + v, ok := checks[k] + if !ok { + t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) + } + + for _, name := range []string{"Name", "Constraint"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "check %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + } +} diff --git a/schema/index_test.go b/schema/index_test.go index 1409b9c4..d0e8dfe0 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -21,7 +21,7 @@ type UserIndex struct { func TestParseIndex(t *testing.T) { user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { - t.Fatalf("failed to parse user index index, got error %v", err) + t.Fatalf("failed to parse user index, got error %v", err) } results := map[string]schema.Index{ diff --git a/schema/relationship.go b/schema/relationship.go index 8081b0e7..6606d77e 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -317,7 +317,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { settings = ParseTagSetting(str, ",") ) - if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) { + if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { name = str[0:idx] } else { name = rel.Schema.namer.RelationshipFKName(*rel) @@ -339,5 +339,9 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } + if constraint.ReferenceSchema == nil { + return nil + } + return &constraint } diff --git a/tests/migrate.go b/tests/migrate.go new file mode 100644 index 00000000..0466fe11 --- /dev/null +++ b/tests/migrate.go @@ -0,0 +1,19 @@ +package tests + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestMigrate(t *testing.T, db *gorm.DB) { + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} + + db.AutoMigrate(allModels...) + + for _, m := range allModels { + if !db.Migrator().HasTable(m) { + t.Errorf("Failed to create table for %+v", m) + } + } +}