2020-02-23 07:39:26 +03:00
package logger
import (
"database/sql/driver"
"fmt"
2020-02-23 08:22:08 +03:00
"reflect"
2020-02-23 07:39:26 +03:00
"regexp"
"strconv"
"strings"
"time"
"unicode"
2020-09-02 11:14:26 +03:00
"gorm.io/gorm/utils"
2020-02-23 07:39:26 +03:00
)
2021-02-07 05:09:32 +03:00
const (
tmFmtWithMS = "2006-01-02 15:04:05.999"
tmFmtZero = "0000-00-00 00:00:00"
nullStr = "NULL"
)
2020-02-23 07:39:26 +03:00
func isPrintable ( s [ ] byte ) bool {
for _ , r := range s {
if ! unicode . IsPrint ( rune ( r ) ) {
return false
}
}
return true
}
2021-04-19 16:03:39 +03:00
var convertibleTypes = [ ] reflect . Type { reflect . TypeOf ( time . Time { } ) , reflect . TypeOf ( false ) , reflect . TypeOf ( [ ] byte { } ) }
2020-02-23 08:22:08 +03:00
2022-02-09 10:17:19 +03:00
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
2020-06-01 16:26:23 +03:00
func ExplainSQL ( sql string , numericPlaceholder * regexp . Regexp , escaper string , avars ... interface { } ) string {
2022-02-09 10:17:19 +03:00
var (
convertParams func ( interface { } , int )
vars = make ( [ ] string , len ( avars ) )
)
2020-02-23 07:39:26 +03:00
2020-02-23 08:22:08 +03:00
convertParams = func ( v interface { } , idx int ) {
2020-02-23 07:39:26 +03:00
switch v := v . ( type ) {
case bool :
2020-09-01 16:03:37 +03:00
vars [ idx ] = strconv . FormatBool ( v )
2020-02-23 07:39:26 +03:00
case time . Time :
2020-02-23 14:41:29 +03:00
if v . IsZero ( ) {
2021-02-07 05:09:32 +03:00
vars [ idx ] = escaper + tmFmtZero + escaper
2020-02-23 14:41:29 +03:00
} else {
2021-02-07 05:09:32 +03:00
vars [ idx ] = escaper + v . Format ( tmFmtWithMS ) + escaper
2020-02-23 14:41:29 +03:00
}
2020-09-08 14:11:20 +03:00
case * time . Time :
if v != nil {
if v . IsZero ( ) {
2021-02-07 05:09:32 +03:00
vars [ idx ] = escaper + tmFmtZero + escaper
2020-09-08 14:11:20 +03:00
} else {
2021-02-07 05:09:32 +03:00
vars [ idx ] = escaper + v . Format ( tmFmtWithMS ) + escaper
2020-09-08 14:11:20 +03:00
}
} else {
2021-02-07 05:09:32 +03:00
vars [ idx ] = nullStr
2020-09-08 14:11:20 +03:00
}
2021-02-07 06:18:09 +03:00
case driver . Valuer :
2020-10-19 06:04:18 +03:00
reflectValue := reflect . ValueOf ( v )
if v != nil && reflectValue . IsValid ( ) && ( ( reflectValue . Kind ( ) == reflect . Ptr && ! reflectValue . IsNil ( ) ) || reflectValue . Kind ( ) != reflect . Ptr ) {
2021-02-07 06:18:09 +03:00
r , _ := v . Value ( )
convertParams ( r , idx )
2020-10-19 06:04:18 +03:00
} else {
2021-02-07 05:09:32 +03:00
vars [ idx ] = nullStr
2020-10-19 06:04:18 +03:00
}
2021-02-07 06:18:09 +03:00
case fmt . Stringer :
2020-09-08 14:11:20 +03:00
reflectValue := reflect . ValueOf ( v )
2022-02-09 10:17:19 +03:00
switch reflectValue . Kind ( ) {
case reflect . Int , reflect . Int8 , reflect . Int16 , reflect . Int32 , reflect . Int64 , reflect . Uint , reflect . Uint8 , reflect . Uint16 , reflect . Uint32 , reflect . Uint64 :
vars [ idx ] = fmt . Sprintf ( "%d" , reflectValue . Interface ( ) )
case reflect . Float32 , reflect . Float64 :
vars [ idx ] = fmt . Sprintf ( "%.6f" , reflectValue . Interface ( ) )
case reflect . Bool :
vars [ idx ] = fmt . Sprintf ( "%t" , reflectValue . Interface ( ) )
case reflect . String :
2021-02-07 06:18:09 +03:00
vars [ idx ] = escaper + strings . Replace ( fmt . Sprintf ( "%v" , v ) , escaper , "\\" + escaper , - 1 ) + escaper
2022-02-09 10:17:19 +03:00
default :
if v != nil && reflectValue . IsValid ( ) && ( ( reflectValue . Kind ( ) == reflect . Ptr && ! reflectValue . IsNil ( ) ) || reflectValue . Kind ( ) != reflect . Ptr ) {
vars [ idx ] = escaper + strings . Replace ( fmt . Sprintf ( "%v" , v ) , escaper , "\\" + escaper , - 1 ) + escaper
} else {
vars [ idx ] = nullStr
}
2020-09-08 14:11:20 +03:00
}
2020-02-23 07:39:26 +03:00
case [ ] byte :
if isPrintable ( v ) {
vars [ idx ] = escaper + strings . Replace ( string ( v ) , escaper , "\\" + escaper , - 1 ) + escaper
} else {
vars [ idx ] = escaper + "<binary>" + escaper
}
case int , int8 , int16 , int32 , int64 , uint , uint8 , uint16 , uint32 , uint64 :
2020-09-01 16:03:37 +03:00
vars [ idx ] = utils . ToString ( v )
2020-02-23 07:39:26 +03:00
case float64 , float32 :
vars [ idx ] = fmt . Sprintf ( "%.6f" , v )
case string :
vars [ idx ] = escaper + strings . Replace ( v , escaper , "\\" + escaper , - 1 ) + escaper
default :
2020-07-28 12:25:03 +03:00
rv := reflect . ValueOf ( v )
if v == nil || ! rv . IsValid ( ) || rv . Kind ( ) == reflect . Ptr && rv . IsNil ( ) {
2021-02-07 05:09:32 +03:00
vars [ idx ] = nullStr
2020-07-28 12:25:03 +03:00
} else if valuer , ok := v . ( driver . Valuer ) ; ok {
v , _ = valuer . Value ( )
convertParams ( v , idx )
} else if rv . Kind ( ) == reflect . Ptr && ! rv . IsZero ( ) {
convertParams ( reflect . Indirect ( rv ) . Interface ( ) , idx )
2020-02-23 07:39:26 +03:00
} else {
2021-04-19 16:03:39 +03:00
for _ , t := range convertibleTypes {
2020-07-28 12:25:03 +03:00
if rv . Type ( ) . ConvertibleTo ( t ) {
convertParams ( rv . Convert ( t ) . Interface ( ) , idx )
return
2020-02-23 08:22:08 +03:00
}
2020-04-15 18:58:26 +03:00
}
2020-07-28 12:25:03 +03:00
vars [ idx ] = escaper + strings . Replace ( fmt . Sprint ( v ) , escaper , "\\" + escaper , - 1 ) + escaper
2020-02-23 07:39:26 +03:00
}
}
}
2020-09-01 16:03:37 +03:00
for idx , v := range avars {
2020-02-23 08:22:08 +03:00
convertParams ( v , idx )
}
2020-02-23 07:39:26 +03:00
if numericPlaceholder == nil {
2020-09-14 07:37:16 +03:00
var idx int
var newSQL strings . Builder
for _ , v := range [ ] byte ( sql ) {
if v == '?' {
if len ( vars ) > idx {
newSQL . WriteString ( vars [ idx ] )
idx ++
continue
}
}
newSQL . WriteByte ( v )
2020-02-23 07:39:26 +03:00
}
2020-09-14 07:37:16 +03:00
sql = newSQL . String ( )
2020-02-23 07:39:26 +03:00
} else {
2020-02-23 08:22:08 +03:00
sql = numericPlaceholder . ReplaceAllString ( sql , "$$$1$$" )
2020-02-23 07:39:26 +03:00
for idx , v := range vars {
2020-09-01 16:03:37 +03:00
sql = strings . Replace ( sql , "$" + strconv . Itoa ( idx + 1 ) + "$" , v , 1 )
2020-02-23 07:39:26 +03:00
}
}
return sql
}