diff --git a/scope.go b/scope.go index ca291c4b..fff9d379 100644 --- a/scope.go +++ b/scope.go @@ -1,12 +1,10 @@ package gorm import ( - "database/sql" "errors" "fmt" "github.com/jinzhu/gorm/dialect" "go/ast" - "strconv" "strings" "time" @@ -24,29 +22,11 @@ type Scope struct { skipLeft bool } -func (scope *Scope) Quote(str string) string { - return scope.Dialect().Quote(str) -} - func (db *DB) NewScope(value interface{}) *Scope { db.Value = value return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} } -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - func (scope *Scope) New(value interface{}) *Scope { return &Scope{db: scope.db.parent, Search: &search{}, Value: value} } @@ -59,6 +39,14 @@ func (scope *Scope) DB() sqlCommon { return scope.db.db } +func (scope *Scope) SkipLeft() { + scope.skipLeft = true +} + +func (scope *Scope) Quote(str string) string { + return scope.Dialect().Quote(str) +} + func (scope *Scope) Dialect() dialect.Dialect { return scope.db.parent.dialect } @@ -115,40 +103,6 @@ func (scope *Scope) FieldByName(name string) (interface{}, bool) { return nil, false } -func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { - data := reflect.Indirect(reflect.ValueOf(scope.Value)) - if !data.CanAddr() { - return values, true - } - - for key, value := range values { - if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() { - if field.Interface() != value { - - switch field.Kind() { - case reflect.Int, reflect.Int32, reflect.Int64: - if s, ok := value.(string); ok { - i, err := strconv.Atoi(s) - if scope.Err(err) == nil { - value = i - } - } - - scope.db.log(field.Int() != reflect.ValueOf(value).Int()) - if field.Int() != reflect.ValueOf(value).Int() { - hasUpdate = true - setFieldValue(field, value) - } - default: - hasUpdate = true - setFieldValue(field, value) - } - } - } - } - return -} - func (scope *Scope) SetColumn(column string, value interface{}) { if scope.Value == nil { return @@ -241,45 +195,6 @@ func (scope *Scope) CombinedConditionSql() string { scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() } -func (scope *Scope) SqlTagForField(field *Field) (tag string) { - tag, addationalTag, size := parseSqlTag(field.Tag.Get(scope.db.parent.tagIdentifier)) - - if tag == "-" { - field.IsIgnored = true - } - - value := field.Value - reflectValue := reflect.ValueOf(value) - - if field.IsScanner() { - value = reflectValue.Field(0).Interface() - } - - switch reflectValue.Kind() { - case reflect.Slice: - if _, ok := value.([]byte); !ok { - return - } - case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { - return - } - } - - if len(tag) == 0 && tag != "-" { - if field.isPrimaryKey { - tag = scope.Dialect().PrimaryKeyTag(value, size) - } else { - tag = scope.Dialect().SqlTag(value, size) - } - } - - if len(addationalTag) > 0 { - tag = tag + " " + addationalTag - } - return -} - func (scope *Scope) Fields() []*Field { indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) fields := []*Field{} @@ -306,7 +221,7 @@ func (scope *Scope) Fields() []*Field { if scope.db != nil { field.Tag = fieldStruct.Tag - field.SqlTag = scope.SqlTagForField(&field) + field.SqlTag = scope.sqlTagForField(&field) // parse association elem := reflect.Indirect(value) @@ -398,140 +313,3 @@ func (scope *Scope) CommitOrRollback() *Scope { } return scope } - -func (scope *Scope) row() *sql.Row { - defer scope.Trace(time.Now()) - scope.prepareQuerySql() - return scope.DB().QueryRow(scope.Sql, scope.SqlVars...) -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.Trace(time.Now()) - scope.prepareQuerySql() - return scope.DB().Query(scope.Sql, scope.SqlVars...) -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.WhereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) - } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false) - return scope -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search = scope.Search.clone().selects(column) - if dest.Kind() != reflect.Slice { - scope.Err(errors.New("Results should be a slice")) - return scope - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - scope.Search = scope.Search.clone().selects("count(*)") - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - value := reflect.Indirect(reflect.ValueOf(scope.Value)) - if value.Kind() == reflect.Slice { - return value.Type().Elem().Name() - } else { - return value.Type().Name() - } -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.New(value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - if foreignValue, ok := scope.FieldByName(foreignKey); ok { - return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) - } else if toScope.HasColumn(foreignKey) { - sql := fmt.Sprintf("%v = ?", scope.Quote(toSnake(foreignKey))) - return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) - } - } - return scope -} - -func (scope *Scope) createTable() *Scope { - var sqls []string - for _, field := range scope.Fields() { - if !field.IsIgnored && len(field.SqlTag) > 0 { - sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) - } - } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(column string, names ...string) { - var indexName string - if len(names) > 0 { - indexName = names[0] - } else { - indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) - } - - scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() -} - -func (scope *Scope) autoMigrate() *Scope { - var tableName string - scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) - scope.SqlVars = []interface{}{} - - // If table doesn't exist - if len(tableName) == 0 { - scope.createTable() - } else { - for _, field := range scope.Fields() { - var column, data string - scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", - scope.AddToVars(scope.TableName()), - scope.AddToVars(field.DBName), - )) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) - scope.SqlVars = []interface{}{} - - // If column doesn't exist - if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() - } - } - } - return scope -} diff --git a/scope_condition.go b/scope_private.go similarity index 52% rename from scope_condition.go rename to scope_private.go index 14ed9033..5159f156 100644 --- a/scope_condition.go +++ b/scope_private.go @@ -3,11 +3,13 @@ package gorm import ( "database/sql" "database/sql/driver" + "errors" "fmt" "reflect" "regexp" "strconv" "strings" + "time" ) func (scope *Scope) primaryCondiation(value interface{}) string { @@ -245,3 +247,223 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { } return scope } + +func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + for _, f := range funcs { + (*f)(scope) + if scope.skipLeft { + break + } + } + return scope +} + +func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { + data := reflect.Indirect(reflect.ValueOf(scope.Value)) + if !data.CanAddr() { + return values, true + } + + for key, value := range values { + if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() { + if field.Interface() != value { + + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + if s, ok := value.(string); ok { + i, err := strconv.Atoi(s) + if scope.Err(err) == nil { + value = i + } + } + + scope.db.log(field.Int() != reflect.ValueOf(value).Int()) + if field.Int() != reflect.ValueOf(value).Int() { + hasUpdate = true + setFieldValue(field, value) + } + default: + hasUpdate = true + setFieldValue(field, value) + } + } + } + } + return +} + +func (scope *Scope) sqlTagForField(field *Field) (tag string) { + tag, addationalTag, size := parseSqlTag(field.Tag.Get(scope.db.parent.tagIdentifier)) + + if tag == "-" { + field.IsIgnored = true + } + + value := field.Value + reflectValue := reflect.ValueOf(value) + + if field.IsScanner() { + value = reflectValue.Field(0).Interface() + } + + switch reflectValue.Kind() { + case reflect.Slice: + if _, ok := value.([]byte); !ok { + return + } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + return + } + } + + if len(tag) == 0 && tag != "-" { + if field.isPrimaryKey { + tag = scope.Dialect().PrimaryKeyTag(value, size) + } else { + tag = scope.Dialect().SqlTag(value, size) + } + } + + if len(addationalTag) > 0 { + tag = tag + " " + addationalTag + } + return +} + +func (scope *Scope) row() *sql.Row { + defer scope.Trace(time.Now()) + scope.prepareQuerySql() + return scope.DB().QueryRow(scope.Sql, scope.SqlVars...) +} + +func (scope *Scope) rows() (*sql.Rows, error) { + defer scope.Trace(time.Now()) + scope.prepareQuerySql() + return scope.DB().Query(scope.Sql, scope.SqlVars...) +} + +func (scope *Scope) initialize() *Scope { + for _, clause := range scope.Search.WhereConditions { + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) + } + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false) + return scope +} + +func (scope *Scope) pluck(column string, value interface{}) *Scope { + dest := reflect.Indirect(reflect.ValueOf(value)) + scope.Search = scope.Search.clone().selects(column) + if dest.Kind() != reflect.Slice { + scope.Err(errors.New("Results should be a slice")) + return scope + } + + rows, err := scope.rows() + if scope.Err(err) == nil { + defer rows.Close() + for rows.Next() { + elem := reflect.New(dest.Type().Elem()).Interface() + scope.Err(rows.Scan(elem)) + dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) + } + } + return scope +} + +func (scope *Scope) count(value interface{}) *Scope { + scope.Search = scope.Search.clone().selects("count(*)") + scope.Err(scope.row().Scan(value)) + return scope +} + +func (scope *Scope) typeName() string { + value := reflect.Indirect(reflect.ValueOf(scope.Value)) + if value.Kind() == reflect.Slice { + return value.Type().Elem().Name() + } else { + return value.Type().Name() + } +} + +func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { + toScope := scope.New(value) + + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { + if foreignValue, ok := scope.FieldByName(foreignKey); ok { + return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) + } else if toScope.HasColumn(foreignKey) { + sql := fmt.Sprintf("%v = ?", scope.Quote(toSnake(foreignKey))) + return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) + } + } + return scope +} + +func (scope *Scope) createTable() *Scope { + var sqls []string + for _, field := range scope.Fields() { + if !field.IsIgnored && len(field.SqlTag) > 0 { + sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) + } + } + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() + return scope +} + +func (scope *Scope) dropTable() *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() + return scope +} + +func (scope *Scope) modifyColumn(column string, typ string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() +} + +func (scope *Scope) dropColumn(column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) addIndex(column string, names ...string) { + var indexName string + if len(names) > 0 { + indexName = names[0] + } else { + indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) + } + + scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) removeIndex(indexName string) { + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() +} + +func (scope *Scope) autoMigrate() *Scope { + var tableName string + scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) + scope.SqlVars = []interface{}{} + + // If table doesn't exist + if len(tableName) == 0 { + scope.createTable() + } else { + for _, field := range scope.Fields() { + var column, data string + scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", + scope.AddToVars(scope.TableName()), + scope.AddToVars(field.DBName), + )) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) + scope.SqlVars = []interface{}{} + + // If column doesn't exist + if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + } + } + } + return scope +}