diff --git a/clause/expression.go b/clause/expression.go index 048b0980..6b3575df 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -22,7 +22,7 @@ type Expr struct { func (expr Expr) Build(builder Builder) { sql := expr.SQL for _, v := range expr.Vars { - sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) } builder.Write(sql) } diff --git a/migrator/migrator.go b/migrator/migrator.go index e9725935..fc93954e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) // Migrator migrator struct @@ -33,17 +34,25 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement // AutoMigrate func (migrator Migrator) AutoMigrate(values ...interface{}) error { + // if has table + // not -> create table + // check columns -> add column, change column type + // check foreign keys -> create indexes + // check indexes -> create indexes + return gorm.ErrNotImplemented } func (migrator Migrator) CreateTable(values ...interface{}) error { + // migrate + // create join table return gorm.ErrNotImplemented } func (migrator Migrator) DropTable(values ...interface{}) error { for _, value := range values { if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error + return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } @@ -74,7 +83,10 @@ func (migrator Migrator) RenameTable(oldName, newName string) error { func (migrator Migrator) AddColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? ADD ? ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -83,7 +95,9 @@ func (migrator Migrator) AddColumn(value interface{}, field string) error { func (migrator Migrator) DropColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -92,7 +106,10 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error { func (migrator Migrator) AlterColumn(value interface{}, field string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -102,7 +119,10 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error + return migrator.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, + ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -121,22 +141,126 @@ func (migrator Migrator) DropView(name string) error { } func (migrator Migrator) CreateConstraint(value interface{}, name string) error { - return gorm.ErrNotImplemented + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return migrator.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + + return migrator.DB.Exec( + sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references, + ).Error + } + } + + err := fmt.Errorf("failed to create constraint with name %v", name) + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if err = migrator.CreateIndex(value, cc.Name); err != nil { + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { + if err = migrator.CreateIndex(value, constraint.Name); err != nil { + return err + } + } + } + } + + return err + }) } func (migrator Migrator) DropConstraint(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error + return migrator.DB.Exec( + "ALTER TABLE ? DROP CONSTRAINT ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error }) } func (migrator Migrator) CreateIndex(value interface{}, name string) error { - return gorm.ErrNotImplemented + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + fields := []interface{}{} + for _, field := range idx.Fields { + str := stmt.Quote(field.DBName) + if field.Expression != "" { + str = field.Expression + } else if field.Length > 0 { + str += fmt.Sprintf("(%d)", field.Length) + } + + if field.Sort != "" { + str += " " + field.Sort + } + fields = append(fields, clause.Expr{SQL: str}) + } + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Comment != "" { + values = append(values, idx.Comment) + createIndexSQL += " COMMENT ?" + } + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + return migrator.DB.Raw(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = migrator.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) } func (migrator Migrator) DropIndex(value interface{}, name string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error + return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } @@ -144,7 +268,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool { var count int64 migrator.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error + return migrator.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Scan(&count).Error }) if count != 0 { @@ -155,7 +282,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool { func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error + return migrator.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error }) } diff --git a/schema/check.go b/schema/check.go new file mode 100644 index 00000000..a06ac67b --- /dev/null +++ b/schema/check.go @@ -0,0 +1,29 @@ +package schema + +import ( + "regexp" + "strings" +) + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + var checks = map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/schema/index_test.go b/schema/index_test.go index d9595ae6..1409b9c4 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -15,7 +15,7 @@ type UserIndex struct { Name4 string `gorm:"unique_index"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` - Age int64 `gorm:"index:profile,expression:(age+10)"` + Age int64 `gorm:"index:profile,expression:ABS(age)"` } func TestParseIndex(t *testing.T) { @@ -61,7 +61,7 @@ func TestParseIndex(t *testing.T) { Comment: "hello , world", Where: "age > 10", Fields: []schema.IndexOption{{}, { - Expression: "(age+10)", + Expression: "ABS(age)", }}, }, } diff --git a/schema/naming.go b/schema/naming.go index 80af4277..d6f26e9f 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,8 +14,10 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - IndexName(table, column string) string JoinTableName(table string) string + RelationshipFKName(Relationship) string + CheckerName(table, column string) string + IndexName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -37,6 +39,22 @@ func (ns NamingStrategy) ColumnName(table, column string) string { return toDBName(column) } +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + return ns.TablePrefix + inflection.Plural(toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return fmt.Sprintf("chk_%s_%s", table, column) +} + +// IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) @@ -50,11 +68,6 @@ func (ns NamingStrategy) IndexName(table, column string) string { return idxName } -// JoinTableName convert string to join table name -func (ns NamingStrategy) JoinTableName(str string) string { - return ns.TablePrefix + inflection.Plural(toDBName(str)) -} - var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index 671371fe..8081b0e7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -3,6 +3,7 @@ package schema import ( "fmt" "reflect" + "regexp" "strings" "github.com/jinzhu/inflection" @@ -292,3 +293,51 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH relation.Type = BelongsTo } } + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + Schema: rel.Schema, + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil && !ref.OwnPrimaryKey { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + + return &constraint +}