diff --git a/errors.go b/errors.go index 508f6957..08755083 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrInvalidField = errors.New("invalid field") // ErrEmptySlice empty slice found ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) diff --git a/finisher_api.go b/finisher_api.go index f426839a..2c56d763 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -334,13 +334,21 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Row) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row } func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Rows), tx.Error + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error } // Scan scan value to a struct