Reflect do.go

This commit is contained in:
Jinzhu 2013-11-16 20:16:23 +08:00
parent 0ad707b410
commit 6404f803e8
3 changed files with 38 additions and 46 deletions

56
do.go
View File

@ -14,18 +14,17 @@ import (
)
type Do struct {
db *DB
search *search
model *Model
tableName string
usingUpdate bool
value interface{}
update_attrs map[string]interface{}
hasUpdate bool
ignoreProtectedAttrs bool
sql string
sqlVars []interface{}
startedTransaction bool
db *DB
search *search
model *Model
tableName string
value interface{}
usingUpdate bool
hasUpdate bool
update_attrs map[string]interface{}
sql string
sqlVars []interface{}
startedTransaction bool
}
func (s *Do) table() string {
@ -53,11 +52,7 @@ func (s *Do) err(err error) error {
func (s *Do) setModel(value interface{}) *Do {
s.model = &Model{data: value, do: s}
s.value = value
if s.db.search == nil {
s.search = &search{}
} else {
s.search = s.db.search
}
s.search = s.db.search
return s
}
@ -67,7 +62,9 @@ func (s *Do) addToVars(value interface{}) string {
}
func (s *Do) trace(t time.Time) {
s.db.slog(s.sql, t, s.sqlVars...)
if len(s.sql) > 0 {
s.db.slog(s.sql, t, s.sqlVars...)
}
}
func (s *Do) exec(sqls ...string) *Do {
@ -113,12 +110,11 @@ func (s *Do) saveBeforeAssociations() {
for _, field := range s.model.beforeAssociations() {
do := &Do{db: s.db}
reflect_value := reflect.ValueOf(field.Value)
if reflect_value.CanAddr() {
do.setModel(reflect_value.Addr().Interface()).save()
if field.reflectValue.CanAddr() {
do.setModel(field.reflectValue.Addr().Interface()).save()
} else {
// If can't take address, then clone the value and set it back
dest_value := reflect.New(reflect_value.Type()).Elem()
dest_value := reflect.New(field.reflectValue.Type()).Elem()
m := &Model{data: field.Value, do: s}
for _, f := range m.columnsHasValue("other") {
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
@ -170,6 +166,7 @@ func (s *Do) saveAfterAssociations() {
}
func (s *Do) create() (i interface{}) {
defer s.trace(time.Now())
s.model.callMethod("BeforeCreate")
s.model.callMethod("BeforeSave")
@ -178,8 +175,6 @@ func (s *Do) create() (i interface{}) {
if !s.db.hasError() {
var id interface{}
now := time.Now()
if s.dialect().SupportLastInsertId() {
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
id, err = sql_result.LastInsertId()
@ -188,7 +183,6 @@ func (s *Do) create() (i interface{}) {
} else {
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
}
s.db.slog(s.sql, now, s.sqlVars...)
if !s.db.hasError() {
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
@ -348,6 +342,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
}
func (s *Do) query() *Do {
defer s.trace(time.Now())
var (
is_slice bool
dest_type reflect.Type
@ -363,9 +358,7 @@ func (s *Do) query() *Do {
s.prepareQuerySql()
if !s.db.hasError() {
now := time.Now()
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) != nil {
return s
@ -406,17 +399,19 @@ func (s *Do) query() *Do {
}
func (s *Do) count(value interface{}) *Do {
defer s.trace(time.Now())
s.search = s.search.clone().selects("count(*)")
s.prepareQuerySql()
if !s.db.hasError() {
now := time.Now()
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
s.db.slog(s.sql, now, s.sqlVars...)
}
return s
}
func (s *Do) pluck(column string, value interface{}) *Do {
defer s.trace(time.Now())
dest_out := reflect.Indirect(reflect.ValueOf(value))
s.search = s.search.clone().selects(column)
if dest_out.Kind() != reflect.Slice {
@ -427,9 +422,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
s.prepareQuerySql()
if !s.db.hasError() {
now := time.Now()
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
s.db.slog(s.sql, now, s.sqlVars...)
if s.err(err) == nil {
defer rows.Close()
@ -653,7 +646,6 @@ func (s *Do) createTable() *Do {
}
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
s.exec()
return s
}

View File

@ -1242,6 +1242,7 @@ func (c Cart) TableName() string {
}
func TestTableName(t *testing.T) {
db := db.clone()
if db.do(Order{}).table() != "orders" {
t.Errorf("Order table name should be orders")
}
@ -1393,7 +1394,7 @@ func (s *CreditCard) BeforeSave() (err error) {
}
func BenchmarkGorm(b *testing.B) {
b.N = 5000
b.N = 2000
for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
@ -1416,7 +1417,7 @@ func BenchmarkRawSql(b *testing.B) {
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
delete_sql := "DELETE FROM orders WHERE id = $1"
b.N = 5000
b.N = 2000
for x := 0; x < b.N; x++ {
var id int64
e := strconv.Itoa(x) + "benchmark@example.org"

View File

@ -24,8 +24,7 @@ func (m *Model) primaryKeyZero() bool {
func (m *Model) primaryKeyValue() interface{} {
if data := m.reflectData(); data.Kind() == reflect.Struct {
field := data.FieldByName(m.primaryKey())
if data.FieldByName(m.primaryKey()).IsValid() {
if field := data.FieldByName(m.primaryKey()); field.IsValid() {
return field.Interface()
}
}
@ -133,7 +132,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{}
for _, field := range m.fields(operation) {
if !field.isPrimaryKey && (len(field.sqlTag()) > 0) {
if !field.isPrimaryKey && len(field.sqlTag()) > 0 {
results[field.dbName] = field.Value
}
}
@ -141,9 +140,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
}
func (m *Model) hasColumn(name string) bool {
data := m.reflectData()
if data.Kind() == reflect.Struct {
if data := m.reflectData(); data.Kind() == reflect.Struct {
return data.FieldByName(name).IsValid()
} else if data.Kind() == reflect.Slice {
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
@ -152,9 +149,7 @@ func (m *Model) hasColumn(name string) bool {
}
func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
data := m.reflectData()
if data.Kind() == reflect.Struct {
if data := m.reflectData(); data.Kind() == reflect.Struct {
if has_column = data.FieldByName(name).IsValid(); has_column {
value = data.FieldByName(name).Interface()
}
@ -165,15 +160,19 @@ func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, val
return
}
func (m *Model) typeName() string {
func (m *Model) typ() reflect.Type {
typ := m.reflectData().Type()
if typ.Kind() == reflect.Slice {
return typ.Elem().Name()
return typ.Elem()
} else {
return typ.Name()
return typ
}
}
func (m *Model) typeName() string {
return m.typ().Name()
}
func (m *Model) tableName() (str string) {
if m.data == nil {
m.do.err(errors.New("Model haven't been set"))