2014-01-26 08:41:37 +04:00
package gorm
2014-01-26 09:51:23 +04:00
import (
2016-01-13 11:53:04 +03:00
"database/sql"
2014-01-26 09:51:23 +04:00
"errors"
"fmt"
2015-08-01 04:25:06 +03:00
"regexp"
2014-01-26 10:18:21 +04:00
"strings"
2014-01-26 09:51:23 +04:00
"reflect"
)
2014-01-26 08:41:37 +04:00
2016-03-07 16:09:05 +03:00
// Scope contain current operation's information when you perform any operation on the database
2014-01-26 08:41:37 +04:00
type Scope struct {
2014-11-11 06:46:21 +03:00
Search * search
2015-03-12 08:52:29 +03:00
Value interface { }
2016-03-07 09:54:20 +03:00
SQL string
SQLVars [ ] interface { }
2014-11-11 06:46:21 +03:00
db * DB
2016-01-15 16:03:35 +03:00
instanceID string
2015-03-12 08:52:29 +03:00
primaryKeyField * Field
skipLeft bool
2014-11-11 06:46:21 +03:00
fields map [ string ] * Field
2015-03-13 06:01:05 +03:00
selectAttrs * [ ] string
2014-07-30 10:58:00 +04:00
}
2016-03-07 09:54:20 +03:00
// IndirectValue return scope's reflect value's indirect value
2014-07-30 10:58:00 +04:00
func ( scope * Scope ) IndirectValue ( ) reflect . Value {
2016-01-18 07:20:27 +03:00
return indirect ( reflect . ValueOf ( scope . Value ) )
2014-01-26 08:41:37 +04:00
}
2014-01-29 15:14:37 +04:00
// New create a new Scope without search information
2014-01-27 04:26:59 +04:00
func ( scope * Scope ) New ( value interface { } ) * Scope {
2015-02-26 07:35:33 +03:00
return & Scope { db : scope . NewDB ( ) , Search : & search { } , Value : value }
2014-01-27 04:26:59 +04:00
}
2014-01-29 15:14:37 +04:00
// NewDB create a new DB without search information
2014-01-27 04:26:59 +04:00
func ( scope * Scope ) NewDB ( ) * DB {
2015-02-26 07:35:33 +03:00
if scope . db != nil {
db := scope . db . clone ( )
db . search = nil
2015-02-26 11:08:15 +03:00
db . Value = nil
2015-02-26 07:35:33 +03:00
return db
}
return nil
2014-01-27 04:26:59 +04:00
}
2016-03-07 09:54:20 +03:00
// DB return scope's DB connection
2015-02-26 11:08:15 +03:00
func ( scope * Scope ) DB ( ) * DB {
return scope . db
}
2016-03-07 09:54:20 +03:00
// SQLDB return *sql.DB
func ( scope * Scope ) SQLDB ( ) sqlCommon {
2014-01-26 08:41:37 +04:00
return scope . db . db
}
2014-01-29 15:14:37 +04:00
// SkipLeft skip remaining callbacks
2014-01-29 08:00:57 +04:00
func ( scope * Scope ) SkipLeft ( ) {
scope . skipLeft = true
}
2016-03-07 16:09:05 +03:00
// Quote used to quote string to escape them for database
2014-01-29 08:00:57 +04:00
func ( scope * Scope ) Quote ( str string ) string {
2015-02-24 13:02:22 +03:00
if strings . Index ( str , "." ) != - 1 {
newStrs := [ ] string { }
2015-02-26 07:35:33 +03:00
for _ , str := range strings . Split ( str , "." ) {
2015-02-24 13:02:22 +03:00
newStrs = append ( newStrs , scope . Dialect ( ) . Quote ( str ) )
}
return strings . Join ( newStrs , "." )
}
2016-01-15 16:03:35 +03:00
return scope . Dialect ( ) . Quote ( str )
2014-01-29 08:00:57 +04:00
}
2016-01-17 11:37:17 +03:00
func ( scope * Scope ) quoteIfPossible ( str string ) string {
2015-08-01 04:25:06 +03:00
if regexp . MustCompile ( "^[a-zA-Z]+(.[a-zA-Z]+)*$" ) . MatchString ( str ) {
return scope . Quote ( str )
}
return str
}
2014-01-29 15:14:37 +04:00
// Dialect get dialect
2014-04-25 03:20:23 +04:00
func ( scope * Scope ) Dialect ( ) Dialect {
2014-01-26 08:41:37 +04:00
return scope . db . parent . dialect
}
2016-03-07 16:09:05 +03:00
// Err add error to Scope
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) Err ( err error ) error {
if err != nil {
2015-08-13 11:42:13 +03:00
scope . db . AddError ( err )
2014-01-26 08:41:37 +04:00
}
return err
}
2014-01-29 15:14:37 +04:00
// Log print log message
2014-01-28 04:27:12 +04:00
func ( scope * Scope ) Log ( v ... interface { } ) {
scope . db . log ( v ... )
}
2014-01-29 15:14:37 +04:00
// HasError check if there are any error
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) HasError ( ) bool {
2014-01-28 12:29:42 +04:00
return scope . db . Error != nil
2014-01-26 08:41:37 +04:00
}
2016-03-07 09:54:20 +03:00
// PrimaryFields return scope's primary fields
2016-03-07 07:15:15 +03:00
func ( scope * Scope ) PrimaryFields ( ) ( fields [ ] * Field ) {
for _ , field := range scope . Fields ( ) {
if field . IsPrimaryKey {
fields = append ( fields , field )
}
2015-06-29 13:04:15 +03:00
}
return fields
}
2016-03-07 09:54:20 +03:00
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
2015-03-11 06:28:30 +03:00
func ( scope * Scope ) PrimaryField ( ) * Field {
if primaryFields := scope . GetModelStruct ( ) . PrimaryFields ; len ( primaryFields ) > 0 {
if len ( primaryFields ) > 1 {
2016-03-07 07:15:15 +03:00
if field , ok := scope . FieldByName ( "id" ) ; ok {
2015-03-11 06:28:30 +03:00
return field
}
}
2016-03-07 07:15:15 +03:00
return scope . PrimaryFields ( ) [ 0 ]
2015-02-17 09:30:37 +03:00
}
return nil
2014-11-11 06:46:21 +03:00
}
2016-03-07 16:09:05 +03:00
// PrimaryKey get main primary field's db name
2014-11-11 06:46:21 +03:00
func ( scope * Scope ) PrimaryKey ( ) string {
2015-03-11 06:28:30 +03:00
if field := scope . PrimaryField ( ) ; field != nil {
2014-11-11 06:46:21 +03:00
return field . DBName
}
2015-02-17 02:15:34 +03:00
return ""
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// PrimaryKeyZero check main primary field's value is blank or not
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) PrimaryKeyZero ( ) bool {
2015-03-11 06:28:30 +03:00
field := scope . PrimaryField ( )
2015-02-17 17:55:14 +03:00
return field == nil || field . IsBlank
2014-01-26 13:10:33 +04:00
}
2014-01-29 15:14:37 +04:00
// PrimaryKeyValue get the primary key's value
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) PrimaryKeyValue ( ) interface { } {
2015-03-11 06:28:30 +03:00
if field := scope . PrimaryField ( ) ; field != nil && field . Field . IsValid ( ) {
2014-11-11 06:46:21 +03:00
return field . Field . Interface ( )
2014-01-26 13:10:33 +04:00
}
2015-02-17 02:15:34 +03:00
return 0
2014-01-26 13:10:33 +04:00
}
2014-01-29 15:14:37 +04:00
// HasColumn to check if has column
2015-01-19 11:23:33 +03:00
func ( scope * Scope ) HasColumn ( column string ) bool {
2015-02-17 12:40:21 +03:00
for _ , field := range scope . GetStructFields ( ) {
2015-02-17 17:55:14 +03:00
if field . IsNormal && ( field . Name == column || field . DBName == column ) {
return true
2015-02-17 12:40:21 +03:00
}
2014-09-01 13:03:58 +04:00
}
2015-02-17 12:40:21 +03:00
return false
2014-01-28 08:28:44 +04:00
}
2016-03-07 16:09:05 +03:00
// SetColumn to set the column's value, column could be field or field's name/dbname
2014-09-30 16:02:51 +04:00
func ( scope * Scope ) SetColumn ( column interface { } , value interface { } ) error {
2016-02-18 17:24:35 +03:00
var updateAttrs = map [ string ] interface { } { }
if attrs , ok := scope . InstanceGet ( "gorm:update_attrs" ) ; ok {
updateAttrs = attrs . ( map [ string ] interface { } )
defer scope . InstanceSet ( "gorm:update_attrs" , updateAttrs )
}
2014-09-02 15:03:01 +04:00
if field , ok := column . ( * Field ) ; ok {
2016-02-18 17:24:35 +03:00
updateAttrs [ field . DBName ] = value
2014-09-02 15:03:01 +04:00
return field . Set ( value )
2015-04-16 09:08:13 +03:00
} else if name , ok := column . ( string ) ; ok {
2016-03-07 07:15:15 +03:00
var (
dbName = ToDBName ( name )
mostMatchedField * Field
)
for _ , field := range scope . Fields ( ) {
if field . DBName == value {
updateAttrs [ field . DBName ] = value
return field . Set ( value )
}
if ( field . DBName == dbName ) || ( field . Name == name && mostMatchedField == nil ) {
mostMatchedField = field
}
2014-08-30 18:39:28 +04:00
}
2015-04-16 09:08:13 +03:00
2016-03-07 07:15:15 +03:00
if mostMatchedField != nil {
updateAttrs [ mostMatchedField . DBName ] = value
return mostMatchedField . Set ( value )
2015-04-16 09:08:13 +03:00
}
2014-08-30 18:39:28 +04:00
}
2014-09-30 16:02:51 +04:00
return errors . New ( "could not convert column to field" )
2014-01-26 08:41:37 +04:00
}
2016-01-17 16:35:32 +03:00
func ( scope * Scope ) callMethod ( methodName string , reflectValue reflect . Value ) {
if reflectValue . CanAddr ( ) {
reflectValue = reflectValue . Addr ( )
2014-01-27 18:36:08 +04:00
}
2016-01-17 16:35:32 +03:00
if methodValue := reflectValue . MethodByName ( methodName ) ; methodValue . IsValid ( ) {
switch method := methodValue . Interface ( ) . ( type ) {
case func ( ) :
method ( )
case func ( * Scope ) :
method ( scope )
case func ( * DB ) :
newDB := scope . NewDB ( )
method ( newDB )
scope . Err ( newDB . Error )
case func ( ) error :
scope . Err ( method ( ) )
case func ( * Scope ) error :
scope . Err ( method ( scope ) )
case func ( * DB ) error :
newDB := scope . NewDB ( )
scope . Err ( method ( newDB ) )
scope . Err ( newDB . Error )
default :
scope . Err ( fmt . Errorf ( "unsupported function %v" , methodName ) )
2014-01-26 09:51:23 +04:00
}
}
2016-01-17 16:35:32 +03:00
}
2014-01-28 05:25:30 +04:00
2016-03-07 16:09:05 +03:00
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
2016-01-17 16:35:32 +03:00
func ( scope * Scope ) CallMethod ( methodName string ) {
if scope . Value == nil {
return
}
if indirectScopeValue := scope . IndirectValue ( ) ; indirectScopeValue . Kind ( ) == reflect . Slice {
for i := 0 ; i < indirectScopeValue . Len ( ) ; i ++ {
scope . callMethod ( methodName , indirectScopeValue . Index ( i ) )
2014-01-28 05:25:30 +04:00
}
} else {
2016-01-17 16:35:32 +03:00
scope . callMethod ( methodName , indirectScopeValue )
2014-01-28 05:25:30 +04:00
}
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// AddToVars add value as sql's vars, used to prevent SQL injection
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) AddToVars ( value interface { } ) string {
2015-02-24 17:06:35 +03:00
if expr , ok := value . ( * expr ) ; ok {
exp := expr . expr
for _ , arg := range expr . args {
exp = strings . Replace ( exp , "?" , scope . AddToVars ( arg ) , 1 )
}
return exp
}
2016-01-15 16:03:35 +03:00
2016-03-07 09:54:20 +03:00
scope . SQLVars = append ( scope . SQLVars , value )
return scope . Dialect ( ) . BindVar ( len ( scope . SQLVars ) )
2014-01-26 08:41:37 +04:00
}
2015-04-08 06:36:01 +03:00
type tabler interface {
TableName ( ) string
}
type dbTabler interface {
TableName ( * DB ) string
}
2016-01-13 09:58:30 +03:00
// TableName return table name
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) TableName ( ) string {
2015-03-12 08:52:29 +03:00
if scope . Search != nil && len ( scope . Search . tableName ) > 0 {
return scope . Search . tableName
2015-02-17 03:34:01 +03:00
}
2015-04-08 06:36:01 +03:00
if tabler , ok := scope . Value . ( tabler ) ; ok {
return tabler . TableName ( )
}
if tabler , ok := scope . Value . ( dbTabler ) ; ok {
return tabler . TableName ( scope . db )
}
2015-06-30 05:39:29 +03:00
return scope . GetModelStruct ( ) . TableName ( scope . db . Model ( scope . Value ) )
2014-01-26 10:18:21 +04:00
}
2016-01-13 09:58:30 +03:00
// QuotedTableName return quoted table name
2015-02-26 11:08:15 +03:00
func ( scope * Scope ) QuotedTableName ( ) ( name string ) {
2015-03-12 08:52:29 +03:00
if scope . Search != nil && len ( scope . Search . tableName ) > 0 {
2015-05-19 05:43:32 +03:00
if strings . Index ( scope . Search . tableName , " " ) != - 1 {
return scope . Search . tableName
}
2015-03-12 08:52:29 +03:00
return scope . Quote ( scope . Search . tableName )
2014-06-03 13:15:05 +04:00
}
2016-01-15 16:03:35 +03:00
return scope . Quote ( scope . TableName ( ) )
2014-06-03 13:15:05 +04:00
}
2016-03-07 09:54:20 +03:00
// CombinedConditionSql return combined condition sql
2014-01-29 04:55:45 +04:00
func ( scope * Scope ) CombinedConditionSql ( ) string {
2016-03-07 09:54:20 +03:00
return scope . joinsSQL ( ) + scope . whereSQL ( ) + scope . groupSQL ( ) +
scope . havingSQL ( ) + scope . orderSQL ( ) + scope . limitAndOffsetSQL ( )
2014-01-26 10:55:41 +04:00
}
2016-03-07 16:09:05 +03:00
// FieldByName find `gorm.Field` with field name or db name
2014-07-31 07:08:26 +04:00
func ( scope * Scope ) FieldByName ( name string ) ( field * Field , ok bool ) {
2016-03-07 07:15:15 +03:00
var (
dbName = ToDBName ( name )
mostMatchedField * Field
)
2014-09-02 15:03:01 +04:00
for _ , field := range scope . Fields ( ) {
2015-06-04 07:10:09 +03:00
if field . Name == name || field . DBName == name {
2014-09-02 15:03:01 +04:00
return field , true
2014-07-31 07:08:26 +04:00
}
2016-03-07 07:15:15 +03:00
if field . DBName == dbName {
mostMatchedField = field
}
2014-07-31 07:08:26 +04:00
}
2016-03-07 07:15:15 +03:00
return mostMatchedField , mostMatchedField != nil
2014-07-31 07:08:26 +04:00
}
2016-03-07 16:09:05 +03:00
// Raw set raw sql
2014-01-28 11:54:19 +04:00
func ( scope * Scope ) Raw ( sql string ) * Scope {
2016-03-07 09:54:20 +03:00
scope . SQL = strings . Replace ( sql , "$$" , "?" , - 1 )
2014-01-28 11:54:19 +04:00
return scope
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// Exec perform generated SQL
2014-01-28 06:23:31 +04:00
func ( scope * Scope ) Exec ( ) * Scope {
2016-01-13 09:58:30 +03:00
defer scope . trace ( NowFunc ( ) )
2014-01-28 11:54:19 +04:00
2014-01-26 10:18:21 +04:00
if ! scope . HasError ( ) {
2016-03-07 09:54:20 +03:00
if result , err := scope . SQLDB ( ) . Exec ( scope . SQL , scope . SQLVars ... ) ; scope . Err ( err ) == nil {
2015-07-02 22:06:06 +03:00
if count , err := result . RowsAffected ( ) ; scope . Err ( err ) == nil {
2014-06-05 13:58:14 +04:00
scope . db . RowsAffected = count
}
}
2014-01-26 10:18:21 +04:00
}
2014-01-28 06:23:31 +04:00
return scope
2014-01-26 08:41:37 +04:00
}
2014-01-26 13:10:33 +04:00
2014-01-29 15:14:37 +04:00
// Set set value by name
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) Set ( name string , value interface { } ) * Scope {
2014-08-25 13:10:46 +04:00
scope . db . InstantSet ( name , value )
2014-08-20 13:05:02 +04:00
return scope
2014-01-27 07:56:04 +04:00
}
2016-03-07 16:09:05 +03:00
// Get get setting by name
2014-08-20 12:25:01 +04:00
func ( scope * Scope ) Get ( name string ) ( interface { } , bool ) {
return scope . db . Get ( name )
2014-01-29 15:14:37 +04:00
}
2016-01-15 16:03:35 +03:00
// InstanceID get InstanceID for scope
func ( scope * Scope ) InstanceID ( ) string {
if scope . instanceID == "" {
scope . instanceID = fmt . Sprintf ( "%v%v" , & scope , & scope . db )
2014-08-20 13:05:02 +04:00
}
2016-01-15 16:03:35 +03:00
return scope . instanceID
2014-08-20 13:05:02 +04:00
}
2016-03-07 16:09:05 +03:00
// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) InstanceSet ( name string , value interface { } ) * Scope {
2016-01-15 16:03:35 +03:00
return scope . Set ( name + scope . InstanceID ( ) , value )
2014-08-20 13:05:02 +04:00
}
2016-03-07 16:09:05 +03:00
// InstanceGet get instance setting from current operation
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) InstanceGet ( name string ) ( interface { } , bool ) {
2016-01-15 16:03:35 +03:00
return scope . Get ( name + scope . InstanceID ( ) )
2014-08-20 13:05:02 +04:00
}
2014-01-29 15:14:37 +04:00
// Begin start a transaction
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) Begin ( ) * Scope {
2016-03-07 09:54:20 +03:00
if db , ok := scope . SQLDB ( ) . ( sqlDb ) ; ok {
2014-01-27 06:47:37 +04:00
if tx , err := db . Begin ( ) ; err == nil {
scope . db . db = interface { } ( tx ) . ( sqlCommon )
2014-08-20 13:05:02 +04:00
scope . InstanceSet ( "gorm:started_transaction" , true )
2014-01-27 06:47:37 +04:00
}
2014-01-26 13:10:33 +04:00
}
return scope
}
2016-03-07 16:09:05 +03:00
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) CommitOrRollback ( ) * Scope {
2014-08-20 13:05:02 +04:00
if _ , ok := scope . InstanceGet ( "gorm:started_transaction" ) ; ok {
2014-01-26 13:10:33 +04:00
if db , ok := scope . db . db . ( sqlTx ) ; ok {
if scope . HasError ( ) {
db . Rollback ( )
} else {
2016-01-04 08:32:35 +03:00
scope . Err ( db . Commit ( ) )
2014-01-26 13:10:33 +04:00
}
scope . db . db = scope . db . parent . db
}
}
return scope
}
2015-03-12 12:47:31 +03:00
2016-03-07 16:09:05 +03:00
// SelectAttrs return selected attributes
2015-03-13 06:01:05 +03:00
func ( scope * Scope ) SelectAttrs ( ) [ ] string {
if scope . selectAttrs == nil {
attrs := [ ] string { }
for _ , value := range scope . Search . selects {
if str , ok := value . ( string ) ; ok {
attrs = append ( attrs , str )
2015-05-11 10:17:35 +03:00
} else if strs , ok := value . ( [ ] string ) ; ok {
attrs = append ( attrs , strs ... )
2015-03-13 06:01:05 +03:00
} else if strs , ok := value . ( [ ] interface { } ) ; ok {
for _ , str := range strs {
attrs = append ( attrs , fmt . Sprintf ( "%v" , str ) )
}
2015-03-12 12:47:31 +03:00
}
}
2015-03-13 06:01:05 +03:00
scope . selectAttrs = & attrs
2015-03-12 12:47:31 +03:00
}
2015-03-13 06:01:05 +03:00
return * scope . selectAttrs
2015-03-12 12:47:31 +03:00
}
2016-03-07 09:54:20 +03:00
// OmitAttrs return omited attributes
2015-03-12 12:47:31 +03:00
func ( scope * Scope ) OmitAttrs ( ) [ ] string {
return scope . Search . omits
}
2016-01-13 11:53:04 +03:00
2016-03-07 07:15:15 +03:00
func ( scope * Scope ) scan ( rows * sql . Rows , columns [ ] string , fieldsMap map [ string ] * Field ) {
2016-01-13 11:53:04 +03:00
var values = make ( [ ] interface { } , len ( columns ) )
var ignored interface { }
for index , column := range columns {
2016-03-07 07:15:15 +03:00
if field , ok := fieldsMap [ column ] ; ok {
2016-01-13 11:53:04 +03:00
if field . Field . Kind ( ) == reflect . Ptr {
values [ index ] = field . Field . Addr ( ) . Interface ( )
} else {
reflectValue := reflect . New ( reflect . PtrTo ( field . Struct . Type ) )
reflectValue . Elem ( ) . Set ( field . Field . Addr ( ) )
values [ index ] = reflectValue . Interface ( )
}
} else {
values [ index ] = & ignored
}
}
scope . Err ( rows . Scan ( values ... ) )
for index , column := range columns {
2016-03-07 07:15:15 +03:00
if field , ok := fieldsMap [ column ] ; ok {
2016-01-13 11:53:04 +03:00
if field . Field . Kind ( ) != reflect . Ptr {
if v := reflect . ValueOf ( values [ index ] ) . Elem ( ) . Elem ( ) ; v . IsValid ( ) {
field . Field . Set ( v )
}
}
}
}
}