diff --git a/README.md b/README.md index b0dcb7bd..9535015e 100644 --- a/README.md +++ b/README.md @@ -68,15 +68,15 @@ go get -u github.com/jinzhu/gorm - [Limit](#limit) - [Offset](#offset) - [Count](#count) - - [Pluck](#pluck) - - [Raw SQL](#raw-sql) - - [Row & Rows](#row--rows) - - [Scan](#scan) - [Group & Having](#group--having) - [Joins](#joins) - [Transactions](#transactions) - [Scopes](#scopes) - [Callbacks](#callbacks) + - [Pluck](#pluck) + - [Scan](#scan) + - [Raw SQL](#raw-sql) + - [Row & Rows](#row--rows) - [Specifying The Table Name](#specifying-the-table-name) - [Error Handling](#error-handling) - [Logger](#logger) @@ -866,73 +866,6 @@ db.Table("deleted_users").Count(&count) //// SELECT count(*) FROM deleted_users; ``` -## Pluck - -Get selected attributes as map - -```go -var ages []int64 -db.Find(&users).Pluck("age", &ages) - -var names []string -db.Model(&User{}).Pluck("name", &names) - -db.Table("deleted_users").Pluck("name", &names) - -// Requesting more than one column? Do it like this: -db.Select("name, age").Find(&users) -``` - -## Raw SQL - -```go -db.Exec("DROP TABLE users;") -db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33}) -``` - -## Row & Rows - -It is even possible to get query result as `*sql.Row` or `*sql.Rows` - -```go -row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row) -row.Scan(&name, &age) - -rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) -defer rows.Close() -for rows.Next() { - ... - rows.Scan(&name, &age, &email) - ... -} - -// Raw SQL -rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error) -defer rows.Close() -for rows.Next() { - ... - rows.Scan(&name, &age, &email) - ... -} -``` - -## Scan - -Scan results into another struct. - -```go -type Result struct { - Name string - Age int -} - -var result Result -db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result) - -// Raw SQL -db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -``` - ## Group & Having ```go @@ -1108,7 +1041,7 @@ func (u *User) AfterCreate() (err error) { } ``` -Save/delete operations in gorm are running in a transaction. +Save/delete operations in gorm are running in a transaction. Changes made in that transaction are not visible unless it is commited. So if you want to use those changes in your callbacks, you need to run your SQL in the same transaction. For this Gorm supports passing transactions to callbacks like this: @@ -1120,6 +1053,86 @@ func (u *User) AfterCreate(tx *gorm.DB) (err error) { } ``` +## Pluck + +Get selected attributes as map + +```go +var ages []int64 +db.Find(&users).Pluck("age", &ages) + +var names []string +db.Model(&User{}).Pluck("name", &names) + +db.Table("deleted_users").Pluck("name", &names) + +// Requesting more than one column? Do it like this: +db.Select("name, age").Find(&users) +``` + +## Scan + +Scan results into another struct. + +```go +type Result struct { + Name string + Age int +} + +var result Result +db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result) + +// Raw SQL +db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) +``` + +## Raw SQL + +```go +db.Exec("DROP TABLE users;") +db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33}) +``` + +## Row & Rows + +It is even possible to get query result as `*sql.Row` or `*sql.Rows` + +```go +row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row) +row.Scan(&name, &age) + +rows, err := db.Model(&User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) +defer rows.Close() +for rows.Next() { + ... + rows.Scan(&name, &age, &email) + ... +} + +// Raw SQL +rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error) +defer rows.Close() +for rows.Next() { + ... + rows.Scan(&name, &age, &email) + ... +} +``` + +### Scan Rows + +```go +rows, err := db.Model(&User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) +defer rows.Close() + +for rows.Next() { + var user User + db.ScanRows(rows, &user) + // do something +} +``` + ## Specifying The Table Name ```go diff --git a/main.go b/main.go index 5f4d9dbd..51bd9914 100644 --- a/main.go +++ b/main.go @@ -224,6 +224,20 @@ func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } +func (s *DB) ScanRows(rows *sql.Rows, value interface{}) error { + var ( + clone = s.clone() + scope = clone.NewScope(value) + columns, err = rows.Columns() + ) + + if clone.AddError(err) == nil { + scope.scan(rows, columns, scope.Fields()) + } + + return clone.Error +} + func (s *DB) Pluck(column string, value interface{}) *DB { return s.NewScope(s.Value).pluck(column, value).db } diff --git a/main_test.go b/main_test.go index c07d474b..97a3d84e 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "reflect" "strconv" _ "github.com/denisenkom/go-mssqldb" @@ -376,7 +377,7 @@ func TestRows(t *testing.T) { rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { - t.Errorf("Not error should happen, but got") + t.Errorf("Not error should happen, got %v", err) } count := 0 @@ -386,8 +387,39 @@ func TestRows(t *testing.T) { rows.Scan(&name, &age) count++ } + if count != 2 { - t.Errorf("Should found two records with name 3") + t.Errorf("Should found two records") + } +} + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") } } diff --git a/utils.go b/utils.go index 9d2bb075..94e345cc 100644 --- a/utils.go +++ b/utils.go @@ -2,8 +2,11 @@ package gorm import ( "bytes" + "database/sql/driver" "fmt" "reflect" + "regexp" + "runtime" "strings" "sync" ) @@ -50,6 +53,7 @@ const ( upper strCase = true ) +// ToDBName convert string to db name func ToDBName(name string) string { if v := smap.Get(name); v != "" { return v @@ -94,11 +98,14 @@ func ToDBName(name string) string { return s } +// SQL expression type expr struct { expr string args []interface{} } +// Expr generate raw SQL expression for SQL, for example: +// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } @@ -148,3 +155,124 @@ func toQueryValues(values [][]interface{}) (results []interface{}) { } return } + +func fileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +} + +func isBlank(value reflect.Value) bool { + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) +} + +func toSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} + +func convertInterfaceToMap(values interface{}) map[string]interface{} { + attrs := map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + for k, v := range value { + attrs[k] = v + } + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + for _, field := range (&Scope{Value: values}).Fields() { + if !field.IsBlank && !field.IsIgnored { + attrs[field.DBName] = field.Field.Interface() + } + } + } + } + return attrs +} + +func equalAsString(a interface{}, b interface{}) bool { + return toString(a) == toString(b) +} + +func toString(str interface{}) string { + if values, ok := str.([]interface{}); ok { + var results []string + for _, value := range values { + results = append(results, toString(value)) + } + return strings.Join(results, "_") + } else if bytes, ok := str.([]byte); ok { + return string(bytes) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) + } + return "" +} + +func makeSlice(elemType reflect.Type) interface{} { + if elemType.Kind() == reflect.Slice { + elemType = elemType.Elem() + } + sliceType := reflect.SliceOf(elemType) + slice := reflect.New(sliceType) + slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) + return slice.Interface() +} + +func strInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// getValueFromFields return given fields's value +func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { + // If value is a nil pointer, Indirect returns a zero Value! + // Therefor we need to check for a zero value, + // as FieldByName could panic + if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { + for _, fieldName := range fieldNames { + if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + result := fieldValue.Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } + } + } + return +} diff --git a/utils_private.go b/utils_private.go deleted file mode 100644 index 2851a37e..00000000 --- a/utils_private.go +++ /dev/null @@ -1,131 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" -) - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func convertInterfaceToMap(values interface{}) map[string]interface{} { - attrs := map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - for k, v := range value { - attrs[k] = v - } - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank && !field.IsIgnored { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -}