move all code to scope

This commit is contained in:
Jinzhu 2014-01-28 15:54:19 +08:00
parent 05ce3d3933
commit 2adbc4b8a6
3 changed files with 91 additions and 12 deletions

14
main.go
View File

@ -278,33 +278,33 @@ func (s *DB) RecordNotFound() bool {
// Migrations // Migrations
func (s *DB) CreateTable(value interface{}) *DB { func (s *DB) CreateTable(value interface{}) *DB {
return s.clone().do(value).createTable().db return s.clone().NewScope(value).createTable().db
} }
func (s *DB) DropTable(value interface{}) *DB { func (s *DB) DropTable(value interface{}) *DB {
return s.clone().do(value).dropTable().db return s.clone().NewScope(value).dropTable().db
} }
func (s *DB) AutoMigrate(value interface{}) *DB { func (s *DB) AutoMigrate(value interface{}) *DB {
return s.clone().do(value).autoMigrate().db return s.clone().NewScope(value).autoMigrate().db
} }
func (s *DB) ModifyColumn(column string, typ string) *DB { func (s *DB) ModifyColumn(column string, typ string) *DB {
s.clone().do(s.Value).modifyColumn(column, typ) s.clone().NewScope(s.Value).modifyColumn(column, typ)
return s return s
} }
func (s *DB) DropColumn(column string) *DB { func (s *DB) DropColumn(column string) *DB {
s.do(s.Value).dropColumn(column) s.clone().NewScope(s.Value).dropColumn(column)
return s return s
} }
func (s *DB) AddIndex(column string, index_name ...string) *DB { func (s *DB) AddIndex(column string, index_name ...string) *DB {
s.clone().do(s.Value).addIndex(column, index_name...) s.clone().NewScope(s.Value).addIndex(column, index_name...)
return s return s
} }
func (s *DB) RemoveIndex(column string) *DB { func (s *DB) RemoveIndex(column string) *DB {
s.clone().do(s.Value).removeIndex(column) s.clone().NewScope(s.Value).removeIndex(column)
return s return s
} }

View File

@ -255,17 +255,18 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) {
} }
} }
if tag = field.Tag; len(tag) == 0 && tag != "-" { tag = field.Tag
if len(tag) == 0 && tag != "-" {
if field.isPrimaryKey { if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(value, field.Size) tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
} else { } else {
tag = scope.Dialect().SqlTag(value, field.Size) tag = scope.Dialect().SqlTag(value, field.Size)
} }
}
if len(field.AddationalTag) > 0 { if len(field.AddationalTag) > 0 {
tag = tag + " " + field.AddationalTag tag = tag + " " + field.AddationalTag
} }
}
return return
} }
@ -296,7 +297,9 @@ func (scope *Scope) Fields() []*Field {
tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier)) tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier))
field.Tag = tag field.Tag = tag
field.AddationalTag = addationalTag field.AddationalTag = addationalTag
field.isPrimaryKey = scope.PrimaryKey() == field.DBName
field.Size = size field.Size = size
field.SqlTag = scope.SqlTagForField(&field) field.SqlTag = scope.SqlTagForField(&field)
if tag == "-" { if tag == "-" {
@ -339,11 +342,14 @@ func (scope *Scope) Fields() []*Field {
return fields return fields
} }
func (scope *Scope) Raw(sql string) { func (scope *Scope) Raw(sql string) *Scope {
scope.Sql = strings.Replace(sql, "$$", "?", -1) scope.Sql = strings.Replace(sql, "$$", "?", -1)
return scope
} }
func (scope *Scope) Exec() *Scope { func (scope *Scope) Exec() *Scope {
defer scope.Trace(time.Now())
if !scope.HasError() { if !scope.HasError() {
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) _, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
scope.Err(err) scope.Err(err)

73
scope_database.go Normal file
View File

@ -0,0 +1,73 @@
package gorm
import (
"fmt"
"strings"
)
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.quote(field.DBName)+" "+field.SqlTag)
}
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec()
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.quote(column))).Exec()
}
func (scope *Scope) addIndex(column string, names ...string) {
var indexName string
if len(names) > 0 {
indexName = names[0]
} else {
indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column)
}
scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.quote(column))).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec()
}
func (scope *Scope) autoMigrate() *Scope {
var tableName string
scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName())))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName)
scope.SqlVars = []interface{}{}
// If table doesn't exist
if len(tableName) == 0 {
scope.createTable()
} else {
for _, field := range scope.Fields() {
var column, data string
scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v",
scope.AddToVars(scope.TableName()),
scope.AddToVars(field.DBName),
))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data)
scope.SqlVars = []interface{}{}
// If column doesn't exist
if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec()
}
}
}
return scope
}