From 2adbc4b8a6114b7173930f503df5f979a8f070ee Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jan 2014 15:54:19 +0800 Subject: [PATCH] move all code to scope --- main.go | 14 ++++----- scope.go | 16 +++++++---- scope_database.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 scope_database.go diff --git a/main.go b/main.go index 6ff04f8f..25711a82 100644 --- a/main.go +++ b/main.go @@ -278,33 +278,33 @@ func (s *DB) RecordNotFound() bool { // Migrations func (s *DB) CreateTable(value interface{}) *DB { - return s.clone().do(value).createTable().db + return s.clone().NewScope(value).createTable().db } func (s *DB) DropTable(value interface{}) *DB { - return s.clone().do(value).dropTable().db + return s.clone().NewScope(value).dropTable().db } func (s *DB) AutoMigrate(value interface{}) *DB { - return s.clone().do(value).autoMigrate().db + return s.clone().NewScope(value).autoMigrate().db } func (s *DB) ModifyColumn(column string, typ string) *DB { - s.clone().do(s.Value).modifyColumn(column, typ) + s.clone().NewScope(s.Value).modifyColumn(column, typ) return s } func (s *DB) DropColumn(column string) *DB { - s.do(s.Value).dropColumn(column) + s.clone().NewScope(s.Value).dropColumn(column) return s } func (s *DB) AddIndex(column string, index_name ...string) *DB { - s.clone().do(s.Value).addIndex(column, index_name...) + s.clone().NewScope(s.Value).addIndex(column, index_name...) return s } func (s *DB) RemoveIndex(column string) *DB { - s.clone().do(s.Value).removeIndex(column) + s.clone().NewScope(s.Value).removeIndex(column) return s } diff --git a/scope.go b/scope.go index bb92da8e..c9c5c131 100644 --- a/scope.go +++ b/scope.go @@ -255,16 +255,17 @@ func (scope *Scope) SqlTagForField(field *Field) (tag string) { } } - if tag = field.Tag; len(tag) == 0 && tag != "-" { + tag = field.Tag + if len(tag) == 0 && tag != "-" { if field.isPrimaryKey { tag = scope.Dialect().PrimaryKeyTag(value, field.Size) } else { tag = scope.Dialect().SqlTag(value, field.Size) } + } - if len(field.AddationalTag) > 0 { - tag = tag + " " + field.AddationalTag - } + if len(field.AddationalTag) > 0 { + tag = tag + " " + field.AddationalTag } return } @@ -296,7 +297,9 @@ func (scope *Scope) Fields() []*Field { tag, addationalTag, size := parseSqlTag(fieldStruct.Tag.Get(scope.db.parent.tagIdentifier)) field.Tag = tag field.AddationalTag = addationalTag + field.isPrimaryKey = scope.PrimaryKey() == field.DBName field.Size = size + field.SqlTag = scope.SqlTagForField(&field) if tag == "-" { @@ -339,11 +342,14 @@ func (scope *Scope) Fields() []*Field { return fields } -func (scope *Scope) Raw(sql string) { +func (scope *Scope) Raw(sql string) *Scope { scope.Sql = strings.Replace(sql, "$$", "?", -1) + return scope } func (scope *Scope) Exec() *Scope { + defer scope.Trace(time.Now()) + if !scope.HasError() { _, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) scope.Err(err) diff --git a/scope_database.go b/scope_database.go new file mode 100644 index 00000000..99bc197e --- /dev/null +++ b/scope_database.go @@ -0,0 +1,73 @@ +package gorm + +import ( + "fmt" + "strings" +) + +func (scope *Scope) createTable() *Scope { + var sqls []string + for _, field := range scope.Fields() { + if !field.IsIgnored && len(field.SqlTag) > 0 { + sqls = append(sqls, scope.quote(field.DBName)+" "+field.SqlTag) + } + } + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() + return scope +} + +func (scope *Scope) dropTable() *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() + return scope +} + +func (scope *Scope) modifyColumn(column string, typ string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.quote(column), typ)).Exec() +} + +func (scope *Scope) dropColumn(column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.quote(column))).Exec() +} + +func (scope *Scope) addIndex(column string, names ...string) { + var indexName string + if len(names) > 0 { + indexName = names[0] + } else { + indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) + } + + scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.quote(column))).Exec() +} + +func (scope *Scope) removeIndex(indexName string) { + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() +} + +func (scope *Scope) autoMigrate() *Scope { + var tableName string + scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) + scope.SqlVars = []interface{}{} + + // If table doesn't exist + if len(tableName) == 0 { + scope.createTable() + } else { + for _, field := range scope.Fields() { + var column, data string + scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", + scope.AddToVars(scope.TableName()), + scope.AddToVars(field.DBName), + )) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) + scope.SqlVars = []interface{}{} + + // If column doesn't exist + if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + } + } + } + return scope +}