mirror of https://github.com/go-gorm/gorm.git
move all code to scope
This commit is contained in:
parent
05ce3d3933
commit
2adbc4b8a6
14
main.go
14
main.go
|
@ -278,33 +278,33 @@ func (s *DB) RecordNotFound() bool {
|
|||
|
||||
// Migrations
|
||||
func (s *DB) CreateTable(value interface{}) *DB {
|
||||
return s.clone().do(value).createTable().db
|
||||
return s.clone().NewScope(value).createTable().db
|
||||
}
|
||||
|
||||
func (s *DB) DropTable(value interface{}) *DB {
|
||||
return s.clone().do(value).dropTable().db
|
||||
return s.clone().NewScope(value).dropTable().db
|
||||
}
|
||||
|
||||
func (s *DB) AutoMigrate(value interface{}) *DB {
|
||||
return s.clone().do(value).autoMigrate().db
|
||||
return s.clone().NewScope(value).autoMigrate().db
|
||||
}
|
||||
|
||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||
s.clone().do(s.Value).modifyColumn(column, typ)
|
||||
s.clone().NewScope(s.Value).modifyColumn(column, typ)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) DropColumn(column string) *DB {
|
||||
s.do(s.Value).dropColumn(column)
|
||||
s.clone().NewScope(s.Value).dropColumn(column)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) AddIndex(column string, index_name ...string) *DB {
|
||||
s.clone().do(s.Value).addIndex(column, index_name...)
|
||||
s.clone().NewScope(s.Value).addIndex(column, index_name...)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DB) RemoveIndex(column string) *DB {
|
||||
s.clone().do(s.Value).removeIndex(column)
|
||||
s.clone().NewScope(s.Value).removeIndex(column)
|
||||
return s
|
||||
}
|
||||
|
|
16
scope.go
16
scope.go
|
@ -255,16 +255,17 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) {
|
|||
}
|
||||
}
|
||||
|
||||
if tag = field.Tag; len(tag) == 0 && tag != "-" {
|
||||
tag = field.Tag
|
||||
if len(tag) == 0 && tag != "-" {
|
||||
if field.isPrimaryKey {
|
||||
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
|
||||
} else {
|
||||
tag = scope.Dialect().SqlTag(value, field.Size)
|
||||
}
|
||||
}
|
||||
|
||||
if len(field.AddationalTag) > 0 {
|
||||
tag = tag + " " + field.AddationalTag
|
||||
}
|
||||
if len(field.AddationalTag) > 0 {
|
||||
tag = tag + " " + field.AddationalTag
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -296,7 +297,9 @@ func (scope *Scope) Fields() []*Field {
|
|||
tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier))
|
||||
field.Tag = tag
|
||||
field.AddationalTag = addationalTag
|
||||
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
|
||||
field.Size = size
|
||||
|
||||
field.SqlTag = scope.SqlTagForField(&field)
|
||||
|
||||
if tag == "-" {
|
||||
|
@ -339,11 +342,14 @@ func (scope *Scope) Fields() []*Field {
|
|||
return fields
|
||||
}
|
||||
|
||||
func (scope *Scope) Raw(sql string) {
|
||||
func (scope *Scope) Raw(sql string) *Scope {
|
||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) Exec() *Scope {
|
||||
defer scope.Trace(time.Now())
|
||||
|
||||
if !scope.HasError() {
|
||||
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
||||
scope.Err(err)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (scope *Scope) createTable() *Scope {
|
||||
var sqls []string
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsIgnored && len(field.SqlTag) > 0 {
|
||||
sqls = append(sqls, scope.quote(field.DBName)+" "+field.SqlTag)
|
||||
}
|
||||
}
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) dropTable() *Scope {
|
||||
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.quote(column), typ)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) dropColumn(column string) {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.quote(column))).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addIndex(column string, names ...string) {
|
||||
var indexName string
|
||||
if len(names) > 0 {
|
||||
indexName = names[0]
|
||||
} else {
|
||||
indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column)
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.quote(column))).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) autoMigrate() *Scope {
|
||||
var tableName string
|
||||
scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName())))
|
||||
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName)
|
||||
scope.SqlVars = []interface{}{}
|
||||
|
||||
// If table doesn't exist
|
||||
if len(tableName) == 0 {
|
||||
scope.createTable()
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
var column, data string
|
||||
scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v",
|
||||
scope.AddToVars(scope.TableName()),
|
||||
scope.AddToVars(field.DBName),
|
||||
))
|
||||
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data)
|
||||
scope.SqlVars = []interface{}{}
|
||||
|
||||
// If column doesn't exist
|
||||
if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scope
|
||||
}
|
Loading…
Reference in New Issue