From 66ac04ba05679bb45fa411cb4269929392cb2563 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 21 Nov 2013 14:33:06 +0800 Subject: [PATCH] Fix MySQL compatibility --- dialect/mysql.go | 2 +- dialect/sqlite3.go | 2 +- do.go | 46 +++++++++++++++++++++++++++------------------- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/dialect/mysql.go b/dialect/mysql.go index 0e71612b..8e611c3b 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -9,7 +9,7 @@ import ( type mysql struct{} func (s *mysql) BinVar(i int) string { - return "?" + return "$$" // ? } func (s *mysql) SupportLastInsertId() bool { diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index f413bd14..cc91b122 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -9,7 +9,7 @@ import ( type sqlite3 struct{} func (s *sqlite3) BinVar(i int) string { - return "?" + return "$$" // ? } func (s *sqlite3) SupportLastInsertId() bool { diff --git a/do.go b/do.go index 1dbea5bf..83e90293 100644 --- a/do.go +++ b/do.go @@ -27,6 +27,10 @@ type Do struct { startedTransaction bool } +func (s *Do) setSql(sql string) { + s.sql = strings.Replace(sql, "$$", "?", -1) +} + func (s *Do) table() string { if len(s.tableName) == 0 { if len(s.search.tableName) == 0 { @@ -68,7 +72,7 @@ func (s *Do) trace(t time.Time) { } func (s *Do) raw(query string, values ...interface{}) *Do { - s.sql = s.buildWhereCondition(map[string]interface{}{"query": query, "args": values}) + s.setSql(s.buildWhereCondition(map[string]interface{}{"query": query, "args": values})) return s } @@ -98,13 +102,13 @@ func (s *Do) prepareCreateSql() { sqls = append(sqls, s.addToVars(value)) } - s.sql = fmt.Sprintf( + s.setSql(fmt.Sprintf( "INSERT INTO %v (%v) VALUES (%v) %v", s.table(), strings.Join(columns, ","), strings.Join(sqls, ","), s.dialect().ReturningStr(s.model.primaryKeyDb()), - ) + )) return } @@ -246,12 +250,12 @@ func (s *Do) prepareUpdateSql(include_self bool) { } } - s.sql = fmt.Sprintf( + s.setSql(fmt.Sprintf( "UPDATE %v SET %v %v", s.table(), strings.Join(sqls, ", "), s.combinedSql(), - ) + )) return } @@ -292,9 +296,9 @@ func (s *Do) delete() *Do { if !s.db.hasError() { if !s.search.unscope && s.model.hasColumn("DeletedAt") { - s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()) + s.setSql(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql())) } else { - s.sql = fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()) + s.setSql(fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql())) } s.exec() s.model.callMethod("AfterDelete") @@ -303,7 +307,7 @@ func (s *Do) delete() *Do { } func (s *Do) prepareQuerySql() { - s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql()) + s.setSql(fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.table(), s.combinedSql())) return } @@ -506,6 +510,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } + str = strings.Replace(str, "?", s.addToVars(arg), 1) } } @@ -686,24 +691,24 @@ func (s *Do) createTable() *Do { } } - s.sql = fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")) + s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))) s.exec() return s } func (s *Do) dropTable() *Do { - s.sql = fmt.Sprintf("DROP TABLE %v", s.table()) + s.setSql(fmt.Sprintf("DROP TABLE %v", s.table())) s.exec() return s } func (s *Do) modifyColumn(column string, typ string) { - s.sql = fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ) + s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)) s.exec() } func (s *Do) dropColumn(column string) { - s.sql = fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column) + s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)) s.exec() } @@ -715,19 +720,19 @@ func (s *Do) addIndex(column string, names ...string) { index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column) } - s.sql = fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column) + s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)) s.exec() } func (s *Do) removeIndex(index_name string) { - s.sql = fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table()) + s.setSql(fmt.Sprintf("DROP INDEX %v ON %v", index_name, s.table())) s.exec() } func (s *Do) autoMigrate() *Do { var table_name string - sql := fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table())) - s.db.db.QueryRow(sql, s.sqlVars...).Scan(&table_name) + s.setSql(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", s.addToVars(s.table()))) + s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&table_name) s.sqlVars = []interface{}{} // If table doesn't exist @@ -736,13 +741,16 @@ func (s *Do) autoMigrate() *Do { } else { for _, field := range s.model.fields("migration") { var column_name, data_type string - sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.table())) - s.db.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type) + s.setSql(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", + s.addToVars(s.table()), + s.addToVars(field.dbName), + )) + s.db.db.QueryRow(s.sql, s.sqlVars...).Scan(&column_name, &data_type) s.sqlVars = []interface{}{} // If column doesn't exist if len(column_name) == 0 && len(field.sqlTag()) > 0 { - s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag()) + s.setSql(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.table(), field.dbName, field.sqlTag())) s.exec() } }