2014-01-26 10:55:41 +04:00
package gorm
import (
"database/sql"
"database/sql/driver"
2014-01-29 08:00:57 +04:00
"errors"
2014-01-26 10:55:41 +04:00
"fmt"
2014-04-02 12:59:07 +04:00
"go/ast"
2014-01-26 10:55:41 +04:00
"reflect"
"regexp"
"strconv"
"strings"
2014-01-29 08:00:57 +04:00
"time"
2014-01-26 10:55:41 +04:00
)
func ( scope * Scope ) primaryCondiation ( value interface { } ) string {
2014-01-28 12:22:41 +04:00
return fmt . Sprintf ( "(%v = %v)" , scope . Quote ( scope . PrimaryKey ( ) ) , value )
2014-01-26 10:55:41 +04:00
}
func ( scope * Scope ) buildWhereCondition ( clause map [ string ] interface { } ) ( str string ) {
switch value := clause [ "query" ] . ( type ) {
case string :
// if string is number
if regexp . MustCompile ( "^\\s*\\d+\\s*$" ) . MatchString ( value ) {
id , _ := strconv . Atoi ( value )
return scope . primaryCondiation ( scope . AddToVars ( id ) )
} else {
str = value
}
case int , int64 , int32 :
return scope . primaryCondiation ( scope . AddToVars ( value ) )
case sql . NullInt64 :
return scope . primaryCondiation ( scope . AddToVars ( value . Int64 ) )
case [ ] int64 , [ ] int , [ ] int32 , [ ] string :
2014-01-28 12:22:41 +04:00
str = fmt . Sprintf ( "(%v in (?))" , scope . Quote ( scope . PrimaryKey ( ) ) )
2014-01-26 10:55:41 +04:00
clause [ "args" ] = [ ] interface { } { value }
case map [ string ] interface { } :
var sqls [ ] string
for key , value := range value {
2014-01-28 12:22:41 +04:00
sqls = append ( sqls , fmt . Sprintf ( "(%v = %v)" , scope . Quote ( key ) , scope . AddToVars ( value ) ) )
2014-01-26 10:55:41 +04:00
}
return strings . Join ( sqls , " AND " )
case interface { } :
var sqls [ ] string
2014-01-28 05:25:30 +04:00
for _ , field := range scope . New ( value ) . Fields ( ) {
if ! field . IsBlank {
2014-01-28 12:22:41 +04:00
sqls = append ( sqls , fmt . Sprintf ( "(%v = %v)" , scope . Quote ( field . DBName ) , scope . AddToVars ( field . Value ) ) )
2014-01-28 05:25:30 +04:00
}
2014-01-26 10:55:41 +04:00
}
return strings . Join ( sqls , " AND " )
}
args := clause [ "args" ] . ( [ ] interface { } )
for _ , arg := range args {
switch reflect . TypeOf ( arg ) . Kind ( ) {
case reflect . Slice : // For where("id in (?)", []int64{1,2})
values := reflect . ValueOf ( arg )
2014-01-28 13:09:43 +04:00
var tempMarks [ ] string
2014-01-26 10:55:41 +04:00
for i := 0 ; i < values . Len ( ) ; i ++ {
2014-01-28 13:09:43 +04:00
tempMarks = append ( tempMarks , scope . AddToVars ( values . Index ( i ) . Interface ( ) ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 13:09:43 +04:00
str = strings . Replace ( str , "?" , strings . Join ( tempMarks , "," ) , 1 )
2014-01-26 10:55:41 +04:00
default :
if valuer , ok := interface { } ( arg ) . ( driver . Valuer ) ; ok {
arg , _ = valuer . Value ( )
}
str = strings . Replace ( str , "?" , scope . AddToVars ( arg ) , 1 )
}
}
return
}
func ( scope * Scope ) buildNotCondition ( clause map [ string ] interface { } ) ( str string ) {
2014-01-28 13:09:43 +04:00
var notEqualSql string
2014-01-26 10:55:41 +04:00
switch value := clause [ "query" ] . ( type ) {
case string :
if regexp . MustCompile ( "^\\s*\\d+\\s*$" ) . MatchString ( value ) {
id , _ := strconv . Atoi ( value )
2014-01-28 12:22:41 +04:00
return fmt . Sprintf ( "(%v <> %v)" , scope . Quote ( scope . PrimaryKey ( ) ) , id )
2014-01-26 10:55:41 +04:00
} else if regexp . MustCompile ( "(?i) (=|<>|>|<|LIKE|IS) " ) . MatchString ( value ) {
str = fmt . Sprintf ( " NOT (%v) " , value )
2014-01-28 13:09:43 +04:00
notEqualSql = fmt . Sprintf ( "NOT (%v)" , value )
2014-01-26 10:55:41 +04:00
} else {
2014-01-28 12:22:41 +04:00
str = fmt . Sprintf ( "(%v NOT IN (?))" , scope . Quote ( value ) )
2014-01-28 13:09:43 +04:00
notEqualSql = fmt . Sprintf ( "(%v <> ?)" , scope . Quote ( value ) )
2014-01-26 10:55:41 +04:00
}
case int , int64 , int32 :
2014-01-28 12:22:41 +04:00
return fmt . Sprintf ( "(%v <> %v)" , scope . Quote ( scope . PrimaryKey ( ) ) , value )
2014-01-26 10:55:41 +04:00
case [ ] int64 , [ ] int , [ ] int32 , [ ] string :
if reflect . ValueOf ( value ) . Len ( ) > 0 {
2014-01-28 12:22:41 +04:00
str = fmt . Sprintf ( "(%v not in (?))" , scope . Quote ( scope . PrimaryKey ( ) ) )
2014-01-26 10:55:41 +04:00
clause [ "args" ] = [ ] interface { } { value }
} else {
return ""
}
case map [ string ] interface { } :
var sqls [ ] string
for key , value := range value {
2014-01-28 12:22:41 +04:00
sqls = append ( sqls , fmt . Sprintf ( "(%v <> %v)" , scope . Quote ( key ) , scope . AddToVars ( value ) ) )
2014-01-26 10:55:41 +04:00
}
return strings . Join ( sqls , " AND " )
case interface { } :
var sqls [ ] string
2014-01-28 05:25:30 +04:00
for _ , field := range scope . New ( value ) . Fields ( ) {
if ! field . IsBlank {
2014-01-28 12:22:41 +04:00
sqls = append ( sqls , fmt . Sprintf ( "(%v <> %v)" , scope . Quote ( field . DBName ) , scope . AddToVars ( field . Value ) ) )
2014-01-28 05:25:30 +04:00
}
2014-01-26 10:55:41 +04:00
}
return strings . Join ( sqls , " AND " )
}
args := clause [ "args" ] . ( [ ] interface { } )
for _ , arg := range args {
switch reflect . TypeOf ( arg ) . Kind ( ) {
case reflect . Slice : // For where("id in (?)", []int64{1,2})
values := reflect . ValueOf ( arg )
2014-01-28 13:09:43 +04:00
var tempMarks [ ] string
2014-01-26 10:55:41 +04:00
for i := 0 ; i < values . Len ( ) ; i ++ {
2014-01-28 13:09:43 +04:00
tempMarks = append ( tempMarks , scope . AddToVars ( values . Index ( i ) . Interface ( ) ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 13:09:43 +04:00
str = strings . Replace ( str , "?" , strings . Join ( tempMarks , "," ) , 1 )
2014-01-26 10:55:41 +04:00
default :
if scanner , ok := interface { } ( arg ) . ( driver . Valuer ) ; ok {
arg , _ = scanner . Value ( )
}
2014-01-28 13:09:43 +04:00
str = strings . Replace ( notEqualSql , "?" , scope . AddToVars ( arg ) , 1 )
2014-01-26 10:55:41 +04:00
}
}
return
}
func ( scope * Scope ) where ( where ... interface { } ) {
if len ( where ) > 0 {
scope . Search = scope . Search . clone ( ) . where ( where [ 0 ] , where [ 1 : ] ... )
}
}
func ( scope * Scope ) whereSql ( ) ( sql string ) {
2014-01-28 13:09:43 +04:00
var primaryCondiations , andConditions , orConditions [ ] string
2014-01-26 10:55:41 +04:00
2014-01-28 12:56:51 +04:00
if ! scope . Search . Unscope && scope . HasColumn ( "DeletedAt" ) {
2014-01-28 13:09:43 +04:00
primaryCondiations = append ( primaryCondiations , "(deleted_at IS NULL OR deleted_at <= '0001-01-02')" )
2014-01-26 10:55:41 +04:00
}
if ! scope . PrimaryKeyZero ( ) {
2014-01-28 13:09:43 +04:00
primaryCondiations = append ( primaryCondiations , scope . primaryCondiation ( scope . AddToVars ( scope . PrimaryKeyValue ( ) ) ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 12:56:51 +04:00
for _ , clause := range scope . Search . WhereConditions {
2014-01-28 13:09:43 +04:00
andConditions = append ( andConditions , scope . buildWhereCondition ( clause ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 12:56:51 +04:00
for _ , clause := range scope . Search . OrConditions {
2014-01-28 13:09:43 +04:00
orConditions = append ( orConditions , scope . buildWhereCondition ( clause ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 12:56:51 +04:00
for _ , clause := range scope . Search . NotConditions {
2014-01-28 13:09:43 +04:00
andConditions = append ( andConditions , scope . buildNotCondition ( clause ) )
2014-01-26 10:55:41 +04:00
}
2014-01-28 13:09:43 +04:00
orSql := strings . Join ( orConditions , " OR " )
combinedSql := strings . Join ( andConditions , " AND " )
if len ( combinedSql ) > 0 {
if len ( orSql ) > 0 {
combinedSql = combinedSql + " OR " + orSql
2014-01-26 10:55:41 +04:00
}
} else {
2014-01-28 13:09:43 +04:00
combinedSql = orSql
2014-01-26 10:55:41 +04:00
}
2014-01-28 13:09:43 +04:00
if len ( primaryCondiations ) > 0 {
sql = "WHERE " + strings . Join ( primaryCondiations , " AND " )
if len ( combinedSql ) > 0 {
sql = sql + " AND (" + combinedSql + ")"
2014-01-26 10:55:41 +04:00
}
2014-01-28 13:09:43 +04:00
} else if len ( combinedSql ) > 0 {
sql = "WHERE " + combinedSql
2014-01-26 10:55:41 +04:00
}
return
}
func ( s * Scope ) selectSql ( ) string {
2014-01-28 12:56:51 +04:00
if len ( s . Search . Select ) == 0 {
2014-01-26 10:55:41 +04:00
return "*"
} else {
2014-01-28 12:56:51 +04:00
return s . Search . Select
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) orderSql ( ) string {
2014-01-28 12:56:51 +04:00
if len ( s . Search . Orders ) == 0 {
2014-01-26 10:55:41 +04:00
return ""
} else {
2014-01-28 12:56:51 +04:00
return " ORDER BY " + strings . Join ( s . Search . Orders , "," )
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) limitSql ( ) string {
2014-01-28 12:56:51 +04:00
if len ( s . Search . Limit ) == 0 {
2014-01-26 10:55:41 +04:00
return ""
} else {
2014-01-28 12:56:51 +04:00
return " LIMIT " + s . Search . Limit
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) offsetSql ( ) string {
2014-01-28 12:56:51 +04:00
if len ( s . Search . Offset ) == 0 {
2014-01-26 10:55:41 +04:00
return ""
} else {
2014-01-28 12:56:51 +04:00
return " OFFSET " + s . Search . Offset
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) groupSql ( ) string {
2014-01-28 12:56:51 +04:00
if len ( s . Search . Group ) == 0 {
2014-01-26 10:55:41 +04:00
return ""
} else {
2014-01-28 12:56:51 +04:00
return " GROUP BY " + s . Search . Group
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) havingSql ( ) string {
2014-01-28 12:56:51 +04:00
if s . Search . HavingCondition == nil {
2014-01-26 10:55:41 +04:00
return ""
} else {
2014-01-28 12:56:51 +04:00
return " HAVING " + s . buildWhereCondition ( s . Search . HavingCondition )
2014-01-26 10:55:41 +04:00
}
}
func ( s * Scope ) joinsSql ( ) string {
2014-01-28 12:56:51 +04:00
return s . Search . Joins + " "
2014-01-26 10:55:41 +04:00
}
2014-01-28 07:37:32 +04:00
func ( scope * Scope ) prepareQuerySql ( ) {
2014-01-28 12:56:51 +04:00
if scope . Search . Raw {
2014-01-28 07:37:32 +04:00
scope . Raw ( strings . TrimLeft ( scope . CombinedConditionSql ( ) , "WHERE " ) )
} else {
scope . Raw ( fmt . Sprintf ( "SELECT %v FROM %v %v" , scope . selectSql ( ) , scope . TableName ( ) , scope . CombinedConditionSql ( ) ) )
}
return
}
2014-01-28 08:28:44 +04:00
func ( scope * Scope ) inlineCondition ( values ... interface { } ) * Scope {
2014-01-28 07:37:32 +04:00
if len ( values ) > 0 {
scope . Search = scope . Search . clone ( ) . where ( values [ 0 ] , values [ 1 : ] ... )
}
return scope
}
2014-01-29 08:00:57 +04:00
func ( scope * Scope ) callCallbacks ( funcs [ ] * func ( s * Scope ) ) * Scope {
for _ , f := range funcs {
( * f ) ( scope )
if scope . skipLeft {
break
}
}
return scope
}
func ( scope * Scope ) updatedAttrsWithValues ( values map [ string ] interface { } , ignoreProtectedAttrs bool ) ( results map [ string ] interface { } , hasUpdate bool ) {
data := reflect . Indirect ( reflect . ValueOf ( scope . Value ) )
if ! data . CanAddr ( ) {
return values , true
}
for key , value := range values {
if field := data . FieldByName ( snakeToUpperCamel ( key ) ) ; field . IsValid ( ) {
2014-01-30 12:41:10 +04:00
func ( ) {
defer func ( ) {
if err := recover ( ) ; err != nil {
hasUpdate = true
setFieldValue ( field , value )
2014-01-29 08:00:57 +04:00
}
2014-01-30 12:41:10 +04:00
} ( )
if field . Interface ( ) != value {
switch field . Kind ( ) {
case reflect . Int , reflect . Int32 , reflect . Int64 :
if s , ok := value . ( string ) ; ok {
i , err := strconv . Atoi ( s )
if scope . Err ( err ) == nil {
value = i
}
}
2014-01-29 08:00:57 +04:00
2014-01-30 12:41:10 +04:00
if field . Int ( ) != reflect . ValueOf ( value ) . Int ( ) {
hasUpdate = true
setFieldValue ( field , value )
}
default :
2014-01-29 08:00:57 +04:00
hasUpdate = true
setFieldValue ( field , value )
}
}
2014-01-30 12:41:10 +04:00
} ( )
2014-01-29 08:00:57 +04:00
}
}
return
}
func ( scope * Scope ) sqlTagForField ( field * Field ) ( tag string ) {
tag , addationalTag , size := parseSqlTag ( field . Tag . Get ( scope . db . parent . tagIdentifier ) )
if tag == "-" {
field . IsIgnored = true
}
value := field . Value
reflectValue := reflect . ValueOf ( value )
switch reflectValue . Kind ( ) {
case reflect . Slice :
if _ , ok := value . ( [ ] byte ) ; ! ok {
return
}
case reflect . Struct :
2014-03-16 05:28:43 +04:00
if field . IsScanner ( ) {
reflectValue = reflectValue . Field ( 0 )
} else if ! field . IsTime ( ) {
2014-01-29 08:00:57 +04:00
return
}
}
2014-02-18 06:03:14 +04:00
if len ( tag ) == 0 {
2014-01-29 08:00:57 +04:00
if field . isPrimaryKey {
2014-03-16 05:28:43 +04:00
tag = scope . Dialect ( ) . PrimaryKeyTag ( reflectValue , size )
2014-01-29 08:00:57 +04:00
} else {
2014-03-16 05:28:43 +04:00
tag = scope . Dialect ( ) . SqlTag ( reflectValue , size )
2014-01-29 08:00:57 +04:00
}
}
if len ( addationalTag ) > 0 {
tag = tag + " " + addationalTag
}
return
}
func ( scope * Scope ) row ( ) * sql . Row {
defer scope . Trace ( time . Now ( ) )
scope . prepareQuerySql ( )
return scope . DB ( ) . QueryRow ( scope . Sql , scope . SqlVars ... )
}
func ( scope * Scope ) rows ( ) ( * sql . Rows , error ) {
defer scope . Trace ( time . Now ( ) )
scope . prepareQuerySql ( )
return scope . DB ( ) . Query ( scope . Sql , scope . SqlVars ... )
}
func ( scope * Scope ) initialize ( ) * Scope {
for _ , clause := range scope . Search . WhereConditions {
scope . updatedAttrsWithValues ( convertInterfaceToMap ( clause [ "query" ] ) , false )
}
scope . updatedAttrsWithValues ( convertInterfaceToMap ( scope . Search . InitAttrs ) , false )
scope . updatedAttrsWithValues ( convertInterfaceToMap ( scope . Search . AssignAttrs ) , false )
return scope
}
func ( scope * Scope ) pluck ( column string , value interface { } ) * Scope {
dest := reflect . Indirect ( reflect . ValueOf ( value ) )
scope . Search = scope . Search . clone ( ) . selects ( column )
if dest . Kind ( ) != reflect . Slice {
scope . Err ( errors . New ( "Results should be a slice" ) )
return scope
}
rows , err := scope . rows ( )
if scope . Err ( err ) == nil {
defer rows . Close ( )
for rows . Next ( ) {
elem := reflect . New ( dest . Type ( ) . Elem ( ) ) . Interface ( )
scope . Err ( rows . Scan ( elem ) )
dest . Set ( reflect . Append ( dest , reflect . ValueOf ( elem ) . Elem ( ) ) )
}
}
return scope
}
func ( scope * Scope ) count ( value interface { } ) * Scope {
scope . Search = scope . Search . clone ( ) . selects ( "count(*)" )
scope . Err ( scope . row ( ) . Scan ( value ) )
return scope
}
func ( scope * Scope ) typeName ( ) string {
value := reflect . Indirect ( reflect . ValueOf ( scope . Value ) )
if value . Kind ( ) == reflect . Slice {
return value . Type ( ) . Elem ( ) . Name ( )
} else {
return value . Type ( ) . Name ( )
}
}
func ( scope * Scope ) related ( value interface { } , foreignKeys ... string ) * Scope {
2014-03-26 04:48:40 +04:00
toScope := scope . db . NewScope ( value )
2014-01-29 08:00:57 +04:00
for _ , foreignKey := range append ( foreignKeys , toScope . typeName ( ) + "Id" , scope . typeName ( ) + "Id" ) {
if foreignValue , ok := scope . FieldByName ( foreignKey ) ; ok {
return toScope . inlineCondition ( foreignValue ) . callCallbacks ( scope . db . parent . callback . queries )
} else if toScope . HasColumn ( foreignKey ) {
sql := fmt . Sprintf ( "%v = ?" , scope . Quote ( toSnake ( foreignKey ) ) )
return toScope . inlineCondition ( sql , scope . PrimaryKeyValue ( ) ) . callCallbacks ( scope . db . parent . callback . queries )
}
}
return scope
}
func ( scope * Scope ) createTable ( ) * Scope {
var sqls [ ] string
for _ , field := range scope . Fields ( ) {
if ! field . IsIgnored && len ( field . SqlTag ) > 0 {
sqls = append ( sqls , scope . Quote ( field . DBName ) + " " + field . SqlTag )
}
}
scope . Raw ( fmt . Sprintf ( "CREATE TABLE %v (%v)" , scope . TableName ( ) , strings . Join ( sqls , "," ) ) ) . Exec ( )
return scope
}
func ( scope * Scope ) dropTable ( ) * Scope {
scope . Raw ( fmt . Sprintf ( "DROP TABLE %v" , scope . TableName ( ) ) ) . Exec ( )
return scope
}
func ( scope * Scope ) modifyColumn ( column string , typ string ) {
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v MODIFY %v %v" , scope . TableName ( ) , scope . Quote ( column ) , typ ) ) . Exec ( )
}
func ( scope * Scope ) dropColumn ( column string ) {
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v DROP COLUMN %v" , scope . TableName ( ) , scope . Quote ( column ) ) ) . Exec ( )
}
func ( scope * Scope ) addIndex ( column string , names ... string ) {
var indexName string
if len ( names ) > 0 {
indexName = names [ 0 ]
} else {
indexName = fmt . Sprintf ( "index_%v_on_%v" , scope . TableName ( ) , column )
}
scope . Raw ( fmt . Sprintf ( "CREATE INDEX %v ON %v(%v);" , indexName , scope . TableName ( ) , scope . Quote ( column ) ) ) . Exec ( )
}
func ( scope * Scope ) removeIndex ( indexName string ) {
scope . Raw ( fmt . Sprintf ( "DROP INDEX %v ON %v" , indexName , scope . TableName ( ) ) ) . Exec ( )
}
func ( scope * Scope ) autoMigrate ( ) * Scope {
2014-04-24 20:38:40 +04:00
// scope.db.source sample: root:@/testdatabase?parseTime=true
from := strings . Index ( scope . db . source , "/" )
to := strings . Index ( scope . db . source , "?" )
if to == - 1 {
to = len ( scope . db . source )
}
databaseName := scope . db . source [ from : to ]
2014-01-29 08:00:57 +04:00
var tableName string
2014-04-24 20:38:40 +04:00
scope . Raw ( fmt . Sprintf ( "SELECT table_name FROM INFORMATION_SCHEMA.tables where table_schema = %v AND table_name = %v" ,
scope . AddToVars ( databaseName ) ,
scope . AddToVars ( scope . TableName ( ) ) ) )
2014-01-29 08:00:57 +04:00
scope . DB ( ) . QueryRow ( scope . Sql , scope . SqlVars ... ) . Scan ( & tableName )
scope . SqlVars = [ ] interface { } { }
// If table doesn't exist
if len ( tableName ) == 0 {
scope . createTable ( )
} else {
for _ , field := range scope . Fields ( ) {
var column , data string
2014-04-24 20:38:40 +04:00
scope . Raw ( fmt . Sprintf ( "SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v" ,
scope . AddToVars ( databaseName ) ,
2014-01-29 08:00:57 +04:00
scope . AddToVars ( scope . TableName ( ) ) ,
scope . AddToVars ( field . DBName ) ,
) )
scope . DB ( ) . QueryRow ( scope . Sql , scope . SqlVars ... ) . Scan ( & column , & data )
scope . SqlVars = [ ] interface { } { }
// If column doesn't exist
if len ( column ) == 0 && len ( field . SqlTag ) > 0 && ! field . IsIgnored {
scope . Raw ( fmt . Sprintf ( "ALTER TABLE %v ADD %v %v;" , scope . TableName ( ) , field . DBName , field . SqlTag ) ) . Exec ( )
}
}
}
return scope
}
2014-04-02 12:59:07 +04:00
func ( scope * Scope ) getPrimaryKey ( ) string {
var indirectValue reflect . Value
indirectValue = reflect . Indirect ( reflect . ValueOf ( scope . Value ) )
if indirectValue . Kind ( ) == reflect . Slice {
indirectValue = reflect . New ( indirectValue . Type ( ) . Elem ( ) ) . Elem ( )
}
if ! indirectValue . IsValid ( ) {
return "id"
}
scopeTyp := indirectValue . Type ( )
for i := 0 ; i < scopeTyp . NumField ( ) ; i ++ {
fieldStruct := scopeTyp . Field ( i )
if ! ast . IsExported ( fieldStruct . Name ) {
continue
}
// if primaryKey tag found, return column name
if fieldStruct . Tag . Get ( "primaryKey" ) != "" {
return toSnake ( fieldStruct . Name )
}
}
//If primaryKey tag not found, fallback to id
return "id"
}