diff --git a/schema/index.go b/schema/index.go new file mode 100644 index 00000000..ea3a68f5 --- /dev/null +++ b/schema/index.go @@ -0,0 +1,116 @@ +package schema + +import ( + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + var indexes = map[string]Index{} + + for _, field := range schema.FieldsByDBName { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" { + for _, index := range parseFieldIndexes(field) { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + idx.Fields = append(idx.Fields, index.Fields...) + indexes[index.Name] = idx + } + } + } + + return indexes +} + +func parseFieldIndexes(field *Field) (indexes []Index) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUE_INDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + settings = map[string]string{} + ) + + names := strings.Split(tag, ",") + for i := 0; i < len(names); i++ { + if len(names[i]) > 0 { + j := i + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + names[i] + names[i] = "" + } else { + break + } + } + } + + if i == 0 { + name = names[0] + } + + values := strings.Split(names[i], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + if name == "" { + name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + } + + length, _ := strconv.Atoi(settings["LENGTH"]) + + if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Type: settings["TYPE"], + Length: length, + Where: settings["WHERE"], + Comment: settings["COMMENT"], + }}, + }) + } + } + } + + return +} diff --git a/schema/index_test.go b/schema/index_test.go new file mode 100644 index 00000000..8c2cb9fe --- /dev/null +++ b/schema/index_test.go @@ -0,0 +1,96 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +type UserIndex struct { + Name string `gorm:"index"` + Name2 string `gorm:"index:idx_name,unique"` + Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` + Name4 string `gorm:"unique_index"` + Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` + Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` + Age int64 `gorm:"index:profile,expression:(age+10)"` +} + +func TestParseIndex(t *testing.T) { + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index index, got error %v", err) + } + + results := map[string]schema.Index{ + "idx_user_indices_name": { + Name: "idx_user_indices_name", + Fields: []schema.IndexOption{{}}, + }, + "idx_name": { + Name: "idx_name", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, + "idx_user_indices_name3": { + Name: "idx_user_indices_name3", + Fields: []schema.IndexOption{{ + Sort: "desc", + Collate: "utf8", + Length: 10, + Type: "btree", + Where: "name3 != 'jinzhu'", + }}, + }, + "idx_user_indices_name4": { + Name: "idx_user_indices_name4", + Class: "UNIQUE", + Fields: []schema.IndexOption{{}}, + }, + "idx_user_indices_name5": { + Name: "idx_user_indices_name5", + Class: "FULLTEXT", + Fields: []schema.IndexOption{{ + Comment: "hello , world", + Where: "age > 10", + }}, + }, + "profile": { + Name: "profile", + Fields: []schema.IndexOption{{ + Comment: "hello , world", + Where: "age > 10", + }, { + Expression: "(age+10)", + }}, + }, + } + + indices := user.ParseIndexes() + + for k, result := range results { + v, ok := indices[k] + if !ok { + t.Errorf("Failed to found index %v from parsed indices %+v", k, indices) + } + + if result.Name != v.Name { + t.Errorf("index %v name should equal, expects %v, got %v", k, result.Name, v.Name) + } + + if result.Class != v.Class { + t.Errorf("index %v Class should equal, expects %v, got %v", k, result.Class, v.Class) + } + + for idx, ef := range result.Fields { + rf := v.Fields[idx] + for _, name := range []string{"Expression", "Sort", "Collate", "Length", "Type", "Where"} { + if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { + t.Errorf("index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface()) + } + } + } + } +} diff --git a/schema/naming.go b/schema/naming.go index e6a5625e..80af4277 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -1,9 +1,11 @@ package schema import ( + "crypto/sha1" "fmt" "strings" "sync" + "unicode/utf8" "github.com/jinzhu/inflection" ) @@ -12,6 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string + IndexName(table, column string) string JoinTableName(table string) string } @@ -30,8 +33,21 @@ func (ns NamingStrategy) TableName(str string) string { } // ColumnName convert string to column name -func (ns NamingStrategy) ColumnName(table, str string) string { - return toDBName(str) +func (ns NamingStrategy) ColumnName(table, column string) string { + return toDBName(column) +} + +func (ns NamingStrategy) IndexName(table, column string) string { + idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + + if utf8.RuneCountInString(idxName) > 64 { + h := sha1.New() + h.Write([]byte(idxName)) + bs := h.Sum(nil) + + idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] + } + return idxName } // JoinTableName convert string to join table name