mirror of https://github.com/go-gorm/gorm.git
Reflect do.go
This commit is contained in:
parent
0ad707b410
commit
6404f803e8
56
do.go
56
do.go
|
@ -14,18 +14,17 @@ import (
|
|||
)
|
||||
|
||||
type Do struct {
|
||||
db *DB
|
||||
search *search
|
||||
model *Model
|
||||
tableName string
|
||||
usingUpdate bool
|
||||
value interface{}
|
||||
update_attrs map[string]interface{}
|
||||
hasUpdate bool
|
||||
ignoreProtectedAttrs bool
|
||||
sql string
|
||||
sqlVars []interface{}
|
||||
startedTransaction bool
|
||||
db *DB
|
||||
search *search
|
||||
model *Model
|
||||
tableName string
|
||||
value interface{}
|
||||
usingUpdate bool
|
||||
hasUpdate bool
|
||||
update_attrs map[string]interface{}
|
||||
sql string
|
||||
sqlVars []interface{}
|
||||
startedTransaction bool
|
||||
}
|
||||
|
||||
func (s *Do) table() string {
|
||||
|
@ -53,11 +52,7 @@ func (s *Do) err(err error) error {
|
|||
func (s *Do) setModel(value interface{}) *Do {
|
||||
s.model = &Model{data: value, do: s}
|
||||
s.value = value
|
||||
if s.db.search == nil {
|
||||
s.search = &search{}
|
||||
} else {
|
||||
s.search = s.db.search
|
||||
}
|
||||
s.search = s.db.search
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -67,7 +62,9 @@ func (s *Do) addToVars(value interface{}) string {
|
|||
}
|
||||
|
||||
func (s *Do) trace(t time.Time) {
|
||||
s.db.slog(s.sql, t, s.sqlVars...)
|
||||
if len(s.sql) > 0 {
|
||||
s.db.slog(s.sql, t, s.sqlVars...)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) exec(sqls ...string) *Do {
|
||||
|
@ -113,12 +110,11 @@ func (s *Do) saveBeforeAssociations() {
|
|||
for _, field := range s.model.beforeAssociations() {
|
||||
do := &Do{db: s.db}
|
||||
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
if reflect_value.CanAddr() {
|
||||
do.setModel(reflect_value.Addr().Interface()).save()
|
||||
if field.reflectValue.CanAddr() {
|
||||
do.setModel(field.reflectValue.Addr().Interface()).save()
|
||||
} else {
|
||||
// If can't take address, then clone the value and set it back
|
||||
dest_value := reflect.New(reflect_value.Type()).Elem()
|
||||
dest_value := reflect.New(field.reflectValue.Type()).Elem()
|
||||
m := &Model{data: field.Value, do: s}
|
||||
for _, f := range m.columnsHasValue("other") {
|
||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
|
@ -170,6 +166,7 @@ func (s *Do) saveAfterAssociations() {
|
|||
}
|
||||
|
||||
func (s *Do) create() (i interface{}) {
|
||||
defer s.trace(time.Now())
|
||||
s.model.callMethod("BeforeCreate")
|
||||
s.model.callMethod("BeforeSave")
|
||||
|
||||
|
@ -178,8 +175,6 @@ func (s *Do) create() (i interface{}) {
|
|||
|
||||
if !s.db.hasError() {
|
||||
var id interface{}
|
||||
|
||||
now := time.Now()
|
||||
if s.dialect().SupportLastInsertId() {
|
||||
if sql_result, err := s.db.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
|
||||
id, err = sql_result.LastInsertId()
|
||||
|
@ -188,7 +183,6 @@ func (s *Do) create() (i interface{}) {
|
|||
} else {
|
||||
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
}
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if !s.db.hasError() {
|
||||
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
||||
|
@ -348,6 +342,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
|
|||
}
|
||||
|
||||
func (s *Do) query() *Do {
|
||||
defer s.trace(time.Now())
|
||||
var (
|
||||
is_slice bool
|
||||
dest_type reflect.Type
|
||||
|
@ -363,9 +358,7 @@ func (s *Do) query() *Do {
|
|||
|
||||
s.prepareQuerySql()
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if s.err(err) != nil {
|
||||
return s
|
||||
|
@ -406,17 +399,19 @@ func (s *Do) query() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) count(value interface{}) *Do {
|
||||
defer s.trace(time.Now())
|
||||
|
||||
s.search = s.search.clone().selects("count(*)")
|
||||
s.prepareQuerySql()
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
s.err(s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(value))
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) pluck(column string, value interface{}) *Do {
|
||||
defer s.trace(time.Now())
|
||||
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
s.search = s.search.clone().selects(column)
|
||||
if dest_out.Kind() != reflect.Slice {
|
||||
|
@ -427,9 +422,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
|
|||
s.prepareQuerySql()
|
||||
|
||||
if !s.db.hasError() {
|
||||
now := time.Now()
|
||||
rows, err := s.db.db.Query(s.sql, s.sqlVars...)
|
||||
s.db.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if s.err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
@ -653,7 +646,6 @@ func (s *Do) createTable() *Do {
|
|||
}
|
||||
|
||||
s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))
|
||||
|
||||
s.exec()
|
||||
return s
|
||||
}
|
||||
|
|
|
@ -1242,6 +1242,7 @@ func (c Cart) TableName() string {
|
|||
}
|
||||
|
||||
func TestTableName(t *testing.T) {
|
||||
db := db.clone()
|
||||
if db.do(Order{}).table() != "orders" {
|
||||
t.Errorf("Order table name should be orders")
|
||||
}
|
||||
|
@ -1393,7 +1394,7 @@ func (s *CreditCard) BeforeSave() (err error) {
|
|||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
b.N = 5000
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
|
@ -1416,7 +1417,7 @@ func BenchmarkRawSql(b *testing.B) {
|
|||
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||
delete_sql := "DELETE FROM orders WHERE id = $1"
|
||||
|
||||
b.N = 5000
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
var id int64
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
|
|
23
model.go
23
model.go
|
@ -24,8 +24,7 @@ func (m *Model) primaryKeyZero() bool {
|
|||
|
||||
func (m *Model) primaryKeyValue() interface{} {
|
||||
if data := m.reflectData(); data.Kind() == reflect.Struct {
|
||||
field := data.FieldByName(m.primaryKey())
|
||||
if data.FieldByName(m.primaryKey()).IsValid() {
|
||||
if field := data.FieldByName(m.primaryKey()); field.IsValid() {
|
||||
return field.Interface()
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +132,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
results := map[string]interface{}{}
|
||||
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.isPrimaryKey && (len(field.sqlTag()) > 0) {
|
||||
if !field.isPrimaryKey && len(field.sqlTag()) > 0 {
|
||||
results[field.dbName] = field.Value
|
||||
}
|
||||
}
|
||||
|
@ -141,9 +140,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
}
|
||||
|
||||
func (m *Model) hasColumn(name string) bool {
|
||||
data := m.reflectData()
|
||||
|
||||
if data.Kind() == reflect.Struct {
|
||||
if data := m.reflectData(); data.Kind() == reflect.Struct {
|
||||
return data.FieldByName(name).IsValid()
|
||||
} else if data.Kind() == reflect.Slice {
|
||||
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||
|
@ -152,9 +149,7 @@ func (m *Model) hasColumn(name string) bool {
|
|||
}
|
||||
|
||||
func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
|
||||
data := m.reflectData()
|
||||
|
||||
if data.Kind() == reflect.Struct {
|
||||
if data := m.reflectData(); data.Kind() == reflect.Struct {
|
||||
if has_column = data.FieldByName(name).IsValid(); has_column {
|
||||
value = data.FieldByName(name).Interface()
|
||||
}
|
||||
|
@ -165,15 +160,19 @@ func (m *Model) columnAndValue(name string) (has_column bool, is_slice bool, val
|
|||
return
|
||||
}
|
||||
|
||||
func (m *Model) typeName() string {
|
||||
func (m *Model) typ() reflect.Type {
|
||||
typ := m.reflectData().Type()
|
||||
if typ.Kind() == reflect.Slice {
|
||||
return typ.Elem().Name()
|
||||
return typ.Elem()
|
||||
} else {
|
||||
return typ.Name()
|
||||
return typ
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) typeName() string {
|
||||
return m.typ().Name()
|
||||
}
|
||||
|
||||
func (m *Model) tableName() (str string) {
|
||||
if m.data == nil {
|
||||
m.do.err(errors.New("Model haven't been set"))
|
||||
|
|
Loading…
Reference in New Issue