diff --git a/callbacks/query.go b/callbacks/query.go index 95b5ead3..c9fa160f 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -90,7 +90,11 @@ func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClause(clauseSelect) + if len(clauseSelect.Columns) > 0 { + db.Statement.AddClause(clauseSelect) + } else { + db.Statement.AddClauseIfNotExists(clauseSelect) + } db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/finisher_api.go b/finisher_api.go index c64ecdda..84168e23 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -48,7 +48,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { } // First find first record that match given conditions, order by primary key -func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) @@ -56,25 +56,25 @@ func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Take return a record that match given conditions, the order will depend on the database implementation -func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1) if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Last find last record that match given conditions, order by primary key -func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, @@ -83,28 +83,28 @@ func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } tx.Statement.RaiseErrorOnNotFound = true - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } // Find find records that match given conditions -func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) { +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) } - tx.Statement.Dest = out + tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrInit(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(dest interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } @@ -181,6 +181,8 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) return } diff --git a/tests/scan_test.go b/tests/scan_test.go new file mode 100644 index 00000000..f7a14636 --- /dev/null +++ b/tests/scan_test.go @@ -0,0 +1,40 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestScan(t *testing.T) { + user1 := User{Name: "ScanUser1", Age: 1} + user2 := User{Name: "ScanUser2", Age: 10} + user3 := User{Name: "ScanUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Age int + } + + var res result + DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) + if res.Name != user3.Name || res.Age != int(user3.Age) { + t.Errorf("Scan into struct should work") + } + + var doubleAgeRes = &result{} + if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + t.Errorf("Scan to pointer of pointer") + } + + if doubleAgeRes.Age != int(res.Age)*2 { + t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) + } + + var ress []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress) + if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + t.Errorf("Scan into struct map") + } +}