Move scope_database to scope.go

This commit is contained in:
Jinzhu 2014-01-28 17:09:43 +08:00
parent 6f1dd5fae3
commit 036df5f46b
4 changed files with 102 additions and 108 deletions

View File

@ -218,8 +218,8 @@ func (s *DB) Model(value interface{}) *DB {
return c return c
} }
func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreign_keys...).db return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
} }
func (s *DB) Pluck(column string, value interface{}) *DB { func (s *DB) Pluck(column string, value interface{}) *DB {
@ -299,8 +299,8 @@ func (s *DB) DropColumn(column string) *DB {
return s return s
} }
func (s *DB) AddIndex(column string, index_name ...string) *DB { func (s *DB) AddIndex(column string, indexName ...string) *DB {
s.clone().NewScope(s.Value).addIndex(column, index_name...) s.clone().NewScope(s.Value).addIndex(column, indexName...)
return s return s
} }

View File

@ -242,13 +242,13 @@ func (s *Scope) CombinedConditionSql() string {
func (scope *Scope) SqlTagForField(field *Field) (tag string) { func (scope *Scope) SqlTagForField(field *Field) (tag string) {
value := field.Value value := field.Value
reflect_value := reflect.ValueOf(value) reflectValue := reflect.ValueOf(value)
if field.IsScanner() { if field.IsScanner() {
value = reflect_value.Field(0).Interface() value = reflectValue.Field(0).Interface()
} }
switch reflect_value.Kind() { switch reflectValue.Kind() {
case reflect.Slice: case reflect.Slice:
if _, ok := value.([]byte); !ok { if _, ok := value.([]byte); !ok {
return return
@ -470,3 +470,70 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
} }
return scope return scope
} }
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
}
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec()
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) addIndex(column string, names ...string) {
var indexName string
if len(names) > 0 {
indexName = names[0]
} else {
indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column)
}
scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec()
}
func (scope *Scope) autoMigrate() *Scope {
var tableName string
scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName())))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName)
scope.SqlVars = []interface{}{}
// If table doesn't exist
if len(tableName) == 0 {
scope.createTable()
} else {
for _, field := range scope.Fields() {
var column, data string
scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v",
scope.AddToVars(scope.TableName()),
scope.AddToVars(field.DBName),
))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data)
scope.SqlVars = []interface{}{}
// If column doesn't exist
if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec()
}
}
}
return scope
}

View File

@ -52,11 +52,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
switch reflect.TypeOf(arg).Kind() { switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg) values := reflect.ValueOf(arg)
var temp_marks []string var tempMarks []string
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default: default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
@ -69,7 +69,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
} }
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var not_equal_sql string var notEqualSql string
switch value := clause["query"].(type) { switch value := clause["query"].(type) {
case string: case string:
@ -78,10 +78,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id) return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value) str = fmt.Sprintf(" NOT (%v) ", value)
not_equal_sql = fmt.Sprintf("NOT (%v)", value) notEqualSql = fmt.Sprintf("NOT (%v)", value)
} else { } else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
not_equal_sql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
} }
case int, int64, int32: case int, int64, int32:
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value)
@ -113,16 +113,16 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
switch reflect.TypeOf(arg).Kind() { switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg) values := reflect.ValueOf(arg)
var temp_marks []string var tempMarks []string
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default: default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value() arg, _ = scanner.Value()
} }
str = strings.Replace(not_equal_sql, "?", scope.AddToVars(arg), 1) str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
} }
} }
return return
@ -135,45 +135,45 @@ func (scope *Scope) where(where ...interface{}) {
} }
func (scope *Scope) whereSql() (sql string) { func (scope *Scope) whereSql() (sql string) {
var primary_condiations, and_conditions, or_conditions []string var primaryCondiations, andConditions, orConditions []string
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") primaryCondiations = append(primaryCondiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')")
} }
if !scope.PrimaryKeyZero() { if !scope.PrimaryKeyZero() {
primary_condiations = append(primary_condiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue()))) primaryCondiations = append(primaryCondiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue())))
} }
for _, clause := range scope.Search.WhereConditions { for _, clause := range scope.Search.WhereConditions {
and_conditions = append(and_conditions, scope.buildWhereCondition(clause)) andConditions = append(andConditions, scope.buildWhereCondition(clause))
} }
for _, clause := range scope.Search.OrConditions { for _, clause := range scope.Search.OrConditions {
or_conditions = append(or_conditions, scope.buildWhereCondition(clause)) orConditions = append(orConditions, scope.buildWhereCondition(clause))
} }
for _, clause := range scope.Search.NotConditions { for _, clause := range scope.Search.NotConditions {
and_conditions = append(and_conditions, scope.buildNotCondition(clause)) andConditions = append(andConditions, scope.buildNotCondition(clause))
} }
or_sql := strings.Join(or_conditions, " OR ") orSql := strings.Join(orConditions, " OR ")
combined_sql := strings.Join(and_conditions, " AND ") combinedSql := strings.Join(andConditions, " AND ")
if len(combined_sql) > 0 { if len(combinedSql) > 0 {
if len(or_sql) > 0 { if len(orSql) > 0 {
combined_sql = combined_sql + " OR " + or_sql combinedSql = combinedSql + " OR " + orSql
} }
} else { } else {
combined_sql = or_sql combinedSql = orSql
} }
if len(primary_condiations) > 0 { if len(primaryCondiations) > 0 {
sql = "WHERE " + strings.Join(primary_condiations, " AND ") sql = "WHERE " + strings.Join(primaryCondiations, " AND ")
if len(combined_sql) > 0 { if len(combinedSql) > 0 {
sql = sql + " AND (" + combined_sql + ")" sql = sql + " AND (" + combinedSql + ")"
} }
} else if len(combined_sql) > 0 { } else if len(combinedSql) > 0 {
sql = "WHERE " + combined_sql sql = "WHERE " + combinedSql
} }
return return
} }

View File

@ -1,73 +0,0 @@
package gorm
import (
"fmt"
"strings"
)
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, field := range scope.Fields() {
if !field.IsIgnored && len(field.SqlTag) > 0 {
sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag)
}
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec()
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) addIndex(column string, names ...string) {
var indexName string
if len(names) > 0 {
indexName = names[0]
} else {
indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column)
}
scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec()
}
func (scope *Scope) autoMigrate() *Scope {
var tableName string
scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName())))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName)
scope.SqlVars = []interface{}{}
// If table doesn't exist
if len(tableName) == 0 {
scope.createTable()
} else {
for _, field := range scope.Fields() {
var column, data string
scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v",
scope.AddToVars(scope.TableName()),
scope.AddToVars(field.DBName),
))
scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data)
scope.SqlVars = []interface{}{}
// If column doesn't exist
if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec()
}
}
}
return scope
}