2014-01-26 08:41:37 +04:00
package gorm
2014-01-26 09:51:23 +04:00
import (
2016-01-13 11:53:04 +03:00
"database/sql"
2016-03-08 16:45:20 +03:00
"database/sql/driver"
2014-01-26 09:51:23 +04:00
"errors"
"fmt"
2015-08-01 04:25:06 +03:00
"regexp"
2016-03-08 16:45:20 +03:00
"strconv"
2014-01-26 10:18:21 +04:00
"strings"
2016-03-08 16:45:20 +03:00
"time"
2014-01-26 09:51:23 +04:00
"reflect"
)
2014-01-26 08:41:37 +04:00
2016-03-07 16:09:05 +03:00
// Scope contain current operation's information when you perform any operation on the database
2014-01-26 08:41:37 +04:00
type Scope struct {
2014-11-11 06:46:21 +03:00
Search * search
2015-03-12 08:52:29 +03:00
Value interface { }
2016-03-07 09:54:20 +03:00
SQL string
SQLVars [ ] interface { }
2014-11-11 06:46:21 +03:00
db * DB
2016-01-15 16:03:35 +03:00
instanceID string
2015-03-12 08:52:29 +03:00
primaryKeyField * Field
skipLeft bool
2016-03-08 17:29:58 +03:00
fields * [ ] * Field
2015-03-13 06:01:05 +03:00
selectAttrs * [ ] string
2014-07-30 10:58:00 +04:00
}
2016-03-07 09:54:20 +03:00
// IndirectValue return scope's reflect value's indirect value
2014-07-30 10:58:00 +04:00
func ( scope * Scope ) IndirectValue ( ) reflect . Value {
2016-01-18 07:20:27 +03:00
return indirect ( reflect . ValueOf ( scope . Value ) )
2014-01-26 08:41:37 +04:00
}
2014-01-29 15:14:37 +04:00
// New create a new Scope without search information
2014-01-27 04:26:59 +04:00
func ( scope * Scope ) New ( value interface { } ) * Scope {
2015-02-26 07:35:33 +03:00
return & Scope { db : scope . NewDB ( ) , Search : & search { } , Value : value }
2014-01-27 04:26:59 +04:00
}
2016-03-08 17:29:58 +03:00
////////////////////////////////////////////////////////////////////////////////
// Scope DB
////////////////////////////////////////////////////////////////////////////////
// DB return scope's DB connection
func ( scope * Scope ) DB ( ) * DB {
return scope . db
}
2014-01-29 15:14:37 +04:00
// NewDB create a new DB without search information
2014-01-27 04:26:59 +04:00
func ( scope * Scope ) NewDB ( ) * DB {
2015-02-26 07:35:33 +03:00
if scope . db != nil {
db := scope . db . clone ( )
db . search = nil
2015-02-26 11:08:15 +03:00
db . Value = nil
2015-02-26 07:35:33 +03:00
return db
}
return nil
2014-01-27 04:26:59 +04:00
}
2016-03-07 09:54:20 +03:00
// SQLDB return *sql.DB
func ( scope * Scope ) SQLDB ( ) sqlCommon {
2014-01-26 08:41:37 +04:00
return scope . db . db
}
2016-03-08 17:29:58 +03:00
// Dialect get dialect
func ( scope * Scope ) Dialect ( ) Dialect {
return scope . db . parent . dialect
2014-01-29 08:00:57 +04:00
}
2016-03-07 16:09:05 +03:00
// Quote used to quote string to escape them for database
2014-01-29 08:00:57 +04:00
func ( scope * Scope ) Quote ( str string ) string {
2015-02-24 13:02:22 +03:00
if strings . Index ( str , "." ) != - 1 {
newStrs := [ ] string { }
2015-02-26 07:35:33 +03:00
for _ , str := range strings . Split ( str , "." ) {
2015-02-24 13:02:22 +03:00
newStrs = append ( newStrs , scope . Dialect ( ) . Quote ( str ) )
}
return strings . Join ( newStrs , "." )
}
2016-01-15 16:03:35 +03:00
return scope . Dialect ( ) . Quote ( str )
2014-01-29 08:00:57 +04:00
}
2016-03-07 16:09:05 +03:00
// Err add error to Scope
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) Err ( err error ) error {
if err != nil {
2015-08-13 11:42:13 +03:00
scope . db . AddError ( err )
2014-01-26 08:41:37 +04:00
}
return err
}
2016-03-08 17:29:58 +03:00
// HasError check if there are any error
func ( scope * Scope ) HasError ( ) bool {
return scope . db . Error != nil
}
2014-01-29 15:14:37 +04:00
// Log print log message
2014-01-28 04:27:12 +04:00
func ( scope * Scope ) Log ( v ... interface { } ) {
scope . db . log ( v ... )
}
2016-03-08 17:29:58 +03:00
// SkipLeft skip remaining callbacks
func ( scope * Scope ) SkipLeft ( ) {
scope . skipLeft = true
}
// Fields get value's fields
func ( scope * Scope ) Fields ( ) [ ] * Field {
if scope . fields == nil {
var (
fields [ ] * Field
indirectScopeValue = scope . IndirectValue ( )
isStruct = indirectScopeValue . Kind ( ) == reflect . Struct
)
for _ , structField := range scope . GetModelStruct ( ) . StructFields {
if isStruct {
fieldValue := indirectScopeValue
for _ , name := range structField . Names {
fieldValue = reflect . Indirect ( fieldValue ) . FieldByName ( name )
}
fields = append ( fields , & Field { StructField : structField , Field : fieldValue , IsBlank : isBlank ( fieldValue ) } )
} else {
fields = append ( fields , & Field { StructField : structField , IsBlank : true } )
}
}
scope . fields = & fields
}
return * scope . fields
}
// FieldByName find `gorm.Field` with field name or db name
func ( scope * Scope ) FieldByName ( name string ) ( field * Field , ok bool ) {
var (
dbName = ToDBName ( name )
mostMatchedField * Field
)
for _ , field := range scope . Fields ( ) {
if field . Name == name || field . DBName == name {
return field , true
}
if field . DBName == dbName {
mostMatchedField = field
}
}
return mostMatchedField , mostMatchedField != nil
2014-01-26 08:41:37 +04:00
}
2016-03-07 09:54:20 +03:00
// PrimaryFields return scope's primary fields
2016-03-07 07:15:15 +03:00
func ( scope * Scope ) PrimaryFields ( ) ( fields [ ] * Field ) {
for _ , field := range scope . Fields ( ) {
if field . IsPrimaryKey {
fields = append ( fields , field )
}
2015-06-29 13:04:15 +03:00
}
return fields
}
2016-03-07 09:54:20 +03:00
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
2015-03-11 06:28:30 +03:00
func ( scope * Scope ) PrimaryField ( ) * Field {
if primaryFields := scope . GetModelStruct ( ) . PrimaryFields ; len ( primaryFields ) > 0 {
if len ( primaryFields ) > 1 {
2016-03-07 07:15:15 +03:00
if field , ok := scope . FieldByName ( "id" ) ; ok {
2015-03-11 06:28:30 +03:00
return field
}
}
2016-03-07 07:15:15 +03:00
return scope . PrimaryFields ( ) [ 0 ]
2015-02-17 09:30:37 +03:00
}
return nil
2014-11-11 06:46:21 +03:00
}
2016-03-07 16:09:05 +03:00
// PrimaryKey get main primary field's db name
2014-11-11 06:46:21 +03:00
func ( scope * Scope ) PrimaryKey ( ) string {
2015-03-11 06:28:30 +03:00
if field := scope . PrimaryField ( ) ; field != nil {
2014-11-11 06:46:21 +03:00
return field . DBName
}
2015-02-17 02:15:34 +03:00
return ""
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// PrimaryKeyZero check main primary field's value is blank or not
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) PrimaryKeyZero ( ) bool {
2015-03-11 06:28:30 +03:00
field := scope . PrimaryField ( )
2015-02-17 17:55:14 +03:00
return field == nil || field . IsBlank
2014-01-26 13:10:33 +04:00
}
2014-01-29 15:14:37 +04:00
// PrimaryKeyValue get the primary key's value
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) PrimaryKeyValue ( ) interface { } {
2015-03-11 06:28:30 +03:00
if field := scope . PrimaryField ( ) ; field != nil && field . Field . IsValid ( ) {
2014-11-11 06:46:21 +03:00
return field . Field . Interface ( )
2014-01-26 13:10:33 +04:00
}
2015-02-17 02:15:34 +03:00
return 0
2014-01-26 13:10:33 +04:00
}
2014-01-29 15:14:37 +04:00
// HasColumn to check if has column
2015-01-19 11:23:33 +03:00
func ( scope * Scope ) HasColumn ( column string ) bool {
2015-02-17 12:40:21 +03:00
for _ , field := range scope . GetStructFields ( ) {
2015-02-17 17:55:14 +03:00
if field . IsNormal && ( field . Name == column || field . DBName == column ) {
return true
2015-02-17 12:40:21 +03:00
}
2014-09-01 13:03:58 +04:00
}
2015-02-17 12:40:21 +03:00
return false
2014-01-28 08:28:44 +04:00
}
2016-03-07 16:09:05 +03:00
// SetColumn to set the column's value, column could be field or field's name/dbname
2014-09-30 16:02:51 +04:00
func ( scope * Scope ) SetColumn ( column interface { } , value interface { } ) error {
2016-02-18 17:24:35 +03:00
var updateAttrs = map [ string ] interface { } { }
if attrs , ok := scope . InstanceGet ( "gorm:update_attrs" ) ; ok {
updateAttrs = attrs . ( map [ string ] interface { } )
defer scope . InstanceSet ( "gorm:update_attrs" , updateAttrs )
}
2014-09-02 15:03:01 +04:00
if field , ok := column . ( * Field ) ; ok {
2016-02-18 17:24:35 +03:00
updateAttrs [ field . DBName ] = value
2014-09-02 15:03:01 +04:00
return field . Set ( value )
2015-04-16 09:08:13 +03:00
} else if name , ok := column . ( string ) ; ok {
2016-03-07 07:15:15 +03:00
var (
dbName = ToDBName ( name )
mostMatchedField * Field
)
for _ , field := range scope . Fields ( ) {
if field . DBName == value {
updateAttrs [ field . DBName ] = value
return field . Set ( value )
}
if ( field . DBName == dbName ) || ( field . Name == name && mostMatchedField == nil ) {
mostMatchedField = field
}
2014-08-30 18:39:28 +04:00
}
2015-04-16 09:08:13 +03:00
2016-03-07 07:15:15 +03:00
if mostMatchedField != nil {
updateAttrs [ mostMatchedField . DBName ] = value
return mostMatchedField . Set ( value )
2015-04-16 09:08:13 +03:00
}
2014-08-30 18:39:28 +04:00
}
2014-09-30 16:02:51 +04:00
return errors . New ( "could not convert column to field" )
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
2016-01-17 16:35:32 +03:00
func ( scope * Scope ) CallMethod ( methodName string ) {
if scope . Value == nil {
return
}
if indirectScopeValue := scope . IndirectValue ( ) ; indirectScopeValue . Kind ( ) == reflect . Slice {
for i := 0 ; i < indirectScopeValue . Len ( ) ; i ++ {
scope . callMethod ( methodName , indirectScopeValue . Index ( i ) )
2014-01-28 05:25:30 +04:00
}
} else {
2016-01-17 16:35:32 +03:00
scope . callMethod ( methodName , indirectScopeValue )
2014-01-28 05:25:30 +04:00
}
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// AddToVars add value as sql's vars, used to prevent SQL injection
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) AddToVars ( value interface { } ) string {
2015-02-24 17:06:35 +03:00
if expr , ok := value . ( * expr ) ; ok {
exp := expr . expr
for _ , arg := range expr . args {
exp = strings . Replace ( exp , "?" , scope . AddToVars ( arg ) , 1 )
}
return exp
}
2016-01-15 16:03:35 +03:00
2016-03-07 09:54:20 +03:00
scope . SQLVars = append ( scope . SQLVars , value )
return scope . Dialect ( ) . BindVar ( len ( scope . SQLVars ) )
2014-01-26 08:41:37 +04:00
}
2016-03-08 17:29:58 +03:00
// SelectAttrs return selected attributes
func ( scope * Scope ) SelectAttrs ( ) [ ] string {
if scope . selectAttrs == nil {
attrs := [ ] string { }
for _ , value := range scope . Search . selects {
if str , ok := value . ( string ) ; ok {
attrs = append ( attrs , str )
} else if strs , ok := value . ( [ ] string ) ; ok {
attrs = append ( attrs , strs ... )
} else if strs , ok := value . ( [ ] interface { } ) ; ok {
for _ , str := range strs {
attrs = append ( attrs , fmt . Sprintf ( "%v" , str ) )
}
}
}
scope . selectAttrs = & attrs
}
return * scope . selectAttrs
}
// OmitAttrs return omitted attributes
func ( scope * Scope ) OmitAttrs ( ) [ ] string {
return scope . Search . omits
}
2015-04-08 06:36:01 +03:00
type tabler interface {
TableName ( ) string
}
type dbTabler interface {
TableName ( * DB ) string
}
2016-01-13 09:58:30 +03:00
// TableName return table name
2014-01-26 08:41:37 +04:00
func ( scope * Scope ) TableName ( ) string {
2015-03-12 08:52:29 +03:00
if scope . Search != nil && len ( scope . Search . tableName ) > 0 {
return scope . Search . tableName
2015-02-17 03:34:01 +03:00
}
2015-04-08 06:36:01 +03:00
if tabler , ok := scope . Value . ( tabler ) ; ok {
return tabler . TableName ( )
}
if tabler , ok := scope . Value . ( dbTabler ) ; ok {
return tabler . TableName ( scope . db )
}
2015-06-30 05:39:29 +03:00
return scope . GetModelStruct ( ) . TableName ( scope . db . Model ( scope . Value ) )
2014-01-26 10:18:21 +04:00
}
2016-01-13 09:58:30 +03:00
// QuotedTableName return quoted table name
2015-02-26 11:08:15 +03:00
func ( scope * Scope ) QuotedTableName ( ) ( name string ) {
2015-03-12 08:52:29 +03:00
if scope . Search != nil && len ( scope . Search . tableName ) > 0 {
2015-05-19 05:43:32 +03:00
if strings . Index ( scope . Search . tableName , " " ) != - 1 {
return scope . Search . tableName
}
2015-03-12 08:52:29 +03:00
return scope . Quote ( scope . Search . tableName )
2014-06-03 13:15:05 +04:00
}
2016-01-15 16:03:35 +03:00
return scope . Quote ( scope . TableName ( ) )
2014-06-03 13:15:05 +04:00
}
2016-03-07 09:54:20 +03:00
// CombinedConditionSql return combined condition sql
2014-01-29 04:55:45 +04:00
func ( scope * Scope ) CombinedConditionSql ( ) string {
2016-03-07 09:54:20 +03:00
return scope . joinsSQL ( ) + scope . whereSQL ( ) + scope . groupSQL ( ) +
scope . havingSQL ( ) + scope . orderSQL ( ) + scope . limitAndOffsetSQL ( )
2014-01-26 10:55:41 +04:00
}
2016-03-07 16:09:05 +03:00
// Raw set raw sql
2014-01-28 11:54:19 +04:00
func ( scope * Scope ) Raw ( sql string ) * Scope {
2016-03-07 09:54:20 +03:00
scope . SQL = strings . Replace ( sql , "$$" , "?" , - 1 )
2014-01-28 11:54:19 +04:00
return scope
2014-01-26 08:41:37 +04:00
}
2016-03-07 16:09:05 +03:00
// Exec perform generated SQL
2014-01-28 06:23:31 +04:00
func ( scope * Scope ) Exec ( ) * Scope {
2016-01-13 09:58:30 +03:00
defer scope . trace ( NowFunc ( ) )
2014-01-28 11:54:19 +04:00
2014-01-26 10:18:21 +04:00
if ! scope . HasError ( ) {
2016-03-07 09:54:20 +03:00
if result , err := scope . SQLDB ( ) . Exec ( scope . SQL , scope . SQLVars ... ) ; scope . Err ( err ) == nil {
2015-07-02 22:06:06 +03:00
if count , err := result . RowsAffected ( ) ; scope . Err ( err ) == nil {
2014-06-05 13:58:14 +04:00
scope . db . RowsAffected = count
}
}
2014-01-26 10:18:21 +04:00
}
2014-01-28 06:23:31 +04:00
return scope
2014-01-26 08:41:37 +04:00
}
2014-01-26 13:10:33 +04:00
2014-01-29 15:14:37 +04:00
// Set set value by name
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) Set ( name string , value interface { } ) * Scope {
2014-08-25 13:10:46 +04:00
scope . db . InstantSet ( name , value )
2014-08-20 13:05:02 +04:00
return scope
2014-01-27 07:56:04 +04:00
}
2016-03-07 16:09:05 +03:00
// Get get setting by name
2014-08-20 12:25:01 +04:00
func ( scope * Scope ) Get ( name string ) ( interface { } , bool ) {
return scope . db . Get ( name )
2014-01-29 15:14:37 +04:00
}
2016-01-15 16:03:35 +03:00
// InstanceID get InstanceID for scope
func ( scope * Scope ) InstanceID ( ) string {
if scope . instanceID == "" {
scope . instanceID = fmt . Sprintf ( "%v%v" , & scope , & scope . db )
2014-08-20 13:05:02 +04:00
}
2016-01-15 16:03:35 +03:00
return scope . instanceID
2014-08-20 13:05:02 +04:00
}
2016-03-07 16:09:05 +03:00
// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) InstanceSet ( name string , value interface { } ) * Scope {
2016-01-15 16:03:35 +03:00
return scope . Set ( name + scope . InstanceID ( ) , value )
2014-08-20 13:05:02 +04:00
}
2016-03-07 16:09:05 +03:00
// InstanceGet get instance setting from current operation
2014-08-20 13:05:02 +04:00
func ( scope * Scope ) InstanceGet ( name string ) ( interface { } , bool ) {
2016-01-15 16:03:35 +03:00
return scope . Get ( name + scope . InstanceID ( ) )
2014-08-20 13:05:02 +04:00
}
2014-01-29 15:14:37 +04:00
// Begin start a transaction
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) Begin ( ) * Scope {
2016-03-07 09:54:20 +03:00
if db , ok := scope . SQLDB ( ) . ( sqlDb ) ; ok {
2014-01-27 06:47:37 +04:00
if tx , err := db . Begin ( ) ; err == nil {
scope . db . db = interface { } ( tx ) . ( sqlCommon )
2014-08-20 13:05:02 +04:00
scope . InstanceSet ( "gorm:started_transaction" , true )
2014-01-27 06:47:37 +04:00
}
2014-01-26 13:10:33 +04:00
}
return scope
}
2016-03-07 16:09:05 +03:00
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
2014-01-26 13:10:33 +04:00
func ( scope * Scope ) CommitOrRollback ( ) * Scope {
2014-08-20 13:05:02 +04:00
if _ , ok := scope . InstanceGet ( "gorm:started_transaction" ) ; ok {
2014-01-26 13:10:33 +04:00
if db , ok := scope . db . db . ( sqlTx ) ; ok {
if scope . HasError ( ) {
db . Rollback ( )
} else {
2016-01-04 08:32:35 +03:00
scope . Err ( db . Commit ( ) )
2014-01-26 13:10:33 +04:00
}
scope . db . db = scope . db . parent . db
}
}
return scope
}
2015-03-12 12:47:31 +03:00
2016-03-08 17:29:58 +03:00
////////////////////////////////////////////////////////////////////////////////
// Private Methods For *gorm.Scope
////////////////////////////////////////////////////////////////////////////////
func ( scope * Scope ) callMethod ( methodName string , reflectValue reflect . Value ) {
2016-04-04 17:49:18 +03:00
// Only get address from non-pointer
if reflectValue . CanAddr ( ) && reflectValue . Kind ( ) != reflect . Ptr {
2016-03-08 17:29:58 +03:00
reflectValue = reflectValue . Addr ( )
}
if methodValue := reflectValue . MethodByName ( methodName ) ; methodValue . IsValid ( ) {
switch method := methodValue . Interface ( ) . ( type ) {
case func ( ) :
method ( )
case func ( * Scope ) :
method ( scope )
case func ( * DB ) :
newDB := scope . NewDB ( )
method ( newDB )
scope . Err ( newDB . Error )
case func ( ) error :
scope . Err ( method ( ) )
case func ( * Scope ) error :
scope . Err ( method ( scope ) )
case func ( * DB ) error :
newDB := scope . NewDB ( )
scope . Err ( method ( newDB ) )
scope . Err ( newDB . Error )
default :
scope . Err ( fmt . Errorf ( "unsupported function %v" , methodName ) )
}
}
2015-03-12 12:47:31 +03:00
}
2016-01-13 11:53:04 +03:00
2016-03-23 05:29:52 +03:00
var columnRegexp = regexp . MustCompile ( "^[a-zA-Z]+(\\.[a-zA-Z]+)*$" ) // only match string like `name`, `users.name`
2016-03-08 17:29:58 +03:00
func ( scope * Scope ) quoteIfPossible ( str string ) string {
2016-03-23 05:29:52 +03:00
if columnRegexp . MatchString ( str ) {
2016-03-08 17:29:58 +03:00
return scope . Quote ( str )
}
return str
}
2016-03-08 16:45:20 +03:00
2016-03-10 12:13:48 +03:00
func ( scope * Scope ) scan ( rows * sql . Rows , columns [ ] string , fields [ ] * Field ) {
var (
ignored interface { }
values = make ( [ ] interface { } , len ( columns ) )
2016-09-13 04:29:36 +03:00
selectFields [ ] * Field
2016-03-10 12:13:48 +03:00
selectedColumnsMap = map [ string ] int { }
2016-09-13 04:29:36 +03:00
resetFields = map [ int ] * Field { }
2016-03-10 12:13:48 +03:00
)
2016-01-13 11:53:04 +03:00
for index , column := range columns {
2016-03-10 12:13:48 +03:00
values [ index ] = & ignored
selectFields = fields
if idx , ok := selectedColumnsMap [ column ] ; ok {
2016-03-10 12:35:19 +03:00
selectFields = selectFields [ idx + 1 : ]
2016-03-10 12:13:48 +03:00
}
2016-03-10 12:35:19 +03:00
for fieldIndex , field := range selectFields {
2016-03-10 12:13:48 +03:00
if field . DBName == column {
if field . Field . Kind ( ) == reflect . Ptr {
values [ index ] = field . Field . Addr ( ) . Interface ( )
} else {
reflectValue := reflect . New ( reflect . PtrTo ( field . Struct . Type ) )
reflectValue . Elem ( ) . Set ( field . Field . Addr ( ) )
values [ index ] = reflectValue . Interface ( )
2016-09-13 04:29:36 +03:00
resetFields [ index ] = field
2016-03-10 12:13:48 +03:00
}
2016-03-10 12:35:19 +03:00
selectedColumnsMap [ column ] = fieldIndex
2016-09-13 04:29:36 +03:00
if field . IsNormal {
break
}
2016-01-13 11:53:04 +03:00
}
}
}
scope . Err ( rows . Scan ( values ... ) )
2016-09-13 04:29:36 +03:00
for index , field := range resetFields {
2016-03-10 12:13:48 +03:00
if v := reflect . ValueOf ( values [ index ] ) . Elem ( ) . Elem ( ) ; v . IsValid ( ) {
field . Field . Set ( v )
2016-01-13 11:53:04 +03:00
}
}
}
2016-03-08 17:29:58 +03:00
2016-03-08 16:45:20 +03:00
func ( scope * Scope ) primaryCondition ( value interface { } ) string {
2016-08-13 16:23:18 +03:00
return fmt . Sprintf ( "(%v.%v = %v)" , scope . QuotedTableName ( ) , scope . Quote ( scope . PrimaryKey ( ) ) , value )
2016-03-08 16:45:20 +03:00
}
func ( scope * Scope ) buildWhereCondition ( clause map [ string ] interface { } ) ( str string ) {
switch value := clause [ "query" ] . ( type ) {
case string :
// if string is number
if regexp . MustCompile ( "^\\s*\\d+\\s*$" ) . MatchString ( value ) {
return scope . primaryCondition ( scope . AddToVars ( value ) )
} else if value != "" {
str = fmt . Sprintf ( "(%v)" , value )
}
case int , int8 , int16 , int32 , int64 , uint , uint8 , uint16 , uint32 , uint64 , sql . NullInt64 :
return scope . primaryCondition ( scope . AddToVars ( value ) )
case [ ] int , [ ] int8 , [ ] int16 , [ ] int32 , [ ] int64 , [ ] uint , [ ] uint8 , [ ] uint16 , [ ] uint32 , [ ] uint64 , [ ] string , [ ] interface { } :
2016-08-13 16:23:18 +03:00
str = fmt . Sprintf ( "(%v.%v IN (?))" , scope . QuotedTableName ( ) , scope . Quote ( scope . PrimaryKey ( ) ) )
2016-03-08 16:45:20 +03:00
clause [ "args" ] = [ ] interface { } { value }
case map [ string ] interface { } :
var sqls [ ] string
for key , value := range value {
if value != nil {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v = %v)" , scope . QuotedTableName ( ) , scope . Quote ( key ) , scope . AddToVars ( value ) ) )
2016-03-08 16:45:20 +03:00
} else {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v IS NULL)" , scope . QuotedTableName ( ) , scope . Quote ( key ) ) )
2016-03-08 16:45:20 +03:00
}
}
return strings . Join ( sqls , " AND " )
case interface { } :
var sqls [ ] string
2016-08-13 16:23:18 +03:00
newScope := scope . New ( value )
for _ , field := range newScope . Fields ( ) {
2016-03-08 16:45:20 +03:00
if ! field . IsIgnored && ! field . IsBlank {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v = %v)" , newScope . QuotedTableName ( ) , scope . Quote ( field . DBName ) , scope . AddToVars ( field . Field . Interface ( ) ) ) )
2016-03-08 16:45:20 +03:00
}
}
return strings . Join ( sqls , " AND " )
}
args := clause [ "args" ] . ( [ ] interface { } )
for _ , arg := range args {
switch reflect . ValueOf ( arg ) . Kind ( ) {
case reflect . Slice : // For where("id in (?)", []int64{1,2})
if bytes , ok := arg . ( [ ] byte ) ; ok {
str = strings . Replace ( str , "?" , scope . AddToVars ( bytes ) , 1 )
} else if values := reflect . ValueOf ( arg ) ; values . Len ( ) > 0 {
var tempMarks [ ] string
for i := 0 ; i < values . Len ( ) ; i ++ {
tempMarks = append ( tempMarks , scope . AddToVars ( values . Index ( i ) . Interface ( ) ) )
}
str = strings . Replace ( str , "?" , strings . Join ( tempMarks , "," ) , 1 )
} else {
str = strings . Replace ( str , "?" , scope . AddToVars ( Expr ( "NULL" ) ) , 1 )
}
default :
if valuer , ok := interface { } ( arg ) . ( driver . Valuer ) ; ok {
arg , _ = valuer . Value ( )
}
str = strings . Replace ( str , "?" , scope . AddToVars ( arg ) , 1 )
}
}
return
}
func ( scope * Scope ) buildNotCondition ( clause map [ string ] interface { } ) ( str string ) {
var notEqualSQL string
var primaryKey = scope . PrimaryKey ( )
switch value := clause [ "query" ] . ( type ) {
case string :
// is number
if regexp . MustCompile ( "^\\s*\\d+\\s*$" ) . MatchString ( value ) {
id , _ := strconv . Atoi ( value )
2016-08-13 16:23:18 +03:00
return fmt . Sprintf ( "(%v <> %v)" , scope . Quote ( primaryKey ) , id )
2016-03-08 16:45:20 +03:00
} else if regexp . MustCompile ( "(?i) (=|<>|>|<|LIKE|IS|IN) " ) . MatchString ( value ) {
str = fmt . Sprintf ( " NOT (%v) " , value )
notEqualSQL = fmt . Sprintf ( "NOT (%v)" , value )
} else {
2016-08-13 16:23:18 +03:00
str = fmt . Sprintf ( "(%v.%v NOT IN (?))" , scope . QuotedTableName ( ) , scope . Quote ( value ) )
notEqualSQL = fmt . Sprintf ( "(%v.%v <> ?)" , scope . QuotedTableName ( ) , scope . Quote ( value ) )
2016-03-08 16:45:20 +03:00
}
case int , int8 , int16 , int32 , int64 , uint , uint8 , uint16 , uint32 , uint64 , sql . NullInt64 :
2016-08-13 16:23:18 +03:00
return fmt . Sprintf ( "(%v.%v <> %v)" , scope . QuotedTableName ( ) , scope . Quote ( primaryKey ) , value )
2016-03-08 16:45:20 +03:00
case [ ] int , [ ] int8 , [ ] int16 , [ ] int32 , [ ] int64 , [ ] uint , [ ] uint8 , [ ] uint16 , [ ] uint32 , [ ] uint64 , [ ] string :
if reflect . ValueOf ( value ) . Len ( ) > 0 {
2016-08-13 16:23:18 +03:00
str = fmt . Sprintf ( "(%v.%v NOT IN (?))" , scope . QuotedTableName ( ) , scope . Quote ( primaryKey ) )
2016-03-08 16:45:20 +03:00
clause [ "args" ] = [ ] interface { } { value }
}
return ""
case map [ string ] interface { } :
var sqls [ ] string
for key , value := range value {
if value != nil {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v <> %v)" , scope . QuotedTableName ( ) , scope . Quote ( key ) , scope . AddToVars ( value ) ) )
2016-03-08 16:45:20 +03:00
} else {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v IS NOT NULL)" , scope . QuotedTableName ( ) , scope . Quote ( key ) ) )
2016-03-08 16:45:20 +03:00
}
}
return strings . Join ( sqls , " AND " )
case interface { } :
var sqls [ ] string
2016-08-13 16:23:18 +03:00
var newScope = scope . New ( value )
for _ , field := range newScope . Fields ( ) {
2016-03-08 16:45:20 +03:00
if ! field . IsBlank {
2016-08-13 16:23:18 +03:00
sqls = append ( sqls , fmt . Sprintf ( "(%v.%v <> %v)" , newScope . QuotedTableName ( ) , scope . Quote ( field . DBName ) , scope . AddToVars ( field . Field . Interface ( ) ) ) )
2016-03-08 16:45:20 +03:00
}
}
return strings . Join ( sqls , " AND " )
}
args := clause [ "args" ] . ( [ ] interface { } )
for _ , arg := range args {
switch reflect . ValueOf ( arg ) . Kind ( ) {
case reflect . Slice : // For where("id in (?)", []int64{1,2})
if bytes , ok := arg . ( [ ] byte ) ; ok {
str = strings . Replace ( str , "?" , scope . AddToVars ( bytes ) , 1 )
} else if values := reflect . ValueOf ( arg ) ; values . Len ( ) > 0 {
var tempMarks [ ] string
for i := 0 ; i < values . Len ( ) ; i ++ {
tempMarks = append ( tempMarks , scope . AddToVars ( values . Index ( i ) . Interface ( ) ) )
}
str = strings . Replace ( str , "?" , strings . Join ( tempMarks , "," ) , 1 )
} else {
str = strings . Replace ( str , "?" , scope . AddToVars ( Expr ( "NULL" ) ) , 1 )
}
default :
if scanner , ok := interface { } ( arg ) . ( driver . Valuer ) ; ok {
arg , _ = scanner . Value ( )
}
str = strings . Replace ( notEqualSQL , "?" , scope . AddToVars ( arg ) , 1 )
}
}
return
}
func ( scope * Scope ) buildSelectQuery ( clause map [ string ] interface { } ) ( str string ) {
switch value := clause [ "query" ] . ( type ) {
case string :
str = value
case [ ] string :
str = strings . Join ( value , ", " )
}
args := clause [ "args" ] . ( [ ] interface { } )
for _ , arg := range args {
switch reflect . ValueOf ( arg ) . Kind ( ) {
case reflect . Slice :
values := reflect . ValueOf ( arg )
var tempMarks [ ] string
for i := 0 ; i < values . Len ( ) ; i ++ {
tempMarks = append ( tempMarks , scope . AddToVars ( values . Index ( i ) . Interface ( ) ) )
}
str = strings . Replace ( str , "?" , strings . Join ( tempMarks , "," ) , 1 )
default :
if valuer , ok := interface { } ( arg ) . ( driver . Valuer ) ; ok {
arg , _ = valuer . Value ( )
}
str = strings . Replace ( str , "?" , scope . AddToVars ( arg ) , 1 )
}
}
return
}
func ( scope * Scope ) whereSQL ( ) ( sql string ) {
var (
quotedTableName = scope . QuotedTableName ( )
primaryConditions , andConditions , orConditions [ ] string
)
if ! scope . Search . Unscoped && scope . HasColumn ( "deleted_at" ) {
sql := fmt . Sprintf ( "%v.deleted_at IS NULL" , quotedTableName )
primaryConditions = append ( primaryConditions , sql )
}
if ! scope . PrimaryKeyZero ( ) {
for _ , field := range scope . PrimaryFields ( ) {
sql := fmt . Sprintf ( "%v.%v = %v" , quotedTableName , scope . Quote ( field . DBName ) , scope . AddToVars ( field . Field . Interface ( ) ) )
primaryConditions = append ( primaryConditions , sql )
}
}
for _ , clause := range scope . Search . whereConditions {
if sql := scope . buildWhereCondition ( clause ) ; sql != "" {
andConditions = append ( andConditions , sql )
}
}
for _ , clause := range scope . Search . orConditions {
if sql := scope . buildWhereCondition ( clause ) ; sql != "" {
orConditions = append ( orConditions , sql )
}
}
for _ , clause := range scope . Search . notConditions {
if sql := scope . buildNotCondition ( clause ) ; sql != "" {
andConditions = append ( andConditions , sql )
}
}
orSQL := strings . Join ( orConditions , " OR " )
combinedSQL := strings . Join ( andConditions , " AND " )
if len ( combinedSQL ) > 0 {
if len ( orSQL ) > 0 {
combinedSQL = combinedSQL + " OR " + orSQL
}
} else {
combinedSQL = orSQL
}
if len ( primaryConditions ) > 0 {
sql = "WHERE " + strings . Join ( primaryConditions , " AND " )
if len ( combinedSQL ) > 0 {
sql = sql + " AND (" + combinedSQL + ")"
}
} else if len ( combinedSQL ) > 0 {
sql = "WHERE " + combinedSQL
}
return
}
func ( scope * Scope ) selectSQL ( ) string {
if len ( scope . Search . selects ) == 0 {
2016-03-14 03:05:36 +03:00
if len ( scope . Search . joinConditions ) > 0 {
return fmt . Sprintf ( "%v.*" , scope . QuotedTableName ( ) )
}
2016-03-08 16:45:20 +03:00
return "*"
}
return scope . buildSelectQuery ( scope . Search . selects )
}
func ( scope * Scope ) orderSQL ( ) string {
if len ( scope . Search . orders ) == 0 || scope . Search . countingQuery {
return ""
}
2016-03-23 05:29:52 +03:00
var orders [ ] string
for _ , order := range scope . Search . orders {
2016-06-28 06:15:42 +03:00
if str , ok := order . ( string ) ; ok {
orders = append ( orders , scope . quoteIfPossible ( str ) )
} else if expr , ok := order . ( * expr ) ; ok {
exp := expr . expr
for _ , arg := range expr . args {
exp = strings . Replace ( exp , "?" , scope . AddToVars ( arg ) , 1 )
}
orders = append ( orders , exp )
}
2016-03-23 05:29:52 +03:00
}
return " ORDER BY " + strings . Join ( orders , "," )
2016-03-08 16:45:20 +03:00
}
func ( scope * Scope ) limitAndOffsetSQL ( ) string {
return scope . Dialect ( ) . LimitAndOffsetSQL ( scope . Search . limit , scope . Search . offset )
}
func ( scope * Scope ) groupSQL ( ) string {
if len ( scope . Search . group ) == 0 {
return ""
}
return " GROUP BY " + scope . Search . group
}
func ( scope * Scope ) havingSQL ( ) string {
if len ( scope . Search . havingConditions ) == 0 {
return ""
}
var andConditions [ ] string
for _ , clause := range scope . Search . havingConditions {
if sql := scope . buildWhereCondition ( clause ) ; sql != "" {
andConditions = append ( andConditions , sql )
}
}
combinedSQL := strings . Join ( andConditions , " AND " )
if len ( combinedSQL ) == 0 {
return ""
}
return " HAVING " + combinedSQL
}
func ( scope * Scope ) joinsSQL ( ) string {
var joinConditions [ ] string
for _ , clause := range scope . Search . joinConditions {
if sql := scope . buildWhereCondition ( clause ) ; sql != "" {
joinConditions = append ( joinConditions , strings . TrimSuffix ( strings . TrimPrefix ( sql , "(" ) , ")" ) )
}
}
return strings . Join ( joinConditions , " " ) + " "
}
func ( scope * Scope ) prepareQuerySQL ( ) {
if scope . Search . raw {
scope . Raw ( strings . TrimSuffix ( strings . TrimPrefix ( scope . CombinedConditionSql ( ) , " WHERE (" ) , ")" ) )
} else {
scope . Raw ( fmt . Sprintf ( "SELECT %v FROM %v %v" , scope . selectSQL ( ) , scope . QuotedTableName ( ) , scope . CombinedConditionSql ( ) ) )
}
return
}
func ( scope * Scope ) inlineCondition ( values ... interface { } ) * Scope {
if len ( values ) > 0 {
scope . Search . Where ( values [ 0 ] , values [ 1 : ] ... )
}
return scope
}
func ( scope * Scope ) callCallbacks ( funcs [ ] * func ( s * Scope ) ) * Scope {
for _ , f := range funcs {
( * f ) ( scope )
if scope . skipLeft {
break
}
}
return scope
}
2016-08-14 11:10:30 +03:00
func convertInterfaceToMap ( values interface { } , withIgnoredField bool ) map [ string ] interface { } {
2016-03-09 11:18:01 +03:00
var attrs = map [ string ] interface { } { }
switch value := values . ( type ) {
case map [ string ] interface { } :
return value
case [ ] interface { } :
for _ , v := range value {
2016-08-14 11:10:30 +03:00
for key , value := range convertInterfaceToMap ( v , withIgnoredField ) {
2016-03-09 11:18:01 +03:00
attrs [ key ] = value
}
}
case interface { } :
reflectValue := reflect . ValueOf ( values )
switch reflectValue . Kind ( ) {
case reflect . Map :
for _ , key := range reflectValue . MapKeys ( ) {
attrs [ ToDBName ( key . Interface ( ) . ( string ) ) ] = reflectValue . MapIndex ( key ) . Interface ( )
}
default :
for _ , field := range ( & Scope { Value : values } ) . Fields ( ) {
2016-08-14 11:10:30 +03:00
if ! field . IsBlank && ( withIgnoredField || ! field . IsIgnored ) {
2016-03-09 11:18:01 +03:00
attrs [ field . DBName ] = field . Field . Interface ( )
}
}
}
}
return attrs
}
func ( scope * Scope ) updatedAttrsWithValues ( value interface { } ) ( results map [ string ] interface { } , hasUpdate bool ) {
2016-03-08 16:45:20 +03:00
if scope . IndirectValue ( ) . Kind ( ) != reflect . Struct {
2016-08-14 11:10:30 +03:00
return convertInterfaceToMap ( value , false ) , true
2016-03-08 16:45:20 +03:00
}
results = map [ string ] interface { } { }
2016-03-09 11:18:01 +03:00
2016-08-14 11:10:30 +03:00
for key , value := range convertInterfaceToMap ( value , true ) {
2016-03-08 16:45:20 +03:00
if field , ok := scope . FieldByName ( key ) ; ok && scope . changeableField ( field ) {
2016-03-09 11:18:01 +03:00
if _ , ok := value . ( * expr ) ; ok {
hasUpdate = true
results [ field . DBName ] = value
2016-03-08 16:45:20 +03:00
} else {
2016-04-04 16:33:11 +03:00
err := field . Set ( value )
2016-03-09 11:18:01 +03:00
if field . IsNormal {
hasUpdate = true
2016-04-04 16:33:11 +03:00
if err == ErrUnaddressable {
results [ field . DBName ] = value
} else {
results [ field . DBName ] = field . Field . Interface ( )
}
2016-03-09 11:18:01 +03:00
}
2016-03-08 16:45:20 +03:00
}
}
}
return
}
func ( scope * Scope ) row ( ) * sql . Row {
defer scope . trace ( NowFunc ( ) )
scope . callCallbacks ( scope . db . parent . callbacks . rowQueries )
scope . prepareQuerySQL ( )
return scope . SQLDB ( ) . QueryRow ( scope . SQL , scope . SQLVars ... )
}
func ( scope * Scope ) rows ( ) ( * sql . Rows , error ) {
defer scope . trace ( NowFunc ( ) )
scope . callCallbacks ( scope . db . parent . callbacks . rowQueries )
scope . prepareQuerySQL ( )
return scope . SQLDB ( ) . Query ( scope . SQL , scope . SQLVars ... )
}
func ( scope * Scope ) initialize ( ) * Scope {
for _ , clause := range scope . Search . whereConditions {
2016-03-09 11:18:01 +03:00
scope . updatedAttrsWithValues ( clause [ "query" ] )
2016-03-08 16:45:20 +03:00
}
2016-03-09 11:18:01 +03:00
scope . updatedAttrsWithValues ( scope . Search . initAttrs )
scope . updatedAttrsWithValues ( scope . Search . assignAttrs )
2016-03-08 16:45:20 +03:00
return scope
}
func ( scope * Scope ) pluck ( column string , value interface { } ) * Scope {
dest := reflect . Indirect ( reflect . ValueOf ( value ) )
scope . Search . Select ( column )
if dest . Kind ( ) != reflect . Slice {
scope . Err ( fmt . Errorf ( "results should be a slice, not %s" , dest . Kind ( ) ) )
return scope
}
rows , err := scope . rows ( )
if scope . Err ( err ) == nil {
defer rows . Close ( )
for rows . Next ( ) {
elem := reflect . New ( dest . Type ( ) . Elem ( ) ) . Interface ( )
scope . Err ( rows . Scan ( elem ) )
dest . Set ( reflect . Append ( dest , reflect . ValueOf ( elem ) . Elem ( ) ) )
}
}
return scope
}
func ( scope * Scope ) count ( value interface { } ) * Scope {
2016-08-15 16:28:07 +03:00
if query , ok := scope . Search . selects [ "query" ] ; ! ok || ! regexp . MustCompile ( "(?i)^count(.+)$" ) . MatchString ( fmt . Sprint ( query ) ) {
2016-08-14 10:15:09 +03:00
scope . Search . Select ( "count(*)" )
}
2016-03-08 16:45:20 +03:00
scope . Search . countingQuery = true
scope . Err ( scope . row ( ) . Scan ( value ) )
return scope
}
func ( scope * Scope ) typeName ( ) string {
typ := scope . IndirectValue ( ) . Type ( )
for typ . Kind ( ) == reflect . Slice || typ . Kind ( ) == reflect . Ptr {
typ = typ . Elem ( )
}
return typ . Name ( )
}
// trace print sql log
func ( scope * Scope ) trace ( t time . Time ) {
if len ( scope . SQL ) > 0 {
scope . db . slog ( scope . SQL , t , scope . SQLVars ... )
}
}
func ( scope * Scope ) changeableField ( field * Field ) bool {
if selectAttrs := scope . SelectAttrs ( ) ; len ( selectAttrs ) > 0 {
for _ , attr := range selectAttrs {
if field . Name == attr || field . DBName == attr {
return true
}
}
return false
}
for _ , attr := range scope . OmitAttrs ( ) {
if field . Name == attr || field . DBName == attr {
return false
}
}
return true
}
func ( scope * Scope ) shouldSaveAssociations ( ) bool {
if saveAssociations , ok := scope . Get ( "gorm:save_associations" ) ; ok && ! saveAssociations . ( bool ) {
return false
}
return true && ! scope . HasError ( )
}
func ( scope * Scope ) related ( value interface { } , foreignKeys ... string ) * Scope {
toScope := scope . db . NewScope ( value )
for _ , foreignKey := range append ( foreignKeys , toScope . typeName ( ) + "Id" , scope . typeName ( ) + "Id" ) {
fromField , _ := scope . FieldByName ( foreignKey )
toField , _ := toScope . FieldByName ( foreignKey )
if fromField != nil {
if relationship := fromField . Relationship ; relationship != nil {
if relationship . Kind == "many_to_many" {
joinTableHandler := relationship . JoinTableHandler
scope . Err ( joinTableHandler . JoinWith ( joinTableHandler , toScope . db , scope . Value ) . Find ( value ) . Error )
} else if relationship . Kind == "belongs_to" {
query := toScope . db
for idx , foreignKey := range relationship . ForeignDBNames {
if field , ok := scope . FieldByName ( foreignKey ) ; ok {
query = query . Where ( fmt . Sprintf ( "%v = ?" , scope . Quote ( relationship . AssociationForeignDBNames [ idx ] ) ) , field . Field . Interface ( ) )
}
}
scope . Err ( query . Find ( value ) . Error )
} else if relationship . Kind == "has_many" || relationship . Kind == "has_one" {
query := toScope . db
for idx , foreignKey := range relationship . ForeignDBNames {
if field , ok := scope . FieldByName ( relationship . AssociationForeignDBNames [ idx ] ) ; ok {
query = query . Where ( fmt . Sprintf ( "%v = ?" , scope . Quote ( foreignKey ) ) , field . Field . Interface ( ) )
}
}
if relationship . PolymorphicType != "" {
2016-09-28 23:44:43 +03:00
value := scope . TableName ( )
if relationship . PolymorphicValue != "" {
value = relationship . PolymorphicValue
}
query = query . Where ( fmt . Sprintf ( "%v = ?" , scope . Quote ( relationship . PolymorphicDBName ) ) , value )
2016-03-08 16:45:20 +03:00
}
scope . Err ( query . Find ( value ) . Error )
}
} else {
sql := fmt . Sprintf ( "%v = ?" , scope . Quote ( toScope . PrimaryKey ( ) ) )
scope . Err ( toScope . db . Where ( sql , fromField . Field . Interface ( ) ) . Find ( value ) . Error )
}
return scope
} else if toField != nil {
sql := fmt . Sprintf ( "%v = ?" , scope . Quote ( toField . DBName ) )
scope . Err ( toScope . db . Where ( sql , scope . PrimaryKeyValue ( ) ) . Find ( value ) . Error )
return scope
}
}
scope . Err ( fmt . Errorf ( "invalid association %v" , foreignKeys ) )
return scope
}
// getTableOptions return the table options string or an empty string if the table options does not exist
func ( scope * Scope ) getTableOptions ( ) string {
tableOptions , ok := scope . Get ( "gorm:table_options" )
if ! ok {
return ""
}
return tableOptions . ( string )
}
func ( scope * Scope ) createJoinTable ( field * StructField ) {
if relationship := field . Relationship ; relationship != nil && relationship . JoinTableHandler != nil {
joinTableHandler := relationship . JoinTableHandler
joinTable := joinTableHandler . Table ( scope . db )
if ! scope . Dialect ( ) . HasTable ( joinTable ) {
toScope := & Scope { Value : reflect . New ( field . Struct . Type ) . Interface ( ) }
var sqlTypes , primaryKeys [ ] string
for idx , fieldName := range relationship . ForeignFieldNames {
if field , ok := scope . FieldByName ( fieldName ) ; ok {
foreignKeyStruct := field . clone ( )
foreignKeyStruct . IsPrimaryKey = false
foreignKeyStruct . TagSettings [ "IS_JOINTABLE_FOREIGNKEY" ] = "true"
2016-05-09 17:32:33 +03:00
delete ( foreignKeyStruct . TagSettings , "AUTO_INCREMENT" )
2016-03-08 16:45:20 +03:00
sqlTypes = append ( sqlTypes , scope . Quote ( relationship . ForeignDBNames [ idx ] ) + " " + scope . Dialect ( ) . DataTypeOf ( foreignKeyStruct ) )
primaryKeys = append ( primaryKeys , scope . Quote ( relationship . ForeignDBNames [ idx ] ) )
}
}
for idx , fieldName := range relationship . AssociationForeignFieldNames {
if field , ok := toScope . FieldByName ( fieldName ) ; ok {
foreignKeyStruct := field . clone ( )
foreignKeyStruct . IsPrimaryKey = false
foreignKeyStruct . TagSettings [ "IS_JOINTABLE_FOREIGNKEY" ] = "true"
2016-05-09 17:32:33 +03:00
delete ( foreignKeyStruct . TagSettings , "AUTO_INCREMENT" )
2016-03-08 16:45:20 +03:00
sqlTypes = append ( sqlTypes , scope . Quote ( relationship . AssociationForeignDBNames [ idx ] ) + " " + scope . Dialect ( ) . DataTypeOf ( foreignKeyStruct ) )
primaryKeys = append ( primaryKeys , scope . Quote ( relationship . AssociationForeignDBNames [ idx ] ) )
}
}
scope . Err ( scope . NewDB ( ) . Exec ( fmt . Sprintf ( "CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s" , scope . Quote ( joinTable ) , strings . Join ( sqlTypes , "," ) , strings . Join ( primaryKeys , "," ) , scope . getTableOptions ( ) ) ) . Error )
}
scope . NewDB ( ) . Table ( joinTable ) . AutoMigrate ( joinTableHandler )
}
}
func ( scope * Scope ) createTable ( ) * Scope {
var tags [ ] string
var primaryKeys [ ] string
var primaryKeyInColumnType = false
for _ , field := range scope . GetModelStruct ( ) . StructFields {
if field . IsNormal {
sqlTag := scope . Dialect ( ) . DataTypeOf ( field )
// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
// one column as the primary key.
if strings . Contains ( strings . ToLower ( sqlTag ) , "primary key" ) {
primaryKeyInColumnType = true
}
tags = append ( tags , scope . Quote ( field . DBName ) + " " + sqlTag )
}
if field . IsPrimaryKey {
primaryKeys = append ( primaryKeys , scope . Quote ( field . DBName ) )
}
scope . createJoinTable ( field )
}
var primaryKeyStr string
if len ( primaryKeys ) > 0 && ! primaryKeyInColumnType {
primaryKeyStr = fmt . Sprintf ( ", PRIMARY KEY (%v)" , strings . Join ( primaryKeys , "," ) )
}
scope . Raw ( fmt . Sprintf ( "CREATE TABLE %v (%v %v) %s" , scope . QuotedTableName ( ) , strings . Join ( tags , "," ) , primaryKeyStr , scope . getTableOptions ( ) ) ) . Exec ( )
scope . autoIndex ( )
return scope
}
func ( scope * Scope ) dropTable ( ) * Scope {
scope . Raw ( fmt . Sprintf ( "DROP TABLE %v" , scope . QuotedTableName ( ) ) ) . Exec ( )
return scope
}
func ( scope * Scope ) modifyColumn ( column string , typ string ) {
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v MODIFY %v %v" , scope . QuotedTableName ( ) , scope . Quote ( column ) , typ ) ) . Exec ( )
}
func ( scope * Scope ) dropColumn ( column string ) {
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v DROP COLUMN %v" , scope . QuotedTableName ( ) , scope . Quote ( column ) ) ) . Exec ( )
}
func ( scope * Scope ) addIndex ( unique bool , indexName string , column ... string ) {
if scope . Dialect ( ) . HasIndex ( scope . TableName ( ) , indexName ) {
return
}
var columns [ ] string
for _ , name := range column {
columns = append ( columns , scope . quoteIfPossible ( name ) )
}
sqlCreate := "CREATE INDEX"
if unique {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope . Raw ( fmt . Sprintf ( "%s %v ON %v(%v) %v" , sqlCreate , indexName , scope . QuotedTableName ( ) , strings . Join ( columns , ", " ) , scope . whereSQL ( ) ) ) . Exec ( )
}
func ( scope * Scope ) addForeignKey ( field string , dest string , onDelete string , onUpdate string ) {
2016-05-22 01:13:26 +03:00
keyName := scope . Dialect ( ) . BuildForeignKeyName ( scope . TableName ( ) , field , dest )
2016-03-08 16:45:20 +03:00
if scope . Dialect ( ) . HasForeignKey ( scope . TableName ( ) , keyName ) {
return
}
var query = ` ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s; `
scope . Raw ( fmt . Sprintf ( query , scope . QuotedTableName ( ) , scope . quoteIfPossible ( keyName ) , scope . quoteIfPossible ( field ) , dest , onDelete , onUpdate ) ) . Exec ( )
}
func ( scope * Scope ) removeIndex ( indexName string ) {
scope . Dialect ( ) . RemoveIndex ( scope . TableName ( ) , indexName )
}
func ( scope * Scope ) autoMigrate ( ) * Scope {
tableName := scope . TableName ( )
quotedTableName := scope . QuotedTableName ( )
if ! scope . Dialect ( ) . HasTable ( tableName ) {
scope . createTable ( )
} else {
for _ , field := range scope . GetModelStruct ( ) . StructFields {
if ! scope . Dialect ( ) . HasColumn ( tableName , field . DBName ) {
if field . IsNormal {
sqlTag := scope . Dialect ( ) . DataTypeOf ( field )
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v ADD %v %v;" , quotedTableName , scope . Quote ( field . DBName ) , sqlTag ) ) . Exec ( )
}
}
scope . createJoinTable ( field )
}
scope . autoIndex ( )
}
return scope
}
func ( scope * Scope ) autoIndex ( ) * Scope {
var indexes = map [ string ] [ ] string { }
var uniqueIndexes = map [ string ] [ ] string { }
for _ , field := range scope . GetStructFields ( ) {
if name , ok := field . TagSettings [ "INDEX" ] ; ok {
2016-06-16 02:06:22 +03:00
names := strings . Split ( name , "," )
for _ , name := range names {
if name == "INDEX" || name == "" {
name = fmt . Sprintf ( "idx_%v_%v" , scope . TableName ( ) , field . DBName )
}
indexes [ name ] = append ( indexes [ name ] , field . DBName )
2016-03-08 16:45:20 +03:00
}
}
if name , ok := field . TagSettings [ "UNIQUE_INDEX" ] ; ok {
2016-06-16 02:06:22 +03:00
names := strings . Split ( name , "," )
for _ , name := range names {
if name == "UNIQUE_INDEX" || name == "" {
name = fmt . Sprintf ( "uix_%v_%v" , scope . TableName ( ) , field . DBName )
}
uniqueIndexes [ name ] = append ( uniqueIndexes [ name ] , field . DBName )
2016-03-08 16:45:20 +03:00
}
}
}
for name , columns := range indexes {
scope . NewDB ( ) . Model ( scope . Value ) . AddIndex ( name , columns ... )
}
for name , columns := range uniqueIndexes {
scope . NewDB ( ) . Model ( scope . Value ) . AddUniqueIndex ( name , columns ... )
}
return scope
}
func ( scope * Scope ) getColumnAsArray ( columns [ ] string , values ... interface { } ) ( results [ ] [ ] interface { } ) {
for _ , value := range values {
indirectValue := reflect . ValueOf ( value )
for indirectValue . Kind ( ) == reflect . Ptr {
indirectValue = indirectValue . Elem ( )
}
switch indirectValue . Kind ( ) {
case reflect . Slice :
for i := 0 ; i < indirectValue . Len ( ) ; i ++ {
var result [ ] interface { }
var object = indirect ( indirectValue . Index ( i ) )
for _ , column := range columns {
result = append ( result , object . FieldByName ( column ) . Interface ( ) )
}
results = append ( results , result )
}
case reflect . Struct :
var result [ ] interface { }
for _ , column := range columns {
result = append ( result , indirectValue . FieldByName ( column ) . Interface ( ) )
}
results = append ( results , result )
}
}
return
}
func ( scope * Scope ) getColumnAsScope ( column string ) * Scope {
indirectScopeValue := scope . IndirectValue ( )
switch indirectScopeValue . Kind ( ) {
case reflect . Slice :
if fieldStruct , ok := scope . GetModelStruct ( ) . ModelType . FieldByName ( column ) ; ok {
fieldType := fieldStruct . Type
if fieldType . Kind ( ) == reflect . Slice || fieldType . Kind ( ) == reflect . Ptr {
fieldType = fieldType . Elem ( )
}
2016-07-10 16:34:37 +03:00
resultsMap := map [ interface { } ] bool { }
2016-03-08 16:45:20 +03:00
results := reflect . New ( reflect . SliceOf ( reflect . PtrTo ( fieldType ) ) ) . Elem ( )
for i := 0 ; i < indirectScopeValue . Len ( ) ; i ++ {
result := indirect ( indirect ( indirectScopeValue . Index ( i ) ) . FieldByName ( column ) )
if result . Kind ( ) == reflect . Slice {
for j := 0 ; j < result . Len ( ) ; j ++ {
2016-07-10 16:34:37 +03:00
if elem := result . Index ( j ) ; elem . CanAddr ( ) && resultsMap [ elem . Addr ( ) ] != true {
resultsMap [ elem . Addr ( ) ] = true
2016-03-08 16:45:20 +03:00
results = reflect . Append ( results , elem . Addr ( ) )
}
}
2016-07-10 16:34:37 +03:00
} else if result . CanAddr ( ) && resultsMap [ result . Addr ( ) ] != true {
resultsMap [ result . Addr ( ) ] = true
2016-03-08 16:45:20 +03:00
results = reflect . Append ( results , result . Addr ( ) )
}
}
return scope . New ( results . Interface ( ) )
}
case reflect . Struct :
if field := indirectScopeValue . FieldByName ( column ) ; field . CanAddr ( ) {
return scope . New ( field . Addr ( ) . Interface ( ) )
}
}
return nil
}