diff --git a/do.go b/do.go deleted file mode 100644 index e7bfb2b4..00000000 --- a/do.go +++ /dev/null @@ -1,774 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - - "github.com/jinzhu/gorm/dialect" - - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -type Do struct { - db *DB - search *search - model *Model - tableName string - value interface{} - usingUpdate bool - hasUpdate bool - update_attrs map[string]interface{} - sql string - sqlVars []interface{} - startedTransaction bool -} - -func (s *Do) setSql(sql string) { - s.sql = strings.Replace(sql, "$$", "?", -1) -} - -func (s *Do) table() string { - if len(s.tableName) == 0 { - if len(s.search.tableName) == 0 { - s.tableName = s.model.tableName() - } else { - s.tableName = s.search.tableName - } - } - return s.tableName -} - -func (s *Do) dialect() dialect.Dialect { - return s.db.parent.dialect -} - -func (s *Do) quote(str string) string { - return s.dialect().Quote(str) -} - -func (s *Do) err(err error) error { - if err != nil { - s.db.err(err) - } - return err -} - -func (s *Do) setModel(value interface{}) *Do { - s.model = &Model{data: value, do: s} - s.value = value - s.search = s.db.search - return s -} - -func (s *Do) addToVars(value interface{}) string { - s.sqlVars = append(s.sqlVars, value) - return s.dialect().BinVar(len(s.sqlVars)) -} - -func (s *Do) trace(t time.Time) { - if len(s.sql) > 0 { - s.db.slog(s.sql, t, s.sqlVars...) - } -} - -func (s *Do) raw(query string, values ...interface{}) *Do { - s.setSql(s.buildWhereCondition(map[string]interface{}{"query": query, "args": values})) - return s -} - -func (s *Do) exec() *Do { - defer s.trace(time.Now()) - if !s.db.hasError() { - _, err := s.db.db.Exec(s.sql, s.sqlVars...) - s.err(err) - } - return s -} - -func (s *Do) save() *Do { - if s.model.primaryKeyZero() { - s.create() - } else { - s.update() - } - return s -} - -func (s *Do) prepareCreateSql() { - var sqls, columns []string - - for key, value := range s.model.columnsAndValues("create") { - columns = append(columns, s.quote(key)) - sqls = append(sqls, s.addToVars(value)) - } - - s.setSql(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v) %v", - s.table(), - strings.Join(columns, ","), - strings.Join(sqls, ","), - s.dialect().ReturningStr(s.model.primaryKeyDb()), - )) - return -} - -func (s *Do) saveBeforeAssociations() { - for _, field := range s.model.beforeAssociations() { - do := &Do{db: s.db} - - if field.reflectValue.CanAddr() { - do.setModel(field.reflectValue.Addr().Interface()).save() - } else { - // If can't take address, then clone the value and set it back - dest_value := reflect.New(field.reflectValue.Type()).Elem() - m := &Model{data: field.Value, do: s} - for _, f := range m.columnsHasValue("other") { - dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) - } - do.setModel(dest_value.Addr().Interface()).save() - m.setValueByColumn(field.Name, dest_value.Interface(), s.value) - } - - if len(field.foreignKey) > 0 { - s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data) - } - } -} - -func (s *Do) saveAfterAssociations() { - for _, field := range s.model.afterAssociations() { - reflect_value := reflect.ValueOf(field.Value) - - switch reflect_value.Kind() { - case reflect.Slice: - for i := 0; i < reflect_value.Len(); i++ { - do := &Do{db: s.db} - - value := reflect_value.Index(i).Addr().Interface() - if len(field.foreignKey) > 0 { - s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value) - } - do.setModel(value).save() - } - default: - do := &Do{db: s.db} - if reflect_value.CanAddr() { - s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), field.Value) - do.setModel(field.Value).save() - } else { - dest_value := reflect.New(reflect.TypeOf(field.Value)).Elem() - m := &Model{data: field.Value, do: s} - for _, f := range m.columnsHasValue("other") { - dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) - } - - setFieldValue(dest_value.FieldByName(field.foreignKey), s.model.primaryKeyValue()) - do.setModel(dest_value.Addr().Interface()).save() - - m.setValueByColumn(field.Name, dest_value.Interface(), s.value) - } - } - } -} - -func (s *Do) create() (i interface{}) { - defer s.trace(time.Now()) - s.model.callMethod("BeforeSave") - s.model.callMethod("BeforeCreate") - - s.saveBeforeAssociations() - s.prepareCreateSql() - - if !s.db.hasError() { - var id interface{} - if s.dialect().SupportLastInsertId() { - if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil { - id, err = sql_result.LastInsertId() - s.err(err) - } - } else { - s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) - } - - if !s.db.hasError() { - s.model.setValueByColumn(s.model.primaryKey(), id, s.value) - - s.saveAfterAssociations() - s.model.callMethod("AfterCreate") - s.model.callMethod("AfterSave") - } - return id - } - - return -} - -func (s *Do) convertToMapInterface(values interface{}) map[string]interface{} { - attrs := map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - for k, v := range value { - attrs[toSnake(k)] = v - } - case []interface{}: - for _, v := range value { - for key, value := range s.convertToMapInterface(v) { - attrs[key] = value - } - } - case interface{}: - reflect_value := reflect.ValueOf(values) - - switch reflect_value.Kind() { - case reflect.Map: - for _, key := range reflect_value.MapKeys() { - attrs[toSnake(key.Interface().(string))] = reflect_value.MapIndex(key).Interface() - } - default: - m := &Model{data: values, do: s} - for _, field := range m.columnsHasValue("other") { - attrs[field.dbName] = field.Value - } - } - } - return attrs -} - -func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do { - ignore_protected := len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] - s.usingUpdate = true - - if maps := s.convertToMapInterface(values); len(maps) > 0 { - results, has_update := s.model.updatedColumnsAndValues(maps, ignore_protected) - if len(results) > 0 { - s.update_attrs = results - } - s.hasUpdate = has_update - } - return s -} - -func (s *Do) prepareUpdateSql(include_self bool) { - var sqls []string - for key, value := range s.update_attrs { - sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) - } - - if include_self { - data := s.model.reflectData() - if data.CanAddr() { - for key, value := range s.model.columnsAndValues("update") { - sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) - } - } - } - - s.setSql(fmt.Sprintf( - "UPDATE %v SET %v %v", - s.table(), - strings.Join(sqls, ", "), - s.combinedSql(), - )) - return -} - -func (s *Do) updateColumns(value interface{}) *Do { - s.update_attrs = s.convertToMapInterface(value) - s.prepareUpdateSql(false) - if !s.db.hasError() { - s.exec() - s.updateAttrs(s.update_attrs) - } - return s -} - -func (s *Do) update() *Do { - if s.usingUpdate && !s.hasUpdate { - return s - } - - s.model.callMethod("BeforeSave") - s.model.callMethod("BeforeUpdate") - s.saveBeforeAssociations() - - s.prepareUpdateSql(true) - - if !s.db.hasError() { - s.exec() - s.saveAfterAssociations() - - s.model.callMethod("AfterUpdate") - s.model.callMethod("AfterSave") - } - - return s -} - -func (s *Do) prepareQuerySql() { - if s.search.raw { - s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE ")) - } else { - s.setSql(fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())) - } - return -} - -func (s *Do) first() *Do { - s.search = s.search.clone().order(s.model.primaryKeyDb()).limit(1) - s.query() - return s -} - -func (s *Do) last() *Do { - s.search = s.search.clone().order(s.model.primaryKeyDb() + " DESC").limit(1) - s.query() - return s -} - -func (s *Do) getForeignKey(from *Model, to *Model, foreign_key string) (err error, from_from bool, foreign_value interface{}) { - if has_column, is_slice, value := from.columnAndValue(foreign_key); has_column { - from_from = true - if is_slice { - foreign_value = to.primaryKeyValue() - } else { - foreign_value = value - } - } else if has_column, _, _ := to.columnAndValue(foreign_key); has_column { - foreign_value = from.primaryKeyValue() - } else { - err = errors.New("Can't find valid foreign Key") - } - return -} - -func (s *Do) related(value interface{}, foreign_keys ...string) *Do { - var foreign_value interface{} - var from_from bool - var foreign_key string - var err error - - from := &Model{data: value, do: s} - to := &Model{data: s.value, do: s} - foreign_keys = append(foreign_keys, from.typeName()+"Id", to.typeName()+"Id") - - for _, fk := range foreign_keys { - err, from_from, foreign_value = s.getForeignKey(from, to, snakeToUpperCamel(fk)) - if err == nil { - foreign_key = fk - break - } - } - - if from_from { - s.where(foreign_value).query() - } else { - query := fmt.Sprintf("%v = %v", s.quote(toSnake(foreign_key)), s.addToVars(foreign_value)) - s.where(query).query() - } - return s -} - -func (s *Do) row() *sql.Row { - defer s.trace(time.Now()) - s.prepareQuerySql() - return s.db.db.QueryRow(s.sql, s.sqlVars...) -} - -func (s *Do) rows() (*sql.Rows, error) { - defer s.trace(time.Now()) - s.prepareQuerySql() - return s.db.db.Query(s.sql, s.sqlVars...) -} - -func (s *Do) query(dests ...interface{}) *Do { - defer s.trace(time.Now()) - var ( - is_slice bool - dest_type reflect.Type - ) - var dest_out reflect.Value - if len(dests) > 0 { - dest_out = reflect.Indirect(reflect.ValueOf(dests[0])) - } else { - dest_out = reflect.Indirect(reflect.ValueOf(s.value)) - } - - if dest_out.Kind() == reflect.Slice { - is_slice = true - dest_type = dest_out.Type().Elem() - } else { - s.search = s.search.clone().limit(1) - } - - s.prepareQuerySql() - if !s.db.hasError() { - rows, err := s.db.db.Query(s.sql, s.sqlVars...) - - if s.err(err) != nil { - return s - } - - defer rows.Close() - var has_record bool - for rows.Next() { - has_record = true - dest := dest_out - if is_slice { - dest = reflect.New(dest_type).Elem() - } - - columns, _ := rows.Columns() - var values []interface{} - for _, value := range columns { - field := dest.FieldByName(snakeToUpperCamel(value)) - if field.IsValid() { - values = append(values, field.Addr().Interface()) - } else { - var ignore interface{} - values = append(values, &ignore) - } - } - s.err(rows.Scan(values...)) - - m := &Model{data: dest.Addr().Interface(), do: s} - m.callMethod("AfterFind") - if is_slice { - dest_out.Set(reflect.Append(dest_out, dest)) - } - } - - if !has_record && !is_slice { - s.err(RecordNotFound) - } - } - return s -} - -func (s *Do) count(value interface{}) *Do { - s.search = s.search.clone().selects("count(*)") - s.err(s.row().Scan(value)) - return s -} - -func (s *Do) pluck(column string, value interface{}) *Do { - dest_out := reflect.Indirect(reflect.ValueOf(value)) - s.search = s.search.clone().selects(column) - if dest_out.Kind() != reflect.Slice { - s.err(errors.New("Results should be a slice")) - return s - } - - rows, err := s.rows() - if s.err(err) == nil { - defer rows.Close() - for rows.Next() { - dest := reflect.New(dest_out.Type().Elem()).Interface() - s.err(rows.Scan(dest)) - dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest).Elem())) - } - } - return s -} - -func (s *Do) primaryCondiation(value interface{}) string { - return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value) -} - -func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return s.primaryCondiation(s.addToVars(id)) - } else { - str = value - } - case int, int64, int32: - return s.primaryCondiation(s.addToVars(value)) - case sql.NullInt64: - return s.primaryCondiation(s.addToVars(value.Int64)) - case []int64, []int, []int32, []string: - str = fmt.Sprintf("(%v in (?))", s.quote(s.model.primaryKeyDb())) - clause["args"] = []interface{}{value} - case map[string]interface{}: - var sqls []string - for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(key), s.addToVars(value))) - } - return strings.Join(sqls, " AND ") - case interface{}: - m := &Model{data: value, do: s} - var sqls []string - for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(field.dbName), s.addToVars(field.Value))) - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.TypeOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - values := reflect.ValueOf(arg) - var temp_marks []string - for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - - str = strings.Replace(str, "?", s.addToVars(arg), 1) - } - } - return -} - -func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { - var not_equal_sql string - - switch value := clause["query"].(type) { - case string: - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - not_equal_sql = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v NOT IN (?))", s.quote(value)) - not_equal_sql = fmt.Sprintf("(%v <> ?)", s.quote(value)) - } - case int, int64, int32: - return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), value) - case []int64, []int, []int32, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v not in (?))", s.quote(s.model.primaryKeyDb())) - clause["args"] = []interface{}{value} - } else { - return "" - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(key), s.addToVars(value))) - } - return strings.Join(sqls, " AND ") - case interface{}: - m := &Model{data: value, do: s} - var sqls []string - for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(field.dbName), s.addToVars(field.Value))) - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.TypeOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - values := reflect.ValueOf(arg) - var temp_marks []string - for i := 0; i < values.Len(); i++ { - temp_marks = append(temp_marks, s.addToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1) - } - } - return -} - -func (s *Do) where(where ...interface{}) *Do { - if len(where) > 0 { - s.search = s.search.clone().where(where[0], where[1:]...) - } - return s -} - -func (s *Do) whereSql() (sql string) { - var primary_condiations, and_conditions, or_conditions []string - - if !s.search.unscope && s.model.hasColumn("DeletedAt") { - primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')") - } - - if !s.model.primaryKeyZero() { - primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue()))) - } - - for _, clause := range s.search.whereClause { - and_conditions = append(and_conditions, s.buildWhereCondition(clause)) - } - - for _, clause := range s.search.orClause { - or_conditions = append(or_conditions, s.buildWhereCondition(clause)) - } - - for _, clause := range s.search.notClause { - and_conditions = append(and_conditions, s.buildNotCondition(clause)) - } - - or_sql := strings.Join(or_conditions, " OR ") - combined_sql := strings.Join(and_conditions, " AND ") - if len(combined_sql) > 0 { - if len(or_sql) > 0 { - combined_sql = combined_sql + " OR " + or_sql - } - } else { - combined_sql = or_sql - } - - if len(primary_condiations) > 0 { - sql = "WHERE " + strings.Join(primary_condiations, " AND ") - if len(combined_sql) > 0 { - sql = sql + " AND (" + combined_sql + ")" - } - } else if len(combined_sql) > 0 { - sql = "WHERE " + combined_sql - } - return -} - -func (s *Do) selectSql() string { - if len(s.search.selectStr) == 0 { - return "*" - } else { - return s.search.selectStr - } -} - -func (s *Do) orderSql() string { - if len(s.search.orders) == 0 { - return "" - } else { - return " ORDER BY " + strings.Join(s.search.orders, ",") - } -} - -func (s *Do) limitSql() string { - if len(s.search.limitStr) == 0 { - return "" - } else { - return " LIMIT " + s.search.limitStr - } -} - -func (s *Do) offsetSql() string { - if len(s.search.offsetStr) == 0 { - return "" - } else { - return " OFFSET " + s.search.offsetStr - } -} - -func (s *Do) groupSql() string { - if len(s.search.groupStr) == 0 { - return "" - } else { - return " GROUP BY " + s.search.groupStr - } -} - -func (s *Do) havingSql() string { - if s.search.havingClause == nil { - return "" - } else { - return " HAVING " + s.buildWhereCondition(s.search.havingClause) - } -} - -func (s *Do) joinsSql() string { - return s.search.joinsStr + " " -} - -func (s *Do) combinedSql() string { - return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql() -} - -func (s *Do) createTable() *Do { - var sqls []string - for _, field := range s.model.fields("migration") { - if len(field.sqlTag()) > 0 { - sqls = append(sqls, s.quote(field.dbName)+" "+field.sqlTag()) - } - } - s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))) - s.exec() - return s -} - -func (s *Do) dropTable() *Do { - s.setSql(fmt.Sprintf("DROP TABLE %v", s.table())) - s.exec() - return s -} - -func (s *Do) modifyColumn(column string, typ string) { - s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), s.quote(column), typ)) - s.exec() -} - -func (s *Do) dropColumn(column string) { - s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), s.quote(column))) - s.exec() -} - -func (s *Do) addIndex(column string, names ...string) { - var index_name string - if len(names) > 0 { - index_name = names[0] - } else { - index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column) - } - - s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), s.quote(column))) - s.exec() -} - -func (s *Do) removeIndex(index_name string) { - s.setSql(fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())) - s.exec() -} - -func (s *Do) autoMigrate() *Do { - var table_name string - s.setSql(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))) - s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&table_name) - s.sqlVars = []interface{}{} - - // If table doesn't exist - if len(table_name) == 0 { - s.createTable() - } else { - for _, field := range s.model.fields("migration") { - var column_name, data_type string - s.setSql(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", - s.addToVars(s.table()), - s.addToVars(field.dbName), - )) - s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&column_name, &data_type) - s.sqlVars = []interface{}{} - - // If column doesn't exist - if len(column_name) == 0 && len(field.sqlTag()) > 0 { - s.setSql(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())) - s.exec() - } - } - } - return s -} diff --git a/field.go b/field.go index e394f2de..e3c05b35 100644 --- a/field.go +++ b/field.go @@ -21,20 +21,7 @@ type Field struct { ForeignKey string BeforeAssociation bool AfterAssociation bool - - foreignKey string - beforeAssociation bool - afterAssociation bool - - dbName string - model *Model - isBlank bool - ignoreField bool - isPrimaryKey bool - autoCreateTime bool - autoUpdateTime bool - reflectValue reflect.Value - structField reflect.StructField + isPrimaryKey bool } func (f *Field) IsScanner() bool { @@ -47,97 +34,6 @@ func (f *Field) IsTime() bool { return is_time } -func (f *Field) parseAssociation() { - elem := reflect.Indirect(reflect.ValueOf(f.Value)) - typ := elem.Type() - - switch elem.Kind() { - case reflect.Slice: - typ = typ.Elem() - - if _, ok := f.Value.([]byte); !ok { - foreignKey := typ.Name() + "Id" - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - f.foreignKey = foreignKey - } - f.afterAssociation = true - } - case reflect.Struct: - if !f.IsTime() && !f.IsScanner() { - if elem.FieldByName(f.Name + "Id").IsValid() { - f.foreignKey = f.Name + "Id" - f.beforeAssociation = true - } else { - foreignKey := typ.Name() + "Id" - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - f.foreignKey = foreignKey - } - f.afterAssociation = true - } - } - } -} - -func (f *Field) parseBlank() { - f.isBlank = isBlank(f.reflectValue) -} - -func (f *Field) parseIgnore() { - typ, _, _ := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier)) - - if typ == "-" { - f.ignoreField = true - } -} - -func (f *Field) isScanner() bool { - _, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) - return is_scanner -} - -func (f *Field) isTime() bool { - _, is_time := f.Value.(time.Time) - return is_time -} - -func (f *Field) sqlTag() (str string) { - value := f.Value - if f.isScanner() { - value = f.reflectValue.Field(0).Interface() - } - reflect_value := f.reflectValue - - switch reflect_value.Kind() { - case reflect.Slice: - if _, ok := f.Value.([]byte); !ok { - return - } - case reflect.Struct: - if !f.isTime() && !f.isScanner() { - return - } - } - - typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(f.model.do.db.parent.tagIdentifier)) - - if typ == "-" { - return - } - - if len(typ) == 0 { - if f.isPrimaryKey { - typ = f.model.do.dialect().PrimaryKeyTag(value, size) - } else { - typ = f.model.do.dialect().SqlTag(value, size) - } - } - - if len(addational_typ) > 0 { - typ = typ + " " + addational_typ - } - return typ -} - func parseSqlTag(str string) (typ string, addational_typ string, size int) { if str == "-" { typ = str diff --git a/gorm_test.go b/gorm_test.go index e3954577..c90631f4 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1340,53 +1340,52 @@ func (c Cart) TableName() string { } func TestTableName(t *testing.T) { - db := db.clone() - if db.do(Order{}).table() != "orders" { + if db.NewScope(Order{}).TableName() != "orders" { t.Errorf("Order's table name should be orders") } - if db.do(&Order{}).table() != "orders" { + if db.NewScope(&Order{}).TableName() != "orders" { t.Errorf("&Order's table name should be orders") } - if db.do([]Order{}).table() != "orders" { + if db.NewScope([]Order{}).TableName() != "orders" { t.Errorf("[]Order's table name should be orders") } - if db.do(&[]Order{}).table() != "orders" { + if db.NewScope(&[]Order{}).TableName() != "orders" { t.Errorf("&[]Order's table name should be orders") } db.SingularTable(true) - if db.do(Order{}).table() != "order" { + if db.NewScope(Order{}).TableName() != "order" { t.Errorf("Order's singular table name should be order") } - if db.do(&Order{}).table() != "order" { + if db.NewScope(&Order{}).TableName() != "order" { t.Errorf("&Order's singular table name should be order") } - if db.do([]Order{}).table() != "order" { + if db.NewScope([]Order{}).TableName() != "order" { t.Errorf("[]Order's singular table name should be order") } - if db.do(&[]Order{}).table() != "order" { + if db.NewScope(&[]Order{}).TableName() != "order" { t.Errorf("&[]Order's singular table name should be order") } - if db.do(&Cart{}).table() != "shopping_cart" { + if db.NewScope(&Cart{}).TableName() != "shopping_cart" { t.Errorf("&Cart's singular table name should be shopping_cart") } - if db.do(Cart{}).table() != "shopping_cart" { + if db.NewScope(Cart{}).TableName() != "shopping_cart" { t.Errorf("Cart's singular table name should be shopping_cart") } - if db.do(&[]Cart{}).table() != "shopping_cart" { + if db.NewScope(&[]Cart{}).TableName() != "shopping_cart" { t.Errorf("&[]Cart's singular table name should be shopping_cart") } - if db.do([]Cart{}).table() != "shopping_cart" { + if db.NewScope([]Cart{}).TableName() != "shopping_cart" { t.Errorf("[]Cart's singular table name should be shopping_cart") } db.SingularTable(false) diff --git a/model.go b/model.go deleted file mode 100644 index cd7bd6f6..00000000 --- a/model.go +++ /dev/null @@ -1,287 +0,0 @@ -package gorm - -import ( - "go/ast" - "reflect" - "regexp" - "strconv" - "time" -) - -var modelFieldMap = map[string][]reflect.StructField{} - -type Model struct { - data interface{} - do *Do - _cache_fields map[string][]*Field -} - -func (m *Model) reflectData() reflect.Value { - return reflect.Indirect(reflect.ValueOf(m.data)) -} - -func (m *Model) primaryKeyZero() bool { - return isBlank(reflect.ValueOf(m.primaryKeyValue())) -} - -func (m *Model) primaryKeyValue() interface{} { - if data := m.reflectData(); data.Kind() == reflect.Struct { - if field := data.FieldByName(m.primaryKey()); field.IsValid() { - return field.Interface() - } - } - return 0 -} - -func (m *Model) primaryKey() string { - return "Id" -} - -func (m *Model) primaryKeyDb() string { - return toSnake(m.primaryKey()) -} - -func getStructs(typ reflect.Type) (fs []reflect.StructField) { - name := typ.Name() - if fs = modelFieldMap[name]; fs != nil { - return - } - - for i := 0; i < typ.NumField(); i++ { - p := typ.Field(i) - if !p.Anonymous && ast.IsExported(p.Name) { - fs = append(fs, p) - } - } - - modelFieldMap[name] = fs - return -} - -func (m *Model) fields(operation string) (fields []*Field) { - if len(m._cache_fields[operation]) > 0 { - return m._cache_fields[operation] - } - - indirect_value := m.reflectData() - if !indirect_value.IsValid() { - return - } - - structs := getStructs(indirect_value.Type()) - c := make(chan *Field, len(structs)) - defer close(c) - - for _, field_struct := range structs { - go func(field_struct reflect.StructField, c chan *Field) { - var field Field - field.Name = field_struct.Name - field.dbName = toSnake(field_struct.Name) - field.isPrimaryKey = m.primaryKeyDb() == field.dbName - value := indirect_value.FieldByName(field_struct.Name) - field.model = m - - if time_value, is_time := value.Interface().(time.Time); is_time { - field.autoCreateTime = "created_at" == field.dbName - field.autoUpdateTime = "updated_at" == field.dbName - - switch operation { - case "create": - if (field.autoCreateTime || field.autoUpdateTime) && time_value.IsZero() { - value.Set(reflect.ValueOf(time.Now())) - } - case "update": - if field.autoUpdateTime { - value.Set(reflect.ValueOf(time.Now())) - } - } - } - field.structField = field_struct - field.reflectValue = value - field.Value = value.Interface() - field.parseAssociation() - field.parseBlank() - field.parseIgnore() - c <- &field - }(field_struct, c) - } - - for i := 0; i < len(structs); i++ { - fields = append(fields, <-c) - } - - if len(m._cache_fields) == 0 { - m._cache_fields = map[string][]*Field{} - } - m._cache_fields[operation] = fields - return -} - -func (m *Model) columnsHasValue(operation string) (fields []*Field) { - for _, field := range m.fields(operation) { - if !field.isBlank { - fields = append(fields, field) - } - } - return -} - -func (m *Model) updatedColumnsAndValues(values map[string]interface{}, ignore_protected_attrs bool) (results map[string]interface{}, any_updated bool) { - data := m.reflectData() - if !data.CanAddr() { - return values, true - } - - for key, value := range values { - if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() { - if field.Interface() != value { - switch field.Kind() { - case reflect.Int, reflect.Int32, reflect.Int64: - if s, ok := value.(string); ok { - i, err := strconv.Atoi(s) - if m.do.err(err) == nil { - value = i - } - } - - if field.Int() != reflect.ValueOf(value).Int() { - any_updated = true - field.SetInt(reflect.ValueOf(value).Int()) - } - default: - any_updated = true - field.Set(reflect.ValueOf(value)) - } - } - } - } - - if values["updated_at"] != nil && any_updated { - setFieldValue(data.FieldByName("UpdatedAt"), time.Now()) - } - return -} - -func (m *Model) columnsAndValues(operation string) map[string]interface{} { - results := map[string]interface{}{} - - for _, field := range m.fields(operation) { - if !field.isPrimaryKey && len(field.sqlTag()) > 0 { - results[field.dbName] = field.Value - } - } - return results -} - -func (m *Model) hasColumn(name string) bool { - if data := m.reflectData(); data.Kind() == reflect.Struct { - return data.FieldByName(name).IsValid() - } else if data.Kind() == reflect.Slice { - return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() - } - return false -} - -func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { - if data := m.reflectData(); data.Kind() == reflect.Struct { - if has_column = data.FieldByName(name).IsValid(); has_column { - value = data.FieldByName(name).Interface() - } - } else if data.Kind() == reflect.Slice { - has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() - is_slice = true - } - return -} - -func (m *Model) typ() reflect.Type { - typ := m.reflectData().Type() - if typ.Kind() == reflect.Slice { - return typ.Elem() - } else { - return typ - } -} - -func (m *Model) typeName() string { - return m.typ().Name() -} - -func (m *Model) tableName() (str string) { - if m.data == nil { - return - } - - data := m.reflectData() - - if data.Kind() == reflect.Slice { - data = reflect.New(data.Type().Elem()).Elem() - } - - if fm := data.MethodByName("TableName"); fm.IsValid() { - if v := fm.Call([]reflect.Value{}); len(v) > 0 { - if result, ok := v[0].Interface().(string); ok { - return result - } - } - } - - str = toSnake(m.typeName()) - - if !m.do.db.parent.singularTable { - pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} - for key, value := range pluralMap { - reg := regexp.MustCompile(key + "$") - if reg.MatchString(str) { - return reg.ReplaceAllString(str, value) - } - } - } - - return -} - -func (m *Model) callMethod(method string) { - if m.data == nil || m.do.db.hasError() { - return - } - - if fm := reflect.ValueOf(m.data).MethodByName(method); fm.IsValid() { - numin := fm.Type().NumIn() - var results []reflect.Value - if numin == 0 { - results = fm.Call([]reflect.Value{}) - } else if numin == 1 { - results = fm.Call([]reflect.Value{reflect.ValueOf(m.do.db.new())}) - } - if len(results) > 0 { - if verr, ok := results[0].Interface().(error); ok { - m.do.err(verr) - } - } - } - return -} - -func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { - data := reflect.Indirect(reflect.ValueOf(out)) - setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) -} - -func (m *Model) beforeAssociations() (fields []*Field) { - for _, field := range m.fields("null") { - if field.beforeAssociation && !field.isBlank && !field.ignoreField { - fields = append(fields, field) - } - } - return -} - -func (m *Model) afterAssociations() (fields []*Field) { - for _, field := range m.fields("null") { - if field.afterAssociation && !field.isBlank && !field.ignoreField { - fields = append(fields, field) - } - } - return -} diff --git a/private.go b/private.go index 4990bdbd..2b6bceb0 100644 --- a/private.go +++ b/private.go @@ -23,13 +23,6 @@ func (s *DB) new() *DB { return s.clone() } -func (s *DB) do(data interface{}) *Do { - s.Value = data - do := Do{db: s} - do.setModel(data) - return &do -} - func (s *DB) err(err error) error { if err != nil { if err != RecordNotFound { diff --git a/scope.go b/scope.go index c9c5c131..1b8a0e19 100644 --- a/scope.go +++ b/scope.go @@ -195,7 +195,7 @@ func (scope *Scope) AddToVars(value interface{}) string { } func (scope *Scope) TableName() string { - if len(scope.Search.tableName) > 0 { + if scope.Search != nil && len(scope.Search.tableName) > 0 { return scope.Search.tableName } else { if scope.Value == nil {