diff --git a/README.md b/README.md index ac843627..975263ca 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,12 @@ Yet Another ORM library for Go, aims for developer friendly db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") type User struct { - Id int64 - Age int64 - Birthday time.Time - Name string - CreatedAt time.Time - UpdatedAt time.Time + Id int64 + Age int64 + Birthday time.Time + Name string + CreatedAt time.Time + UpdatedAt time.Time } // Set database pool @@ -127,12 +127,12 @@ Below callbacks are defined now: Callbacks is a function defined to a model, if the function return error, will prevent the database operations. - func (u *User) BeforeUpdate() (err error) { - if u.readonly() { - err = errors.New("Read Only User") - } - return +func (u *User) BeforeUpdate() (err error) { + if u.readonly() { + err = errors.New("Read Only User") } + return +} // Pluck (get users's age as map) var ages []int64 @@ -149,6 +149,16 @@ db.Where("name <> ?", "jinzhu").Where("age >= ? and role <> ?", 20, "admin").Fin // Create Table with struct db.CreateTable(&User{}) +// Specify Table Name +db.Table("deleted_users").CreateTable(&User{}) +db.Table("users").Pluck("age", &ages) +//// ages -> select age from users; +var deleted_users []User +db.Table("deleted_users").Find(&deleted_users) +//// deleted_users -> select * from deleted_users; +db.Table("deleted_users").Find(&deleted_user) +//// deleted_user -> select * from deleted_users limit 1; + // Run Raw SQL db.Exec("drop table users;") @@ -185,6 +195,10 @@ db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_cart) //// shopping_cart -> select * from carts where product_name = 'fancy_product' // Do you noticed the search table is different for above query, yay +db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&user2) +//// users1 -> select * from users where mail_type = 'TEXT'; +//// users2 -> select * from deleted_users where mail_type = 'TEXT'; + // Open your mind, add more cool examples ``` diff --git a/chain.go b/chain.go index ca863746..dba71b71 100644 --- a/chain.go +++ b/chain.go @@ -17,12 +17,13 @@ type Chain struct { Errors []error Error error - whereClause []map[string]interface{} - orClause []map[string]interface{} - selectStr string - orderStrs []string - offsetStr string - limitStr string + whereClause []map[string]interface{} + orClause []map[string]interface{} + selectStr string + orderStrs []string + offsetStr string + limitStr string + specifiedTableName string } func (s *Chain) err(err error) { @@ -44,6 +45,7 @@ func (s *Chain) do(value interface{}) *Do { do.orderStrs = s.orderStrs do.offsetStr = s.offsetStr do.limitStr = s.limitStr + do.specifiedTableName = s.specifiedTableName s.value = value do.setModel(value) @@ -174,6 +176,11 @@ func (s *Chain) CreateTable(value interface{}) *Chain { return s } +func (s *Chain) Table(name string) *Chain { + s.specifiedTableName = name + return s +} + func (s *Chain) validSql(str string) (result bool) { result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) if !result { diff --git a/do.go b/do.go index 1ededa7e..6f89564b 100644 --- a/do.go +++ b/do.go @@ -12,11 +12,12 @@ import ( ) type Do struct { - chain *Chain - db *sql.DB - driver string - TableName string - Errors []error + chain *Chain + db *sql.DB + driver string + guessedTableName string + specifiedTableName string + Errors []error model *Model value interface{} @@ -34,6 +35,14 @@ type Do struct { operation string } +func (s *Do) tableName() string { + if s.specifiedTableName == "" { + return s.guessedTableName + } else { + return s.specifiedTableName + } +} + func (s *Do) err(err error) { if err != nil { s.Errors = append(s.Errors, err) @@ -49,8 +58,10 @@ func (s *Do) setModel(value interface{}) { s.value = value s.model = &Model{Data: value, driver: s.driver} var err error - s.TableName, err = s.model.tableName() - s.err(err) + if s.specifiedTableName == "" { + s.guessedTableName, err = s.model.tableName() + s.err(err) + } } func (s *Do) addToVars(value interface{}) string { @@ -91,7 +102,7 @@ func (s *Do) prepareCreateSql() *Do { s.Sql = fmt.Sprintf( "INSERT INTO \"%v\" (%v) VALUES (%v) %v", - s.TableName, + s.tableName(), strings.Join(s.quoteMap(columns), ","), strings.Join(sqls, ","), s.model.returningStr(), @@ -135,7 +146,7 @@ func (s *Do) prepareUpdateSql() *Do { s.Sql = fmt.Sprintf( "UPDATE %v SET %v %v", - s.TableName, + s.tableName(), strings.Join(sets, ", "), s.combinedSql(), ) @@ -154,7 +165,7 @@ func (s *Do) update() *Do { } func (s *Do) prepareDeleteSql() *Do { - s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) + s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) return s } @@ -168,7 +179,7 @@ func (s *Do) delete() *Do { } func (s *Do) prepareQuerySql() *Do { - s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.TableName, s.combinedSql()) + s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql()) return s } @@ -216,7 +227,10 @@ func (s *Do) query(where ...interface{}) { columns, _ := rows.Columns() var values []interface{} for _, value := range columns { - values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) + field := dest.FieldByName(snakeToUpperCamel(value)) + if field.IsValid() { + values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) + } } s.err(rows.Scan(values...)) @@ -400,10 +414,9 @@ func (s *Do) createTable() *Do { for _, field := range s.model.fields("null") { sqls = append(sqls, field.DbName+" "+field.SqlType) } - s.Sql = fmt.Sprintf( "CREATE TABLE \"%v\" (%v)", - s.TableName, + s.tableName(), strings.Join(sqls, ","), ) return s diff --git a/gorm_test.go b/gorm_test.go index 503f08e0..20353e60 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1,11 +1,11 @@ package gorm import ( - _ "github.com/lib/pq" - "errors" + _ "github.com/lib/pq" "reflect" "strconv" + "testing" "time" ) @@ -567,3 +567,52 @@ func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) { db.Where("unexisting = ?", "3").Find(&[]User{}) } + +func TestSetTableDirectly(t *testing.T) { + var ages []int64 + if db.Table("users").Pluck("age", &ages).Error != nil { + t.Errorf("No errors should happen if only set table") + } + + if len(ages) == 0 { + t.Errorf("Should find out some records") + } + + var users []User + if db.Table("users").Find(&users).Error != nil { + t.Errorf("No errors should happen if set table to an existing table") + } + + if db.Table("unexisting_users_table").Find(&users).Error == nil { + t.Errorf("Should got some errors if set table to an unexisting table") + } + + if db.Table("products").Find(&users).Error == nil { + t.Errorf("Should got some errors if set table to an unexisting table") + } + + db.Exec("drop table deleted_users;") + if db.Table("deleted_users").CreateTable(&User{}).Error != nil { + t.Errorf("Should create deleted_users table") + } + + db.Table("deleted_users").Save(&User{Name: "DeletedUser"}) + + var deleted_users []User + db.Table("deleted_users").Find(&deleted_users) + if len(deleted_users) != 1 { + t.Errorf("Should query from deleted_users table") + } + + var deleted_user User + db.Table("deleted_users").Find(&deleted_user) + if deleted_user.Name != "DeletedUser" { + t.Errorf("Should query from deleted_users table") + } + + var user1, user2, user3 User + db.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) + if !((user1.Name != user2.Name) && (user1.Name == user3.Name)) { + t.Errorf("Set Table Chain Should works well") + } +} diff --git a/main.go b/main.go index 4eeb14e6..f8c54a03 100644 --- a/main.go +++ b/main.go @@ -65,6 +65,10 @@ func (s *DB) Model(value interface{}) *Chain { return s.buildORM().Model(value) } +func (s *DB) Table(name string) *Chain { + return s.buildORM().Table(name) +} + func (s *DB) CreateTable(value interface{}) *Chain { return s.buildORM().CreateTable(value) } diff --git a/model.go b/model.go index b75bcf51..52cd681f 100644 --- a/model.go +++ b/model.go @@ -29,6 +29,10 @@ func (m *Model) primaryKeyZero() bool { } func (m *Model) primaryKeyValue() int64 { + if m.Data == nil { + return 0 + } + t := reflect.TypeOf(m.Data).Elem() switch t.Kind() { case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: