From c07a08d88bc4ea7fccf90bcc08b6e2264cf0f78c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 10:43:41 +0800 Subject: [PATCH] Support mssql --- dialects/mssql/create.go | 95 ++++++++++++++++++++++++++++++++++++++++ dialects/mssql/mssql.go | 28 ++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 dialects/mssql/create.go diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go new file mode 100644 index 00000000..4aecce10 --- /dev/null +++ b/dialects/mssql/create.go @@ -0,0 +1,95 @@ +package mssql + +import ( + "reflect" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/clause" +) + +func Create(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT") + db.Statement.WriteByte(' ') + + c := db.Statement.Clauses["VALUES"] + if values, ok := c.Expression.(clause.Values); ok { + if len(values.Columns) > 0 { + db.Statement.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + db.Statement.WriteByte(',') + } + db.Statement.WriteQuoted(column) + } + db.Statement.WriteByte(')') + + if db.Statement.Schema.PrioritizedPrimaryField != nil { + db.Statement.WriteString(" OUTPUT INSERTED.") + db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + } + + db.Statement.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + db.Statement.WriteByte(',') + } + + db.Statement.WriteByte('(') + db.Statement.AddVar(db.Statement, value...) + db.Statement.WriteByte(')') + } + } else { + db.Statement.WriteString("DEFAULT VALUES") + } + } + + db.Statement.WriteByte(' ') + db.Statement.Build("ON CONFLICT") + } + + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + // for idx, field := range fields { + // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + // } + + values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + if err := rows.Scan(values); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: + // for idx, field := range fields { + // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + // } + values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + + if rows.Next() { + err = rows.Scan(values) + } + } + } else { + db.AddError(err) + } +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index ad6782c7..35fcb484 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -26,10 +26,38 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) + db.Callback().Create().Replace("gorm:create", Create) db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } return } +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "LIMIT": func(c clause.Clause, builder clause.Builder) { + if limit, ok := c.Expression.(clause.Limit); ok { + if limit.Offset > 0 { + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + builder.WriteString("ROWS") + } + + if limit.Limit > 0 { + if limit.Offset == 0 { + builder.WriteString(" OFFSET 0 ROWS") + } + builder.WriteString(" FETCH NEXT ") + builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(" ROWS ONLY") + } + } + }, + } +} + func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db,