From da31f58607c8cc8d787db4b2154301f2725774a9 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Sat, 8 Aug 2015 14:08:25 -0700 Subject: [PATCH 1/6] Ensure DDL dialect queries propagate error states to descendent scopes. Includes relevant unit-test. Branched from jay/current_database (please merge that branch first!). --- common_dialect.go | 22 +++++++++++++++----- ddl_errors_test.go | 24 ++++++++++++++++++++++ delete_test.go | 12 +++++------ foundation.go | 24 +++++++++++----------- main.go | 9 +++++++-- main_test.go | 50 +++++++++++++++++++++++++--------------------- mssql.go | 12 +++++------ mysql.go | 4 ++-- postgres.go | 18 ++++++++--------- sqlite3.go | 14 ++++++------- 10 files changed, 117 insertions(+), 72 deletions(-) create mode 100644 ddl_errors_test.go diff --git a/common_dialect.go b/common_dialect.go index d7546ede..956121ff 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -74,7 +74,7 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool { databaseName string ) c.CurrentDatabase(scope, &databaseName) - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName).Row().Scan(&count) + c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, databaseName) return count > 0 } @@ -84,18 +84,30 @@ func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName stri databaseName string ) c.CurrentDatabase(scope, &databaseName) - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) + c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 } -func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName) return count > 0 } func (commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) +} + +// RawScanInt scans the first column of the first row into the `scan' int pointer. +// This function captures raw query errors and propagates them to the original scope. +func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { + scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) +} + +// RawScanString scans the first column of the first row into the `scan' string pointer. +// This function captures raw query errors and propagates them to the original scope. +func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) { + scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) } func (commonDialect) CurrentDatabase(scope *Scope, name *string) { diff --git a/ddl_errors_test.go b/ddl_errors_test.go new file mode 100644 index 00000000..aca59553 --- /dev/null +++ b/ddl_errors_test.go @@ -0,0 +1,24 @@ +package gorm_test + +import ( + "testing" +) + +func TestDdlErrors(t *testing.T) { + var err error + + if err = DB.Close(); err != nil { + t.Errorf("Closing DDL test db connection err=%s", err) + } + defer func() { + // Reopen DB connection. + if DB, err = OpenTestConnection(); err != nil { + t.Fatalf("Failed re-opening db connection: %s", err) + } + }() + + DB.HasTable("foobarbaz") + if DB.Error == nil { + t.Errorf("Expected operation on closed db to produce an error, but err was nil") + } +} diff --git a/delete_test.go b/delete_test.go index 74224a75..e0c71660 100644 --- a/delete_test.go +++ b/delete_test.go @@ -10,8 +10,8 @@ func TestDelete(t *testing.T) { DB.Save(&user1) DB.Save(&user2) - if DB.Delete(&user1).Error != nil { - t.Errorf("No error should happen when delete a record") + if err := DB.Delete(&user1).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) } if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { @@ -34,8 +34,8 @@ func TestInlineDelete(t *testing.T) { t.Errorf("User can't be found after delete") } - if DB.Delete(&User{}, "name = ?", user2.Name).Error != nil { - t.Errorf("No error should happen when delete a record") + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { t.Errorf("User can't be found after delete") } @@ -57,8 +57,8 @@ func TestSoftDelete(t *testing.T) { t.Errorf("Can't find a soft deleted record") } - if DB.Unscoped().First(&User{}, "name = ?", user.Name).Error != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped") + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) } DB.Unscoped().Delete(&user) diff --git a/foundation.go b/foundation.go index 8f68c720..0413360e 100644 --- a/foundation.go +++ b/foundation.go @@ -51,32 +51,32 @@ func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) strin panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) } -func (f foundation) ReturningStr(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key) +func (s foundation) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } -func (foundation) HasTable(scope *Scope, tableName string) bool { +func (s foundation) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName) return count > 0 } -func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName) return count > 0 } -func (f foundation) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName))) +func (s foundation) RemoveIndex(scope *Scope, indexName string) { + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName))) } -func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName) return count > 0 } -func (foundation) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT CURRENT_SCHEMA").Row().Scan(name)) +func (s foundation) CurrentDatabase(scope *Scope, name *string) { + s.RawScanString(scope, name, "SELECT CURRENT_SCHEMA") } diff --git a/main.go b/main.go index 7005f6b5..7e58e592 100644 --- a/main.go +++ b/main.go @@ -365,7 +365,10 @@ func (s *DB) Rollback() *DB { } func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + scope := s.clone().NewScope(value) + result := scope.PrimaryKeyZero() + s.err(scope.db.Error) + return result } func (s *DB) RecordNotFound() bool { @@ -388,7 +391,9 @@ func (s *DB) DropTableIfExists(value interface{}) *DB { func (s *DB) HasTable(value interface{}) bool { scope := s.clone().NewScope(value) tableName := scope.TableName() - return scope.Dialect().HasTable(scope, tableName) + has := scope.Dialect().HasTable(scope, tableName) + s.err(scope.db.Error) + return has } func (s *DB) AutoMigrate(values ...interface{}) *DB { diff --git a/main_test.go b/main_test.go index 0dc5e337..613cc543 100644 --- a/main_test.go +++ b/main_test.go @@ -26,25 +26,9 @@ var ( func init() { var err error - switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") - case "postgres": - fmt.Println("testing postgres...") - DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") - case "foundation": - fmt.Println("testing foundation...") - DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") - case "mssql": - fmt.Println("testing mssql...") - DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") - default: - fmt.Println("testing sqlite3...") - DB, err = gorm.Open("sqlite3", "/tmp/gorm.db") + + if DB, err = OpenTestConnection(); err != nil { + panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) } // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) @@ -52,15 +36,35 @@ func init() { DB.LogMode(true) DB.LogMode(false) - if err != nil { - panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) - } - DB.DB().SetMaxIdleConns(10) runMigration() } +func OpenTestConnection() (db gorm.DB, err error) { + switch os.Getenv("GORM_DIALECT") { + case "mysql": + // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; + // CREATE DATABASE gorm; + // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; + fmt.Println("testing mysql...") + db, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") + case "postgres": + fmt.Println("testing postgres...") + db, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "foundation": + fmt.Println("testing foundation...") + db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + case "mssql": + fmt.Println("testing mssql...") + db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") + default: + fmt.Println("testing sqlite3...") + db, err = gorm.Open("sqlite3", "/tmp/gorm.db") + } + return +} + func TestStringPrimaryKey(t *testing.T) { type UUIDStruct struct { ID string `gorm:"primary_key"` diff --git a/mssql.go b/mssql.go index b2633292..49d17fed 100644 --- a/mssql.go +++ b/mssql.go @@ -56,7 +56,7 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool { databaseName string ) s.CurrentDatabase(scope, &databaseName) - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) return count > 0 } @@ -66,16 +66,16 @@ func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool databaseName string ) s.CurrentDatabase(scope, &databaseName) - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 } -func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) return count > 0 } -func (mssql) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(name)) +func (s mssql) CurrentDatabase(scope *Scope, name *string) { + s.RawScanString(scope, name, "SELECT DB_NAME() AS [Current Database]") } diff --git a/mysql.go b/mysql.go index 025f87a5..7fadb5af 100644 --- a/mysql.go +++ b/mysql.go @@ -64,6 +64,6 @@ func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } -func (mysql) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(name)) +func (s mysql) CurrentDatabase(scope *Scope, name *string) { + s.RawScanString(scope, name, "SELECT DATABASE()") } diff --git a/postgres.go b/postgres.go index 7374af9c..9d1d776c 100644 --- a/postgres.go +++ b/postgres.go @@ -63,30 +63,30 @@ func (s postgres) ReturningStr(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } -func (postgres) HasTable(scope *Scope, tableName string) bool { +func (s postgres) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName) return count > 0 } -func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName) return count > 0 } func (postgres) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) } -func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) return count > 0 } -func (postgres) CurrentDatabase(scope *Scope, name *string) { - scope.Err(scope.NewDB().Raw("SELECT CURRENT_DATABASE()").Row().Scan(name)) +func (s postgres) CurrentDatabase(scope *Scope, name *string) { + s.RawScanString(scope, name, "SELECT CURRENT_DATABASE()") } var hstoreType = reflect.TypeOf(Hstore{}) diff --git a/sqlite3.go b/sqlite3.go index a73d2379..96183b77 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -40,26 +40,26 @@ func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (sqlite3) HasTable(scope *Scope, tableName string) bool { +func (s sqlite3) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) return count > 0 } -func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName) return count > 0 } -func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int - scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) + s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) return count > 0 } func (sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) } func (sqlite3) CurrentDatabase(scope *Scope, name *string) { From 17917d49d8d46788148d3e904595d9d5030754f5 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Wed, 12 Aug 2015 09:32:18 -0700 Subject: [PATCH 2/6] Reverted to original `NewRecord' func as per @jinzhu's feedback. --- main.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/main.go b/main.go index df3c29b4..ae406b8d 100644 --- a/main.go +++ b/main.go @@ -365,10 +365,7 @@ func (s *DB) Rollback() *DB { } func (s *DB) NewRecord(value interface{}) bool { - scope := s.clone().NewScope(value) - result := scope.PrimaryKeyZero() - s.err(scope.db.Error) - return result + return s.clone().NewScope(value).PrimaryKeyZero() } func (s *DB) RecordNotFound() bool { From ab1832b9a5db1c1ad60e000e89ac0d7cf99dc973 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Aug 2015 16:35:48 +0800 Subject: [PATCH 3/6] Handle children db's Error in callbacks --- scope.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 104a3728..75b524a0 100644 --- a/scope.go +++ b/scope.go @@ -205,7 +205,9 @@ func (scope *Scope) CallMethod(name string, checkError bool) { case func(s *Scope): f(scope) case func(s *DB): - f(scope.NewDB()) + newDB := scope.NewDB() + f(newDB) + scope.Err(newDB.Error) case func() error: scope.Err(f()) case func(s *Scope) error: From ff3c23c9e9ec09ce5a3b2c64c5e65567df515355 Mon Sep 17 00:00:00 2001 From: Leon Maia Date: Fri, 7 Aug 2015 17:16:40 -0300 Subject: [PATCH 4/6] fixes #593 - Dont include quotes on dest table --- scope_private.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope_private.go b/scope_private.go index 08da634e..f6f815d9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -575,7 +575,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_")) keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_") var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { From 905b6232a3ccef63fed92bb51f50bd47310328b1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 10 Aug 2015 12:40:56 +0800 Subject: [PATCH 5/6] Fix can't call callbacks for embedded pointers --- scope.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scope.go b/scope.go index 75b524a0..2af95566 100644 --- a/scope.go +++ b/scope.go @@ -222,10 +222,18 @@ func (scope *Scope) CallMethod(name string, checkError bool) { if values := scope.IndirectValue(); values.Kind() == reflect.Slice { for i := 0; i < values.Len(); i++ { - call(values.Index(i).Addr().Interface()) + value := values.Index(i).Addr().Interface() + if values.Index(i).Kind() == reflect.Ptr { + value = values.Index(i).Interface() + } + call(value) } } else { - call(scope.Value) + if scope.IndirectValue().CanAddr() { + call(scope.IndirectValue().Addr().Interface()) + } else { + call(scope.IndirectValue().Interface()) + } } } From 197ae0e893fd848c0be9e666ea6f684adb6169fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 12 Aug 2015 22:28:01 +0800 Subject: [PATCH 6/6] Fixed detecting pointers as a Scanner. thanks @NOX73 --- model_struct.go | 4 ++-- query_test.go | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/model_struct.go b/model_struct.go index 72caec24..26c58fc5 100644 --- a/model_struct.go +++ b/model_struct.go @@ -156,12 +156,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, field := range fields { if !field.IsIgnored { fieldStruct := field.Struct - fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + indirectType := fieldStruct.Type if indirectType.Kind() == reflect.Ptr { indirectType = indirectType.Elem() } - if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner { field.IsScanner, field.IsNormal = true, true } diff --git a/query_test.go b/query_test.go index ff022dbb..0fd58302 100644 --- a/query_test.go +++ b/query_test.go @@ -581,7 +581,6 @@ func TestSelectWithArrayInput(t *testing.T) { } func TestCurrentDatabase(t *testing.T) { - DB.LogMode(true) databaseName := DB.CurrentDatabase() if err := DB.Error; err != nil { t.Errorf("Problem getting current db name: %s", err)