Finish CreateConstraint

This commit is contained in:
Jinzhu 2020-02-22 11:15:51 +08:00
parent ea0b13f7a3
commit 0be4817ff9
6 changed files with 241 additions and 20 deletions

View File

@ -22,7 +22,7 @@ type Expr struct {
func (expr Expr) Build(builder Builder) { func (expr Expr) Build(builder Builder) {
sql := expr.SQL sql := expr.SQL
for _, v := range expr.Vars { for _, v := range expr.Vars {
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1)
} }
builder.Write(sql) builder.Write(sql)
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
) )
// Migrator migrator struct // Migrator migrator struct
@ -33,17 +34,25 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement
// AutoMigrate // AutoMigrate
func (migrator Migrator) AutoMigrate(values ...interface{}) error { func (migrator Migrator) AutoMigrate(values ...interface{}) error {
// if has table
// not -> create table
// check columns -> add column, change column type
// check foreign keys -> create indexes
// check indexes -> create indexes
return gorm.ErrNotImplemented return gorm.ErrNotImplemented
} }
func (migrator Migrator) CreateTable(values ...interface{}) error { func (migrator Migrator) CreateTable(values ...interface{}) error {
// migrate
// create join table
return gorm.ErrNotImplemented return gorm.ErrNotImplemented
} }
func (migrator Migrator) DropTable(values ...interface{}) error { func (migrator Migrator) DropTable(values ...interface{}) error {
for _, value := range values { for _, value := range values {
if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
}); err != nil { }); err != nil {
return err return err
} }
@ -74,7 +83,10 @@ func (migrator Migrator) RenameTable(oldName, newName string) error {
func (migrator Migrator) AddColumn(value interface{}, field string) error { func (migrator Migrator) AddColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error return migrator.DB.Exec(
"ALTER TABLE ? ADD ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -83,7 +95,9 @@ func (migrator Migrator) AddColumn(value interface{}, field string) error {
func (migrator Migrator) DropColumn(value interface{}, field string) error { func (migrator Migrator) DropColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error return migrator.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName},
).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -92,7 +106,10 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error {
func (migrator Migrator) AlterColumn(value interface{}, field string) error { func (migrator Migrator) AlterColumn(value interface{}, field string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error return migrator.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType},
).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -102,7 +119,10 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string)
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName)
return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error return migrator.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName},
).Error
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -121,22 +141,126 @@ func (migrator Migrator) DropView(name string) error {
} }
func (migrator Migrator) CreateConstraint(value interface{}, name string) error { func (migrator Migrator) CreateConstraint(value interface{}, name string) error {
return gorm.ErrNotImplemented return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return migrator.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
return migrator.DB.Exec(
sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references,
).Error
}
}
err := fmt.Errorf("failed to create constraint with name %v", name)
if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints {
if err = migrator.CreateIndex(value, cc.Name); err != nil {
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
if err = migrator.CreateIndex(value, constraint.Name); err != nil {
return err
}
}
}
}
return err
})
} }
func (migrator Migrator) DropConstraint(value interface{}, name string) error { func (migrator Migrator) DropConstraint(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error return migrator.DB.Exec(
"ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
).Error
}) })
} }
func (migrator Migrator) CreateIndex(value interface{}, name string) error { func (migrator Migrator) CreateIndex(value interface{}, name string) error {
return gorm.ErrNotImplemented return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
err := fmt.Errorf("failed to create index with name %v", name)
indexes := stmt.Schema.ParseIndexes()
if idx, ok := indexes[name]; ok {
fields := []interface{}{}
for _, field := range idx.Fields {
str := stmt.Quote(field.DBName)
if field.Expression != "" {
str = field.Expression
} else if field.Length > 0 {
str += fmt.Sprintf("(%d)", field.Length)
}
if field.Sort != "" {
str += " " + field.Sort
}
fields = append(fields, clause.Expr{SQL: str})
}
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Comment != "" {
values = append(values, idx.Comment)
createIndexSQL += " COMMENT ?"
}
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
return migrator.DB.Raw(createIndexSQL, values...).Error
} else if field := stmt.Schema.LookUpField(name); field != nil {
for _, idx := range indexes {
for _, idxOpt := range idx.Fields {
if idxOpt.Field == field {
if err = migrator.CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
}
}
return err
})
} }
func (migrator Migrator) DropIndex(value interface{}, name string) error { func (migrator Migrator) DropIndex(value interface{}, name string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
}) })
} }
@ -144,7 +268,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
migrator.RunWithValue(value, func(stmt *gorm.Statement) error { migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := migrator.DB.Migrator().CurrentDatabase() currentDatabase := migrator.DB.Migrator().CurrentDatabase()
return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error return migrator.DB.Raw(
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
).Scan(&count).Error
}) })
if count != 0 { if count != 0 {
@ -155,7 +282,10 @@ func (migrator Migrator) HasIndex(value interface{}, name string) bool {
func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error return migrator.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
}) })
} }

29
schema/check.go Normal file
View File

@ -0,0 +1,29 @@
package schema
import (
"regexp"
"strings"
)
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
var checks = map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

View File

@ -15,7 +15,7 @@ type UserIndex struct {
Name4 string `gorm:"unique_index"` Name4 string `gorm:"unique_index"`
Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"`
Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"`
Age int64 `gorm:"index:profile,expression:(age+10)"` Age int64 `gorm:"index:profile,expression:ABS(age)"`
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
@ -61,7 +61,7 @@ func TestParseIndex(t *testing.T) {
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
Fields: []schema.IndexOption{{}, { Fields: []schema.IndexOption{{}, {
Expression: "(age+10)", Expression: "ABS(age)",
}}, }},
}, },
} }

View File

@ -14,8 +14,10 @@ import (
type Namer interface { type Namer interface {
TableName(table string) string TableName(table string) string
ColumnName(table, column string) string ColumnName(table, column string) string
IndexName(table, column string) string
JoinTableName(table string) string JoinTableName(table string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
} }
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
@ -37,6 +39,22 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
return toDBName(column) return toDBName(column)
} }
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
return ns.TablePrefix + inflection.Plural(toDBName(str))
}
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table)
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return fmt.Sprintf("chk_%s_%s", table, column)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string { func (ns NamingStrategy) IndexName(table, column string) string {
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
@ -50,11 +68,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return idxName return idxName
} }
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
return ns.TablePrefix + inflection.Plural(toDBName(str))
}
var ( var (
smap sync.Map smap sync.Map
// https://github.com/golang/lint/blob/master/lint.go#L770 // https://github.com/golang/lint/blob/master/lint.go#L770

View File

@ -3,6 +3,7 @@ package schema
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
@ -292,3 +293,51 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
relation.Type = BelongsTo relation.Type = BelongsTo
} }
} }
type Constraint struct {
Name string
Field *Field
Schema *Schema
ForeignKeys []*Field
ReferenceSchema *Schema
References []*Field
OnDelete string
OnUpdate string
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
return nil
}
var (
name string
idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",")
)
if idx != -1 && regexp.MustCompile("^[A-Za-z]+$").MatchString(str[0:idx]) {
name = str[0:idx]
} else {
name = rel.Schema.namer.RelationshipFKName(*rel)
}
constraint := Constraint{
Name: name,
Field: rel.Field,
OnUpdate: settings["ONUPDATE"],
OnDelete: settings["ONDELETE"],
Schema: rel.Schema,
}
for _, ref := range rel.References {
if ref.PrimaryKey != nil && !ref.OwnPrimaryKey {
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
constraint.References = append(constraint.References, ref.PrimaryKey)
constraint.ReferenceSchema = ref.PrimaryKey.Schema
}
}
return &constraint
}