mirror of https://github.com/go-gorm/gorm.git
move all code to scope
This commit is contained in:
parent
05ce3d3933
commit
2adbc4b8a6
14
main.go
14
main.go
|
@ -278,33 +278,33 @@ func (s *DB) RecordNotFound() bool {
|
||||||
|
|
||||||
// Migrations
|
// Migrations
|
||||||
func (s *DB) CreateTable(value interface{}) *DB {
|
func (s *DB) CreateTable(value interface{}) *DB {
|
||||||
return s.clone().do(value).createTable().db
|
return s.clone().NewScope(value).createTable().db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) DropTable(value interface{}) *DB {
|
func (s *DB) DropTable(value interface{}) *DB {
|
||||||
return s.clone().do(value).dropTable().db
|
return s.clone().NewScope(value).dropTable().db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) AutoMigrate(value interface{}) *DB {
|
func (s *DB) AutoMigrate(value interface{}) *DB {
|
||||||
return s.clone().do(value).autoMigrate().db
|
return s.clone().NewScope(value).autoMigrate().db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||||
s.clone().do(s.Value).modifyColumn(column, typ)
|
s.clone().NewScope(s.Value).modifyColumn(column, typ)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) DropColumn(column string) *DB {
|
func (s *DB) DropColumn(column string) *DB {
|
||||||
s.do(s.Value).dropColumn(column)
|
s.clone().NewScope(s.Value).dropColumn(column)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) AddIndex(column string, index_name ...string) *DB {
|
func (s *DB) AddIndex(column string, index_name ...string) *DB {
|
||||||
s.clone().do(s.Value).addIndex(column, index_name...)
|
s.clone().NewScope(s.Value).addIndex(column, index_name...)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) RemoveIndex(column string) *DB {
|
func (s *DB) RemoveIndex(column string) *DB {
|
||||||
s.clone().do(s.Value).removeIndex(column)
|
s.clone().NewScope(s.Value).removeIndex(column)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
16
scope.go
16
scope.go
|
@ -255,16 +255,17 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tag = field.Tag; len(tag) == 0 && tag != "-" {
|
tag = field.Tag
|
||||||
|
if len(tag) == 0 && tag != "-" {
|
||||||
if field.isPrimaryKey {
|
if field.isPrimaryKey {
|
||||||
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
|
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
|
||||||
} else {
|
} else {
|
||||||
tag = scope.Dialect().SqlTag(value, field.Size)
|
tag = scope.Dialect().SqlTag(value, field.Size)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(field.AddationalTag) > 0 {
|
if len(field.AddationalTag) > 0 {
|
||||||
tag = tag + " " + field.AddationalTag
|
tag = tag + " " + field.AddationalTag
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -296,7 +297,9 @@ func (scope *Scope) Fields() []*Field {
|
||||||
tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier))
|
tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier))
|
||||||
field.Tag = tag
|
field.Tag = tag
|
||||||
field.AddationalTag = addationalTag
|
field.AddationalTag = addationalTag
|
||||||
|
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
|
||||||
field.Size = size
|
field.Size = size
|
||||||
|
|
||||||
field.SqlTag = scope.SqlTagForField(&field)
|
field.SqlTag = scope.SqlTagForField(&field)
|
||||||
|
|
||||||
if tag == "-" {
|
if tag == "-" {
|
||||||
|
@ -339,11 +342,14 @@ func (scope *Scope) Fields() []*Field {
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) Raw(sql string) {
|
func (scope *Scope) Raw(sql string) *Scope {
|
||||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
||||||
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) Exec() *Scope {
|
func (scope *Scope) Exec() *Scope {
|
||||||
|
defer scope.Trace(time.Now())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
||||||
scope.Err(err)
|
scope.Err(err)
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (scope *Scope) createTable() *Scope {
|
||||||
|
var sqls []string
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if !field.IsIgnored && len(field.SqlTag) > 0 {
|
||||||
|
sqls = append(sqls, scope.quote(field.DBName)+" "+field.SqlTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec()
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) dropTable() *Scope {
|
||||||
|
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec()
|
||||||
|
return scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) modifyColumn(column string, typ string) {
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.quote(column), typ)).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) dropColumn(column string) {
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.quote(column))).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) addIndex(column string, names ...string) {
|
||||||
|
var indexName string
|
||||||
|
if len(names) > 0 {
|
||||||
|
indexName = names[0]
|
||||||
|
} else {
|
||||||
|
indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column)
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.quote(column))).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) removeIndex(indexName string) {
|
||||||
|
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) autoMigrate() *Scope {
|
||||||
|
var tableName string
|
||||||
|
scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName())))
|
||||||
|
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName)
|
||||||
|
scope.SqlVars = []interface{}{}
|
||||||
|
|
||||||
|
// If table doesn't exist
|
||||||
|
if len(tableName) == 0 {
|
||||||
|
scope.createTable()
|
||||||
|
} else {
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
var column, data string
|
||||||
|
scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v",
|
||||||
|
scope.AddToVars(scope.TableName()),
|
||||||
|
scope.AddToVars(field.DBName),
|
||||||
|
))
|
||||||
|
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data)
|
||||||
|
scope.SqlVars = []interface{}{}
|
||||||
|
|
||||||
|
// If column doesn't exist
|
||||||
|
if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored {
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope
|
||||||
|
}
|
Loading…
Reference in New Issue