diff --git a/main.go b/main.go index 519874df..a72e091f 100644 --- a/main.go +++ b/main.go @@ -218,8 +218,8 @@ func (s *DB) Model(value interface{}) *DB { return c } -func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreign_keys...).db +func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { + return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } func (s *DB) Pluck(column string, value interface{}) *DB { @@ -299,8 +299,8 @@ func (s *DB) DropColumn(column string) *DB { return s } -func (s *DB) AddIndex(column string, index_name ...string) *DB { - s.clone().NewScope(s.Value).addIndex(column, index_name...) +func (s *DB) AddIndex(column string, indexName ...string) *DB { + s.clone().NewScope(s.Value).addIndex(column, indexName...) return s } diff --git a/scope.go b/scope.go index bff009ed..fd100307 100644 --- a/scope.go +++ b/scope.go @@ -242,13 +242,13 @@ func (s *Scope) CombinedConditionSql() string { func (scope *Scope) SqlTagForField(field *Field) (tag string) { value := field.Value - reflect_value := reflect.ValueOf(value) + reflectValue := reflect.ValueOf(value) if field.IsScanner() { - value = reflect_value.Field(0).Interface() + value = reflectValue.Field(0).Interface() } - switch reflect_value.Kind() { + switch reflectValue.Kind() { case reflect.Slice: if _, ok := value.([]byte); !ok { return @@ -470,3 +470,70 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } return scope } + +func (scope *Scope) createTable() *Scope { + var sqls []string + for _, field := range scope.Fields() { + if !field.IsIgnored && len(field.SqlTag) > 0 { + sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) + } + } + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() + return scope +} + +func (scope *Scope) dropTable() *Scope { + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() + return scope +} + +func (scope *Scope) modifyColumn(column string, typ string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() +} + +func (scope *Scope) dropColumn(column string) { + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) addIndex(column string, names ...string) { + var indexName string + if len(names) > 0 { + indexName = names[0] + } else { + indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) + } + + scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec() +} + +func (scope *Scope) removeIndex(indexName string) { + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() +} + +func (scope *Scope) autoMigrate() *Scope { + var tableName string + scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) + scope.SqlVars = []interface{}{} + + // If table doesn't exist + if len(tableName) == 0 { + scope.createTable() + } else { + for _, field := range scope.Fields() { + var column, data string + scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", + scope.AddToVars(scope.TableName()), + scope.AddToVars(field.DBName), + )) + scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) + scope.SqlVars = []interface{}{} + + // If column doesn't exist + if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + } + } + } + return scope +} diff --git a/scope_condition.go b/scope_condition.go index 7e106844..14ed9033 100644 --- a/scope_condition.go +++ b/scope_condition.go @@ -52,11 +52,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri switch reflect.TypeOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) values := reflect.ValueOf(arg) - var temp_marks []string + var tempMarks []string for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() @@ -69,7 +69,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var not_equal_sql string + var notEqualSql string switch value := clause["query"].(type) { case string: @@ -78,10 +78,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) - not_equal_sql = fmt.Sprintf("NOT (%v)", value) + notEqualSql = fmt.Sprintf("NOT (%v)", value) } else { str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - not_equal_sql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } case int, int64, int32: return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) @@ -113,16 +113,16 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch reflect.TypeOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) values := reflect.ValueOf(arg) - var temp_marks []string + var tempMarks []string for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, scope.AddToVars(values.Index(i).Interface())) + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(not_equal_sql, "?", scope.AddToVars(arg), 1) + str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) } } return @@ -135,45 +135,45 @@ func (scope *Scope) where(where ...interface{}) { } func (scope *Scope) whereSql() (sql string) { - var primary_condiations, and_conditions, or_conditions []string + var primaryCondiations, andConditions, orConditions []string if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { - primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") + primaryCondiations = append(primaryCondiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") } if !scope.PrimaryKeyZero() { - primary_condiations = append(primary_condiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue()))) + primaryCondiations = append(primaryCondiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue()))) } for _, clause := range scope.Search.WhereConditions { - and_conditions = append(and_conditions, scope.buildWhereCondition(clause)) + andConditions = append(andConditions, scope.buildWhereCondition(clause)) } for _, clause := range scope.Search.OrConditions { - or_conditions = append(or_conditions, scope.buildWhereCondition(clause)) + orConditions = append(orConditions, scope.buildWhereCondition(clause)) } for _, clause := range scope.Search.NotConditions { - and_conditions = append(and_conditions, scope.buildNotCondition(clause)) + andConditions = append(andConditions, scope.buildNotCondition(clause)) } - or_sql := strings.Join(or_conditions, " OR ") - combined_sql := strings.Join(and_conditions, " AND ") - if len(combined_sql) > 0 { - if len(or_sql) > 0 { - combined_sql = combined_sql + " OR " + or_sql + orSql := strings.Join(orConditions, " OR ") + combinedSql := strings.Join(andConditions, " AND ") + if len(combinedSql) > 0 { + if len(orSql) > 0 { + combinedSql = combinedSql + " OR " + orSql } } else { - combined_sql = or_sql + combinedSql = orSql } - if len(primary_condiations) > 0 { - sql = "WHERE " + strings.Join(primary_condiations, " AND ") - if len(combined_sql) > 0 { - sql = sql + " AND (" + combined_sql + ")" + if len(primaryCondiations) > 0 { + sql = "WHERE " + strings.Join(primaryCondiations, " AND ") + if len(combinedSql) > 0 { + sql = sql + " AND (" + combinedSql + ")" } - } else if len(combined_sql) > 0 { - sql = "WHERE " + combined_sql + } else if len(combinedSql) > 0 { + sql = "WHERE " + combinedSql } return } diff --git a/scope_database.go b/scope_database.go deleted file mode 100644 index 3f17c0ff..00000000 --- a/scope_database.go +++ /dev/null @@ -1,73 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -func (scope *Scope) createTable() *Scope { - var sqls []string - for _, field := range scope.Fields() { - if !field.IsIgnored && len(field.SqlTag) > 0 { - sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) - } - } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(column string, names ...string) { - var indexName string - if len(names) > 0 { - indexName = names[0] - } else { - indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) - } - - scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() -} - -func (scope *Scope) autoMigrate() *Scope { - var tableName string - scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) - scope.SqlVars = []interface{}{} - - // If table doesn't exist - if len(tableName) == 0 { - scope.createTable() - } else { - for _, field := range scope.Fields() { - var column, data string - scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", - scope.AddToVars(scope.TableName()), - scope.AddToVars(field.DBName), - )) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) - scope.SqlVars = []interface{}{} - - // If column doesn't exist - if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() - } - } - } - return scope -}