forked from mirror/gorm
Compare commits
62 Commits
allow_shar
...
master
Author | SHA1 | Date |
---|---|---|
Cheese | 02b7e26f6b | |
Cr | 878ac51e98 | |
chyroc | e1f46eb802 | |
Jinzhu | 4d6b70ec88 | |
qiankunli | cfbcedbf03 | |
Jinzhu | d834dd60b7 | |
Jinzhu | 3d35ddba55 | |
Haibo | baf1afa1fc | |
Jinzhu | 2bc913787b | |
Jinzhu | 3d91802b1d | |
Jinzhu | b0e13d95b4 | |
Jinzhu | 4b768c8aff | |
Haibo | 16a272209a | |
Haibo | da2b2861de | |
dependabot[bot] | 7da24d1d52 | |
Jinzhu | ddd3cc2502 | |
Cr | 794edad60e | |
Cr | 1935eb0adb | |
Defoo Li | 775fa70af5 | |
Ning | bbd2bbe521 | |
Nate Armstrong | f3c6fc2533 | |
Edward McFarlane | 4ec73c9bf4 | |
Cr | d9525d4da4 | |
wjw1758548031 | f931def33d | |
Jinzhu | f91313436a | |
Cr | 342310fba4 | |
kvii | b6836c2d3e | |
jessetang | cef3de694d | |
jessetang | 1b9cd56c53 | |
kvii | 871f1de6b9 | |
jessetang | fb640cf7da | |
jessetang | 5c8ecc3a2a | |
jessetang | f82e9cfdbe | |
Cr | b2f42528a4 | |
Cr | 9d82aa5673 | |
Cr | 5dd2bb4827 | |
Jinzhu | 3f20a543fa | |
viatoriche / Maxim Panfilov | 62593cfad0 | |
Jinzhu | a0f4d3f7d2 | |
Jinzhu | ab5f80a8d8 | |
Cr | 186e8a9e14 | |
Jinzhu | 2a788fb20c | |
Jinzhu | aa4312ee74 | |
Jinzhu | 08aa2f9888 | |
Jinzhu | 2c56954cb1 | |
Jinzhu | e93dc3426e | |
Jinzhu | 983e96f142 | |
Jinzhu | 34fbe84580 | |
robhafner | e8f48b5c15 | |
jesse.tang | 4b22a55a75 | |
Wen Sun | 9564b82975 | |
Cr | 0b7113b618 | |
Stephano George | a3cc6c6088 | |
jesse.tang | be440e7512 | |
dependabot[bot] | e1dd0dcbc4 | |
Nguyen Huu Tuan | 328f301982 | |
kinggo | 12237454ed | |
Cr | 73bc53f061 | |
Cr | 101a7c789f | |
Jinzhu | 3a72ba102e | |
jesse.tang | 1f634c3937 | |
Cr | 5ed7b1a65e |
|
@ -16,7 +16,7 @@ jobs:
|
|||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
|
|
@ -16,7 +16,7 @@ jobs:
|
|||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
|
|
@ -16,7 +16,7 @@ jobs:
|
|||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||
|
|
|
@ -16,7 +16,7 @@ jobs:
|
|||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
go: ['1.19', '1.18']
|
||||
platform: [ubuntu-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
go: ['1.19', '1.18']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
|
@ -86,7 +86,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
go: ['1.19', '1.18']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
|
@ -128,7 +128,7 @@ jobs:
|
|||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
go: ['1.19', '1.18']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
|
@ -167,3 +167,36 @@ jobs:
|
|||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||
|
||||
tidb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'v6.5.0' ]
|
||||
go: [ '1.19', '1.18' ]
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Setup TiDB
|
||||
uses: Icemap/tidb-action@main
|
||||
with:
|
||||
port: 9940
|
||||
version: ${{matrix.dbversion}}
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
|
||||
|
|
|
@ -9,3 +9,12 @@ linters:
|
|||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- goimports
|
||||
- whitespace
|
||||
|
||||
linters-settings:
|
||||
whitespace:
|
||||
multi-func: true
|
||||
goimports:
|
||||
local-prefixes: gorm.io/gorm
|
||||
|
||||
|
|
11
README.md
11
README.md
|
@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||
|
||||
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
|
||||
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
|
||||
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
|
||||
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
|
||||
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
|
@ -30,12 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## Contributors
|
||||
|
||||
Thank you for contributing to the GORM framework!
|
||||
|
||||
[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors)
|
||||
|
||||
## License
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
|
10
callbacks.go
10
callbacks.go
|
@ -93,6 +93,10 @@ func (p *processor) Execute(db *DB) *DB {
|
|||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
|
@ -132,7 +136,11 @@ func (p *processor) Execute(db *DB) *DB {
|
|||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
|
||||
|
|
|
@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
|
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
|
|
|
@ -117,13 +117,21 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
|
@ -177,7 +185,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Type: join.JoinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
})
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun {
|
||||
if db.DryRun || db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
137
chainable_api.go
137
chainable_api.go
|
@ -10,9 +10,10 @@ import (
|
|||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
//
|
||||
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||
// [docs] for more depth.
|
||||
//
|
||||
// // add a simple limit clause
|
||||
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||
// // tell the optimizer to use the `idx_user_name` index
|
||||
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||
// // specify the lock strength to UPDATE
|
||||
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
|
@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
||||
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
//
|
||||
// // Get a user
|
||||
// db.Table("users").take(&result)
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||
if results[1] != "" {
|
||||
tx.Statement.Table = results[1]
|
||||
} else {
|
||||
tx.Statement.Table = results[2]
|
||||
}
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
|
@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
//
|
||||
// // Select distinct names of users
|
||||
// db.Distinct("name").Find(&results)
|
||||
// // Select distinct name/age pairs from users
|
||||
// db.Distinct("name", "age").Find(&results)
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
|
@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
//
|
||||
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||
// Select accepts both string arguments and arrays.
|
||||
//
|
||||
// // Select name and age of user using multiple arguments
|
||||
// db.Select("name", "age").Find(&users)
|
||||
// // Select name and age of user using an array
|
||||
// db.Select([]string{"name", "age"}).Find(&users)
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
|
@ -152,6 +186,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
//
|
||||
// // Find the first user with name jinzhu
|
||||
// db.Where("name = ?", "jinzhu").First(&user)
|
||||
// // Find the first user with name jinzhu and age 20
|
||||
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||
// // Find the first user with name jinzhu and age not equal to 20
|
||||
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
|
@ -161,6 +206,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
//
|
||||
// Not works similarly to where, and has the same syntax.
|
||||
//
|
||||
// // Find the first user with name not equal to jinzhu
|
||||
// db.Not("name = ?", "jinzhu").First(&user)
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
|
@ -170,6 +220,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Or add OR conditions
|
||||
//
|
||||
// Or is used to chain together queries with an OR.
|
||||
//
|
||||
// // Find the first user with name equal to jinzhu or john
|
||||
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
|
@ -179,26 +234,42 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.LeftJoin, query, args...)
|
||||
}
|
||||
|
||||
// InnerJoins specify inner joins conditions
|
||||
// db.InnerJoins("Account").Find(&user)
|
||||
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.InnerJoin, query, args...)
|
||||
}
|
||||
|
||||
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
|
||||
j.On = &where
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
//
|
||||
// // Select the sum age of users with given names
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
|
@ -210,6 +281,9 @@ func (db *DB) Group(name string) (tx *DB) {
|
|||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
//
|
||||
// // Select the sum age of users with name jinzhu
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
|
@ -218,7 +292,8 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database
|
||||
// Order specify order when retrieving records from database
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
|
@ -242,13 +317,27 @@ func (db *DB) Order(value interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
//
|
||||
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||
//
|
||||
// // retrieve 3 users
|
||||
// db.Limit(3).Find(&users)
|
||||
// // retrieve 3 users into users1, and all users into users2
|
||||
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
//
|
||||
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||
//
|
||||
// // select the third user
|
||||
// db.Offset(2).First(&user)
|
||||
// // select the first user by cancelling an earlier chained offset
|
||||
// db.Offset(5).Offset(-1).First(&user)
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
|
@ -256,6 +345,7 @@ func (db *DB) Offset(offset int) (tx *DB) {
|
|||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
//
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
|
@ -274,6 +364,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
|||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
//
|
||||
// // get all users, and preload all non-cancelled orders
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
@ -284,12 +376,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Attrs only adds attributes if the record is not found.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign an email if the record is not found, otherwise ignore provided email
|
||||
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||
// records will be updated even if they are found.
|
||||
//
|
||||
// // assign an email regardless of if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
|
|
|
@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
|
|||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
limit10 := 10
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{
|
||||
|
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
|
|||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||
}},
|
||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||
clause.Limit{Limit: 10, Offset: 20},
|
||||
clause.Limit{Limit: &limit10, Offset: 20},
|
||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ type Builder interface {
|
|||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
AddError(error) error
|
||||
}
|
||||
|
||||
// Clause
|
||||
|
|
|
@ -9,7 +9,7 @@ const (
|
|||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
// Join join clause for from
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
package clause_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
results := []struct {
|
||||
name string
|
||||
join clause.Join
|
||||
sql string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.RightJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "CROSS JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.CrossJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "USING",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
{
|
||||
name: "Expression",
|
||||
join: clause.Join{
|
||||
// Invalid
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
// Valid
|
||||
Expression: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
}
|
||||
for _, result := range results {
|
||||
t.Run(result.name, func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
result.join.Build(stmt)
|
||||
if result.sql != stmt.SQL.String() {
|
||||
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@ import "strconv"
|
|||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit int
|
||||
Limit *int
|
||||
Offset int
|
||||
}
|
||||
|
||||
|
@ -15,12 +15,12 @@ func (limit Limit) Name() string {
|
|||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
builder.WriteString(strconv.Itoa(*limit.Limit))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
|
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
|||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,10 @@ import (
|
|||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
limit0 := 0
|
||||
limit10 := 10
|
||||
limit50 := 50
|
||||
limitNeg10 := -10
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
|
@ -15,11 +19,15 @@ func TestLimit(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||
Limit: 10,
|
||||
Limit: &limit10,
|
||||
Offset: 20,
|
||||
}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||
"SELECT * FROM `users` LIMIT 0", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||
"SELECT * FROM `users` OFFSET 20", nil,
|
||||
|
@ -29,23 +37,23 @@ func TestLimit(t *testing.T) {
|
|||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT 10", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
|
||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -16,6 +16,11 @@ func (OnConflict) Name() string {
|
|||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
|
@ -32,11 +37,6 @@ func (onConflict OnConflict) Build(builder Builder) {
|
|||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
|
|
|
@ -21,6 +21,8 @@ var (
|
|||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
|
|
|
@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
totalSize = limit.Limit
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
|
@ -229,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||
break
|
||||
}
|
||||
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
if zero {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
|
@ -288,6 +294,16 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
|||
|
||||
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
|
@ -315,12 +331,30 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
|
||||
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
||||
|
||||
result := queryTx.Find(dest, conds...)
|
||||
if result.Error != nil {
|
||||
tx.Error = result.Error
|
||||
return tx
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
|
@ -349,16 +383,13 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
} else {
|
||||
tx.Error = result.Error
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
|
@ -459,7 +490,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||
tx.Statement.Dest = count
|
||||
tx = tx.callbacks.Query().Execute(tx)
|
||||
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
|
||||
if tx.RowsAffected != 1 {
|
||||
*count = tx.RowsAffected
|
||||
}
|
||||
|
||||
|
@ -512,6 +543,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||
//
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
|
|
12
gorm.go
12
gorm.go
|
@ -37,6 +37,8 @@ type Config struct {
|
|||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// IgnoreRelationshipsWhenMigrating
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
|
@ -179,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||
|
||||
preparedStmt := &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
Stmts: map[string]Stmt{},
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
|
@ -248,11 +250,19 @@ func (db *DB) Session(config *Session) *DB {
|
|||
if config.PrepareStmt {
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
|
|
@ -26,6 +26,10 @@ type Plugin interface {
|
|||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
type ParamsFilter interface {
|
||||
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
|
|
|
@ -55,6 +55,7 @@ type Config struct {
|
|||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
IgnoreRecordNotFoundError bool
|
||||
ParameterizedQueries bool
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
|
@ -181,6 +182,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
|||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
return sql, params
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
|
|
|
@ -68,6 +68,7 @@ type Migrator interface {
|
|||
// Database
|
||||
CurrentDatabase() string
|
||||
FullDataTypeOf(*schema.Field) clause.Expr
|
||||
GetTypeAliases(databaseTypeName string) []string
|
||||
|
||||
// Tables
|
||||
CreateTable(dst ...interface{}) error
|
||||
|
|
|
@ -8,9 +8,11 @@ import (
|
|||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
|
@ -30,6 +32,16 @@ type Config struct {
|
|||
gorm.Dialector
|
||||
}
|
||||
|
||||
type printSQLLogger struct {
|
||||
logger.Interface
|
||||
}
|
||||
|
||||
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
fmt.Println(sql + ";")
|
||||
l.Interface.Trace(ctx, begin, fc, err)
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||
|
@ -92,18 +104,26 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||
// AutoMigrate auto migrate values
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
||||
queryTx := m.DB.Session(&gorm.Session{})
|
||||
execTx := queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
if !queryTx.Migrator().HasTable(value) {
|
||||
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
parseIndexes = stmt.Schema.ParseIndexes()
|
||||
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||
)
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
var foundColumn gorm.ColumnType
|
||||
|
@ -117,37 +137,40 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
if err := execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
for _, chk := range parseCheckConstraints {
|
||||
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
for _, idx := range parseIndexes {
|
||||
if !queryTx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -227,8 +250,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := buildConstraint(constraint)
|
||||
|
@ -406,13 +432,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
var (
|
||||
alterColumn bool
|
||||
isSameType = fullDataType == realDataType
|
||||
)
|
||||
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) {
|
||||
alterColumn = true
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
isSameType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
|
@ -434,6 +477,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
|
@ -453,17 +497,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||
|
||||
// check default value
|
||||
if !field.PrimaryKey {
|
||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||
dv, dvNotNull := columnType.DefaultValue()
|
||||
if dvNotNull && field.DefaultValueInterface == nil {
|
||||
if dvNotNull && !currentDefaultNotNull {
|
||||
// defalut value -> null
|
||||
alterColumn = true
|
||||
} else if !dvNotNull && field.DefaultValueInterface != nil {
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if dv != field.DefaultValue {
|
||||
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
|
||||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
|
||||
// default value not equal
|
||||
// not both null
|
||||
if !(field.DefaultValueInterface == nil && !dvNotNull) {
|
||||
if currentDefaultNotNull || dvNotNull {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
@ -478,7 +524,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||
}
|
||||
|
||||
if alterColumn && !field.IgnoreMigration {
|
||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||
return m.DB.Migrator().AlterColumn(value, field.DBName)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -782,7 +828,11 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||
}
|
||||
parsedSchemas[dep.Statement.Schema] = true
|
||||
|
||||
if !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
|
@ -804,6 +854,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
valuesMap[dep.Schema.Table] = dep
|
||||
|
||||
|
@ -863,3 +914,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
|||
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
||||
// GetTypeAliases return database type aliases
|
||||
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,10 +9,12 @@ import (
|
|||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts map[string]Stmt
|
||||
Stmts map[string]*Stmt
|
||||
PreparedSQL []string
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
|
@ -42,31 +44,72 @@ func (db *PreparedStmtDB) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, stmt := range db.Stmts {
|
||||
go stmt.Close()
|
||||
}
|
||||
db.PreparedSQL = make([]string, 0, 100)
|
||||
db.Stmts = make(map[string]*Stmt)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, nil
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
return stmt, nil
|
||||
} else if ok {
|
||||
go stmt.Close()
|
||||
db.Mux.Unlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
|
||||
// cache preparing stmt first
|
||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||
db.Stmts[query] = &cacheStmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err == nil {
|
||||
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
db.Mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, err
|
||||
}
|
||||
|
||||
return db.Stmts[query], err
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
db.Mux.Unlock()
|
||||
|
||||
return cacheStmt, nil
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
|
|
49
scan.go
49
scan.go
|
@ -65,7 +65,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
joinedSchemaMap := make(map[*schema.Field]interface{})
|
||||
for idx, field := range fields {
|
||||
if field == nil {
|
||||
|
@ -164,7 +163,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
selectedColumnsMap = make(map[string]int, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
|
@ -200,26 +198,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
schFieldsCount := len(sch.Fields)
|
||||
matchedFieldCount := make(map[string]int, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
if curIndex, ok := selectedColumnsMap[column]; ok {
|
||||
fields[idx] = field // handle duplicate fields
|
||||
offset := curIndex + 1
|
||||
// handle sch inconsistent with database
|
||||
// like Raw(`...`).Scan
|
||||
if schFieldsCount > offset {
|
||||
for fieldIndex, selectField := range sch.Fields[offset:] {
|
||||
fields[idx] = field
|
||||
if count, ok := matchedFieldCount[column]; ok {
|
||||
// handle duplicate fields
|
||||
for _, selectField := range sch.Fields {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
||||
if count == 0 {
|
||||
matchedFieldCount[column]++
|
||||
fields[idx] = selectField
|
||||
break
|
||||
}
|
||||
count--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fields[idx] = field
|
||||
selectedColumnsMap[column] = idx
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
|
@ -243,12 +239,20 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var elem reflect.Value
|
||||
recyclableStruct := reflect.New(reflectValueType)
|
||||
var (
|
||||
elem reflect.Value
|
||||
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||
)
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else if !isArrayKind {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
|
@ -269,20 +273,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if isPtr && db.RowsAffected > 0 {
|
||||
elem = reflect.New(reflectValueType)
|
||||
} else {
|
||||
elem = recyclableStruct
|
||||
}
|
||||
}
|
||||
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if isPtr {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
if !isPtr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if isArrayKind {
|
||||
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||
}
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.Unique = true
|
||||
}
|
||||
}
|
||||
return indexes
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestParseIndex(t *testing.T) {
|
|||
"idx_name": {
|
||||
Name: "idx_name",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}},
|
||||
},
|
||||
"idx_user_indices_name3": {
|
||||
Name: "idx_user_indices_name3",
|
||||
|
@ -81,7 +81,7 @@ func TestParseIndex(t *testing.T) {
|
|||
"idx_user_indices_name4": {
|
||||
Name: "idx_user_indices_name4",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}},
|
||||
},
|
||||
"idx_user_indices_name5": {
|
||||
Name: "idx_user_indices_name5",
|
||||
|
@ -102,12 +102,12 @@ func TestParseIndex(t *testing.T) {
|
|||
},
|
||||
"idx_id": {
|
||||
Name: "idx_id",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}},
|
||||
},
|
||||
"idx_oid": {
|
||||
Name: "idx_oid",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}},
|
||||
},
|
||||
"type": {
|
||||
Name: "type",
|
||||
|
@ -168,6 +168,9 @@ func TestParseIndex(t *testing.T) {
|
|||
if rf.Field.Name != ef.Field.Name {
|
||||
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
|
||||
}
|
||||
if rf.Field.Unique != ef.Field.Unique {
|
||||
t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique)
|
||||
}
|
||||
|
||||
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
|
||||
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
|
||||
|
|
|
@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
err error
|
||||
joinTableFields []reflect.StructField
|
||||
fieldsMap = map[string]*Field{}
|
||||
ownFieldsMap = map[string]bool{} // fix self join many2many
|
||||
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||
referFieldsMap = map[string]*Field{}
|
||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||
)
|
||||
|
@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = true
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
fieldsMap[joinFieldName] = ownField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
|
@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
|
@ -254,6 +252,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
}
|
||||
}
|
||||
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
|
||||
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||
fieldsMap[joinFieldName] = relField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
|
@ -263,6 +268,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: strings.Title(schema.Name) + field.Name,
|
||||
|
@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
f.Size = fieldsMap[f.Name].Size
|
||||
}
|
||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
||||
|
||||
if ownPrimaryField {
|
||||
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||
joinRel.Field = relation.Field
|
||||
joinRel.References = append(joinRel.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
})
|
||||
} else {
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||
if joinRefRel.Field == nil {
|
||||
joinRefRel.Field = relation.Field
|
||||
}
|
||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: ownPrimaryField,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -391,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
} else {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
var primaryFields []*Field
|
||||
var primarySchemaName = primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
|
@ -454,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
}
|
||||
}
|
||||
|
||||
if len(foreignFields) == 0 {
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
reguessOrErr()
|
||||
return
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
case len(relation.primaryKeys) > 0:
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
|
@ -471,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
return
|
||||
}
|
||||
}
|
||||
} else if len(primaryFields) == 0 {
|
||||
case len(primaryFields) == 0:
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
||||
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
||||
t.Errorf("Failed to parse schema")
|
||||
t.Errorf("Failed to parse schema, got error %v", err)
|
||||
} else {
|
||||
for _, rel := range relations {
|
||||
checkSchemaRelation(t, s, rel)
|
||||
|
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMany2ManySharedForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Kind string
|
||||
ProfileRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
|
||||
Kind string
|
||||
Refer uint
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
||||
References: []Reference{
|
||||
{"Refer", "User", "UserRefer", "user_profiles", "", true},
|
||||
{"Kind", "User", "Kind", "user_profiles", "", true},
|
||||
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
|
||||
{"Kind", "Profile", "Kind", "user_profiles", "", false},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
|
|
|
@ -71,6 +71,10 @@ type Tabler interface {
|
|||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
|
@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
|
@ -223,7 +230,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
@ -45,10 +46,20 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||
if _, err := now.Parse(v); err == nil {
|
||||
return sql.NullString{String: v, Valid: true}
|
||||
}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
|
@ -78,17 +89,18 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
|||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteUpdateClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
|
@ -109,10 +121,11 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
|||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
|
|
61
statement.go
61
statement.go
|
@ -52,6 +52,9 @@ type join struct {
|
|||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
|
@ -117,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
write(v.Raw, stmt.Schema.DBNames[0])
|
||||
} else {
|
||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
|
@ -179,6 +184,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
|
@ -304,6 +313,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for idx, arg := range args {
|
||||
if arg == nil {
|
||||
continue
|
||||
}
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
@ -540,6 +552,7 @@ func (stmt *Statement) clone() *Statement {
|
|||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
|
@ -657,47 +670,41 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
|||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn := func(column string, result bool) {
|
||||
if stmt.Schema == nil {
|
||||
results[column] = true
|
||||
results[column] = result
|
||||
} else if column == "*" {
|
||||
notRestricted = true
|
||||
notRestricted = result
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = true
|
||||
results[dbName] = result
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = true
|
||||
results[rel.Name] = result
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = true
|
||||
results[field.DBName] = result
|
||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
|
||||
results[matches[2]] = true
|
||||
if matches[2] == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = true
|
||||
results[matches[2]] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = result
|
||||
}
|
||||
}
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn(column, true)
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, omit := range stmt.Omits {
|
||||
if stmt.Schema == nil {
|
||||
results[omit] = false
|
||||
} else if omit == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = false
|
||||
}
|
||||
} else if omit == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = false
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
||||
results[matches[1]] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
}
|
||||
for _, column := range stmt.Omits {
|
||||
processColumn(column, false)
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
|
|
|
@ -35,6 +35,13 @@ func TestWhereCloneCorruption(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNilCondition(t *testing.T) {
|
||||
s := new(Statement)
|
||||
if len(s.BuildCondition(nil)) != 0 {
|
||||
t.Errorf("Nil condition should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameMatcher(t *testing.T) {
|
||||
for k, v := range map[string][]string{
|
||||
"table.name": {"table", "name"},
|
||||
|
|
|
@ -3,6 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) {
|
|||
unexistCompanyID := company.ID + 9999999
|
||||
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
|
||||
if err := DB.Create(&user).Error; err == nil {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
t.Errorf("should have gotten foreign key violation error")
|
||||
}
|
||||
}
|
||||
|
@ -224,3 +226,28 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
|
|||
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
||||
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
||||
}
|
||||
|
||||
func TestBelongsToDefaultValue(t *testing.T) {
|
||||
type Org struct {
|
||||
ID string
|
||||
}
|
||||
type BelongsToUser struct {
|
||||
OrgID string
|
||||
Org Org `gorm:"default:NULL"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
|
||||
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
|
||||
|
||||
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
|
||||
tx.AutoMigrate(&BelongsToUser{}, &Org{})
|
||||
|
||||
user := &BelongsToUser{
|
||||
Org: Org{
|
||||
ID: "BelongsToUser_Org_1",
|
||||
},
|
||||
}
|
||||
err := DB.Create(&user).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
|
|
@ -95,6 +95,8 @@ func TestMany2ManyAssociation(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMany2ManyOmitAssociations(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
|
||||
|
||||
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
|
||||
|
|
|
@ -71,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestForeignKeyConstraints(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
|
@ -126,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
|
@ -348,3 +352,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
|
|||
AssertEqual(t, len(orgs), 0)
|
||||
}
|
||||
}
|
||||
|
||||
type AssociationEmptyUser struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pets []AssociationEmptyPet
|
||||
}
|
||||
|
||||
type AssociationEmptyPet struct {
|
||||
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||
}
|
||||
|
||||
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
|
||||
id := uint(100)
|
||||
user := AssociationEmptyUser{
|
||||
ID: id,
|
||||
Name: "jinzhu",
|
||||
Pets: []AssociationEmptyPet{
|
||||
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create, got error: %v", err)
|
||||
}
|
||||
|
||||
var result AssociationEmptyUser
|
||||
err = DB.Preload("Pets").First(&result, &id).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
}
|
||||
|
|
|
@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) {
|
|||
|
||||
for idx, data := range datas {
|
||||
db, err := gorm.Open(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
callbacks := db.Callback()
|
||||
|
||||
for _, c := range data.callbacks {
|
||||
|
|
|
@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
|
|
@ -141,7 +141,7 @@ func TestCount(t *testing.T) {
|
|||
}
|
||||
DB.Create(sameUsers)
|
||||
|
||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) {
|
||||
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
||||
}
|
||||
|
||||
|
|
|
@ -29,3 +29,8 @@ services:
|
|||
- MSSQL_DB=gorm
|
||||
- MSSQL_USER=gorm
|
||||
- MSSQL_PASSWORD=LoremIpsum86
|
||||
tidb:
|
||||
image: 'pingcap/tidb:v6.5.0'
|
||||
ports:
|
||||
- 9940:4000
|
||||
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||
|
||||
type EngadgetPost struct {
|
||||
BasePost BasePost `gorm:"Embedded"`
|
||||
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
ImageUrl string
|
||||
}
|
||||
|
||||
|
@ -74,13 +74,27 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
|
||||
var egNews EngadgetPost
|
||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else if egNews.BasePost.Title != "engadget_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
var egPosts []EngadgetPost
|
||||
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
|
||||
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
|
||||
}
|
||||
expectAuthors := []string{"Edward", "George"}
|
||||
for i, post := range egPosts {
|
||||
t.Log(i, post.Author)
|
||||
if want := expectAuthors[i]; post.Author.Name != want {
|
||||
t.Errorf("expected author %s got %s", want, post.Author.Name)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
|
|
17
tests/go.mod
17
tests/go.mod
|
@ -3,17 +3,16 @@ module gorm.io/gorm/tests
|
|||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.6
|
||||
github.com/mattn/go-sqlite3 v1.14.14 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
|
||||
gorm.io/driver/mysql v1.3.5
|
||||
gorm.io/driver/postgres v1.3.8
|
||||
gorm.io/driver/sqlite v1.3.6
|
||||
gorm.io/driver/sqlserver v1.3.2
|
||||
gorm.io/gorm v1.23.8
|
||||
github.com/lib/pq v1.10.7
|
||||
github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||
golang.org/x/crypto v0.5.0 // indirect
|
||||
gorm.io/driver/mysql v1.4.6
|
||||
gorm.io/driver/postgres v1.4.6
|
||||
gorm.io/driver/sqlite v1.4.4
|
||||
gorm.io/driver/sqlserver v1.4.2
|
||||
gorm.io/gorm v1.24.5
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -235,3 +236,13 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func tidbSkip(t *testing.T, reason string) {
|
||||
if isTiDB() {
|
||||
t.Skipf("This test case skipped, because of TiDB '%s'", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func isTiDB() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "tidb"
|
||||
}
|
||||
|
|
|
@ -229,3 +229,99 @@ func TestJoinWithSoftDeleted(t *testing.T) {
|
|||
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInnerJoins(t *testing.T) {
|
||||
user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
|
||||
|
||||
DB.Create(&user)
|
||||
|
||||
var user2 User
|
||||
var err error
|
||||
err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, nil)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
// inner join and NamedPet is nil
|
||||
err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, gorm.ErrRecordNotFound)
|
||||
|
||||
// mixed inner join and left join
|
||||
var user3 User
|
||||
err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, nil)
|
||||
CheckUser(t, user3, user)
|
||||
}
|
||||
|
||||
func TestJoinWithSameColumnName(t *testing.T) {
|
||||
user := GetUser("TestJoinWithSameColumnName", Config{
|
||||
Languages: 1,
|
||||
Pets: 1,
|
||||
})
|
||||
DB.Create(user)
|
||||
type UserSpeak struct {
|
||||
UserID uint
|
||||
LanguageCode string
|
||||
}
|
||||
type Result struct {
|
||||
User
|
||||
UserSpeak
|
||||
Language
|
||||
Pet
|
||||
}
|
||||
|
||||
results := make([]Result, 0, 1)
|
||||
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
|
||||
Joins("JOIN languages ON languages.code = user_speaks.language_code").
|
||||
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
|
||||
|
||||
if len(results) == 0 {
|
||||
t.Fatalf("no record find")
|
||||
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
|
||||
t.Fatalf("wrong user id in pet")
|
||||
} else if results[0].Pet.Name != user.Pets[0].Name {
|
||||
t.Fatalf("wrong pet name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinArgsWithDB(t *testing.T) {
|
||||
user := *GetUser("joins-args-db", Config{Pets: 2})
|
||||
DB.Save(&user)
|
||||
|
||||
// test where
|
||||
var user1 User
|
||||
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
|
||||
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test where and omit
|
||||
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
|
||||
var user2 User
|
||||
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
|
||||
AssertEqual(t, user2.NamedPet.Name, "")
|
||||
|
||||
// test where and select
|
||||
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
|
||||
var user3 User
|
||||
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user3.NamedPet.ID, 0)
|
||||
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test select
|
||||
onQuery4 := DB.Select("ID")
|
||||
var user4 User
|
||||
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
if user4.NamedPet.ID == 0 {
|
||||
t.Fatal("Pet ID can not be empty")
|
||||
}
|
||||
AssertEqual(t, user4.NamedPet.Name, "")
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -10,6 +12,7 @@ import (
|
|||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
@ -25,7 +28,7 @@ func TestMigrate(t *testing.T) {
|
|||
}
|
||||
|
||||
if err := DB.AutoMigrate(allModels...); err != nil {
|
||||
t.Fatalf("Failed to auto migrate, but got error %v", err)
|
||||
t.Fatalf("Failed to auto migrate, got error %v", err)
|
||||
}
|
||||
|
||||
if tables, err := DB.Migrator().GetTables(); err != nil {
|
||||
|
@ -72,6 +75,44 @@ func TestMigrate(t *testing.T) {
|
|||
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAutoMigrateInt8PG(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
type Smallint int8
|
||||
|
||||
type MigrateInt struct {
|
||||
Int8 Smallint
|
||||
}
|
||||
|
||||
tracer := Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&MigrateInt{})
|
||||
|
||||
// The first AutoMigrate to make table with field with correct type
|
||||
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
|
||||
// make new session to set custom logger tracer
|
||||
session := DB.Session(&gorm.Session{Logger: tracer})
|
||||
|
||||
// The second AutoMigrate to catch an error
|
||||
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateSelfReferential(t *testing.T) {
|
||||
|
@ -333,7 +374,137 @@ func TestMigrateIndexes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTiDBMigrateColumns(t *testing.T) {
|
||||
if !isTiDB() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
// TiDB can't change column constraint and has auto_random feature
|
||||
type ColumnStruct struct {
|
||||
ID int `gorm:"primarykey;default:auto_random()"`
|
||||
Name string
|
||||
Age int `gorm:"default:18;comment:my age"`
|
||||
Code string `gorm:"unique;comment:my code;"`
|
||||
Code2 string
|
||||
Code3 string `gorm:"unique"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&ColumnStruct{})
|
||||
|
||||
if err := DB.AutoMigrate(&ColumnStruct{}); err != nil {
|
||||
t.Errorf("Failed to migrate, got %v", err)
|
||||
}
|
||||
|
||||
type ColumnStruct2 struct {
|
||||
ID int `gorm:"primarykey;default:auto_random()"`
|
||||
Name string `gorm:"size:100"`
|
||||
Code string `gorm:"unique;comment:my code2;default:hello"`
|
||||
Code2 string `gorm:"comment:my code2;default:hello"`
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
|
||||
t.Fatalf("no error should happened when alter column, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
|
||||
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
|
||||
}
|
||||
|
||||
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
|
||||
t.Fatalf("no error should returns for ColumnTypes")
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: DB}
|
||||
stmt.Parse(&ColumnStruct2{})
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "id":
|
||||
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
case "name":
|
||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||
}
|
||||
if length, ok := columnType.Length(); !ok || length != 100 {
|
||||
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
|
||||
}
|
||||
case "age":
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
|
||||
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my age" {
|
||||
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
case "code":
|
||||
if v, ok := columnType.Unique(); !ok || !v {
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my code2" {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
case "code2":
|
||||
// Code2 string `gorm:"comment:my code2;default:hello"`
|
||||
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !ok || v != "my code2" {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type NewColumnStruct struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
NewName string
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
||||
t.Fatalf("Failed to find added column")
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
|
||||
t.Fatalf("Found deleted column")
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
|
||||
t.Fatalf("Failed to found renamed column")
|
||||
}
|
||||
|
||||
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil {
|
||||
t.Fatalf("Failed to add column, got %v", err)
|
||||
}
|
||||
|
||||
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
|
||||
t.Fatalf("Found deleted column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateColumns(t *testing.T) {
|
||||
tidbSkip(t, "use another test case")
|
||||
|
||||
sqlite := DB.Dialector.Name() == "sqlite"
|
||||
sqlserver := DB.Dialector.Name() == "sqlserver"
|
||||
|
||||
|
@ -400,7 +571,7 @@ func TestMigrateColumns(t *testing.T) {
|
|||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
|
@ -715,6 +886,32 @@ func TestPrimarykeyID(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCurrentTimestamp(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
return
|
||||
}
|
||||
type CurrentTimestampTest struct {
|
||||
ID string `gorm:"primary_key"`
|
||||
TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"`
|
||||
}
|
||||
var err error
|
||||
err = DB.Migrator().DropTable(&CurrentTimestampTest{})
|
||||
if err != nil {
|
||||
t.Errorf("DropTable err:%v", err)
|
||||
}
|
||||
err = DB.AutoMigrate(&CurrentTimestampTest{})
|
||||
if err != nil {
|
||||
t.Fatalf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = DB.AutoMigrate(&CurrentTimestampTest{})
|
||||
if err != nil {
|
||||
t.Fatalf("AutoMigrate err:%v", err)
|
||||
}
|
||||
AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
|
||||
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
|
||||
}
|
||||
|
||||
func TestUniqueColumn(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
return
|
||||
|
@ -786,6 +983,8 @@ func TestUniqueColumn(t *testing.T) {
|
|||
AssertEqual(t, "", value)
|
||||
AssertEqual(t, false, ok)
|
||||
|
||||
tidbSkip(t, "can't change column constraint")
|
||||
|
||||
// null -> empty string
|
||||
err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{})
|
||||
if err != nil {
|
||||
|
@ -850,7 +1049,7 @@ func findColumnType(dest interface{}, columnName string) (
|
|||
return
|
||||
}
|
||||
|
||||
func TestInvalidCachedPlan(t *testing.T) {
|
||||
func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
@ -885,6 +1084,101 @@ func TestInvalidCachedPlan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
|
||||
if err != nil {
|
||||
t.Errorf("Open err:%v", err)
|
||||
}
|
||||
if debug := os.Getenv("DEBUG"); debug == "true" {
|
||||
db.Logger = db.Logger.LogMode(logger.Info)
|
||||
} else if debug == "false" {
|
||||
db.Logger = db.Logger.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
type Object1 struct {
|
||||
ID uint
|
||||
}
|
||||
type Object2 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int8"`
|
||||
}
|
||||
type Object3 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int4"`
|
||||
}
|
||||
type Object4 struct {
|
||||
ID uint
|
||||
Field2 int
|
||||
}
|
||||
db.Migrator().DropTable("objects")
|
||||
|
||||
err = db.Table("objects").AutoMigrate(&Object1{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
err = db.Table("objects").Create(&Object1{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("create err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object2{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object2{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AlterColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object3{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object3{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object4{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
|
||||
type DiffType struct {
|
||||
ID uint
|
||||
|
@ -959,3 +1253,247 @@ func TestMigrateArrayTypeModel(t *testing.T) {
|
|||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, "integer[]", ct.DatabaseTypeName())
|
||||
}
|
||||
|
||||
type mockMigrator struct {
|
||||
gorm.Migrator
|
||||
}
|
||||
|
||||
func (mm mockMigrator) AlterColumn(dst interface{}, field string) error {
|
||||
err := mm.Migrator.AlterColumn(dst, field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("trigger alter column error, field: %s", field)
|
||||
}
|
||||
|
||||
func TestMigrateDonotAlterColumn(t *testing.T) {
|
||||
var wrapMockMigrator = func(m gorm.Migrator) mockMigrator {
|
||||
return mockMigrator{
|
||||
Migrator: m,
|
||||
}
|
||||
}
|
||||
m := DB.Migrator()
|
||||
mockM := wrapMockMigrator(m)
|
||||
|
||||
type NotTriggerUpdate struct {
|
||||
ID uint
|
||||
F1 uint16
|
||||
F2 uint32
|
||||
F3 int
|
||||
F4 int64
|
||||
F5 string
|
||||
F6 float32
|
||||
F7 float64
|
||||
F8 time.Time
|
||||
F9 bool
|
||||
F10 []byte
|
||||
}
|
||||
|
||||
var err error
|
||||
err = mockM.DropTable(&NotTriggerUpdate{})
|
||||
AssertEqual(t, err, nil)
|
||||
err = mockM.AutoMigrate(&NotTriggerUpdate{})
|
||||
AssertEqual(t, err, nil)
|
||||
err = mockM.AutoMigrate(&NotTriggerUpdate{})
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
func TestMigrateSameEmbeddedFieldName(t *testing.T) {
|
||||
type UserStat struct {
|
||||
GroundDestroyCount int
|
||||
}
|
||||
|
||||
type GameUser struct {
|
||||
gorm.Model
|
||||
StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"`
|
||||
}
|
||||
|
||||
type UserStat1 struct {
|
||||
GroundDestroyCount string
|
||||
}
|
||||
|
||||
type GroundRate struct {
|
||||
GroundDestroyCount int
|
||||
}
|
||||
|
||||
type GameUser1 struct {
|
||||
gorm.Model
|
||||
StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"`
|
||||
GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&GameUser{})
|
||||
err := DB.AutoMigrate(&GameUser{})
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
|
||||
AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestMigrateDefaultNullString(t *testing.T) {
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver driver treats NULL and 'NULL' the same
|
||||
t.Skip("skip sqlserver")
|
||||
}
|
||||
|
||||
type NullModel struct {
|
||||
ID uint
|
||||
Content string `gorm:"default:null"`
|
||||
}
|
||||
|
||||
type NullStringModel struct {
|
||||
ID uint
|
||||
Content string `gorm:"default:'null'"`
|
||||
}
|
||||
|
||||
tableName := "null_string_model"
|
||||
|
||||
DB.Migrator().DropTable(tableName)
|
||||
|
||||
err := DB.Table(tableName).AutoMigrate(&NullModel{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
// default null -> 'null'
|
||||
err = DB.Table(tableName).AutoMigrate(&NullStringModel{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
columnType, err := findColumnType(tableName, "content")
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
defVal, ok := columnType.DefaultValue()
|
||||
AssertEqual(t, defVal, "null")
|
||||
AssertEqual(t, ok, true)
|
||||
|
||||
// default 'null' -> 'null'
|
||||
session := DB.Session(&gorm.Session{Logger: Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE") {
|
||||
t.Errorf("shouldn't execute: sql=%s", sql)
|
||||
}
|
||||
},
|
||||
}})
|
||||
err = session.Table(tableName).AutoMigrate(&NullStringModel{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
columnType, err = findColumnType(tableName, "content")
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
defVal, ok = columnType.DefaultValue()
|
||||
AssertEqual(t, defVal, "null")
|
||||
AssertEqual(t, ok, true)
|
||||
|
||||
// default 'null' -> null
|
||||
err = DB.Table(tableName).AutoMigrate(&NullModel{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
columnType, err = findColumnType(tableName, "content")
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
defVal, ok = columnType.DefaultValue()
|
||||
AssertEqual(t, defVal, "")
|
||||
AssertEqual(t, ok, false)
|
||||
}
|
||||
|
||||
func TestMigrateMySQLWithCustomizedTypes(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type MyTable struct {
|
||||
Def string `gorm:"size:512;index:idx_def,unique"`
|
||||
Abc string `gorm:"size:65000000"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("my_tables")
|
||||
|
||||
sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))"
|
||||
if err := DB.Exec(sql).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
session := DB.Session(&gorm.Session{Logger: Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE") {
|
||||
t.Errorf("shouldn't execute: sql=%s", sql)
|
||||
}
|
||||
},
|
||||
}})
|
||||
|
||||
if err := session.AutoMigrate(&MyTable{}); err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateIgnoreRelations(t *testing.T) {
|
||||
type RelationModel1 struct {
|
||||
ID uint
|
||||
}
|
||||
type RelationModel2 struct {
|
||||
ID uint
|
||||
}
|
||||
type RelationModel3 struct {
|
||||
ID uint
|
||||
RelationModel1ID uint
|
||||
RelationModel1 *RelationModel1
|
||||
RelationModel2ID uint
|
||||
RelationModel2 *RelationModel2 `gorm:"-:migration"`
|
||||
}
|
||||
|
||||
var err error
|
||||
_ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{})
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.IgnoreRelationshipsWhenMigrating = true
|
||||
|
||||
err = tx.AutoMigrate(&RelationModel3{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
// RelationModel3 should be existed
|
||||
_, err = findColumnType(&RelationModel3{}, "id")
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
// RelationModel1 should not be existed
|
||||
_, err = findColumnType(&RelationModel1{}, "id")
|
||||
if err == nil {
|
||||
t.Errorf("RelationModel1 should not be migrated")
|
||||
}
|
||||
|
||||
// RelationModel2 should not be existed
|
||||
_, err = findColumnType(&RelationModel2{}, "id")
|
||||
if err == nil {
|
||||
t.Errorf("RelationModel2 should not be migrated")
|
||||
}
|
||||
|
||||
tx.IgnoreRelationshipsWhenMigrating = false
|
||||
|
||||
err = tx.AutoMigrate(&RelationModel3{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
// RelationModel3 should be existed
|
||||
_, err = findColumnType(&RelationModel3{}, "id")
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
// RelationModel1 should be existed
|
||||
_, err = findColumnType(&RelationModel1{}, "id")
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
// RelationModel2 should not be existed
|
||||
_, err = findColumnType(&RelationModel2{}, "id")
|
||||
if err == nil {
|
||||
t.Errorf("RelationModel2 should not be migrated")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,60 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Yasuo struct {
|
||||
ID string `gorm:"default:gen_random_uuid()"`
|
||||
Name string
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Yasuo{})
|
||||
|
||||
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
yasuo := Yasuo{Name: "jinzhu"}
|
||||
if err := DB.Create(&yasuo).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
if yasuo.ID == "" {
|
||||
t.Fatal("should be able to has ID, but got zero value")
|
||||
}
|
||||
|
||||
var result Yasuo
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
yasuo.Name = "jinzhu1"
|
||||
if err := DB.Save(&yasuo).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgres(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
|
@ -60,6 +112,45 @@ func TestPostgres(t *testing.T) {
|
|||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("log_usage")
|
||||
|
||||
if err := DB.Exec(`
|
||||
CREATE TABLE public.log_usage (
|
||||
log_id bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||
SEQUENCE NAME public.log_usage_log_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("failed to create table, got error %v", err)
|
||||
}
|
||||
|
||||
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get columns, got error %v", err)
|
||||
}
|
||||
|
||||
hasLogID := false
|
||||
for _, column := range columns {
|
||||
if column.Name() == "log_id" {
|
||||
hasLogID = true
|
||||
autoIncrement, ok := column.AutoIncrement()
|
||||
if !ok || !autoIncrement {
|
||||
t.Fatalf("column log_id should be auto incrementment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLogID {
|
||||
t.Fatalf("failed to found column log_id")
|
||||
}
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
|
@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
|
|||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresOnConstraint(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Thing struct {
|
||||
gorm.Model
|
||||
SomeID string
|
||||
OtherID string
|
||||
Data string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Thing{})
|
||||
DB.Migrator().CreateTable(&Thing{})
|
||||
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
thing := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something",
|
||||
}
|
||||
|
||||
DB.Create(&thing)
|
||||
|
||||
thing2 := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something else",
|
||||
}
|
||||
|
||||
result := DB.Clauses(clause.OnConflict{
|
||||
OnConstraint: "some_id_other_id_unique",
|
||||
UpdateAll: true,
|
||||
}).Create(&thing2)
|
||||
if result.Error != nil {
|
||||
t.Errorf("creating second thing: %v", result.Error)
|
||||
}
|
||||
|
||||
var things []Thing
|
||||
if err := DB.Find(&things).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(things) > 1 {
|
||||
t.Errorf("expected 1 thing got more")
|
||||
}
|
||||
}
|
||||
|
||||
type CompanyNew struct {
|
||||
ID int
|
||||
Name int
|
||||
}
|
||||
|
||||
func TestAlterColumnDataType(t *testing.T) {
|
||||
DB.AutoMigrate(Company{})
|
||||
|
||||
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||
}
|
||||
|
||||
DB.AutoMigrate(Company{})
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package tests_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -88,3 +90,107 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
|||
}
|
||||
tx2.Commit()
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection()
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
user := User{Name: "jinzhu"}
|
||||
tx.Create(&user)
|
||||
|
||||
var result User
|
||||
tx.First(&result)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 2)
|
||||
for _, stmt := range conn.Stmts {
|
||||
if stmt == nil {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
}
|
||||
|
||||
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtError(t *testing.T) {
|
||||
tx, err := OpenTestConnection()
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// err prepare
|
||||
tag := Tag{Locale: "zh"}
|
||||
tx.Table("users").Find(&tag)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 0)
|
||||
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
user := User{Name: "jinzhu"}
|
||||
|
||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
|
||||
return errors.New("test")
|
||||
}); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.First(&result, user.ID).Error; err == nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -216,6 +216,30 @@ func TestFind(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// test array
|
||||
var models2 [3]User
|
||||
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 {
|
||||
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
CheckUser(t, models2[idx], user)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// test smaller array
|
||||
var models3 [2]User
|
||||
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 {
|
||||
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
|
||||
} else {
|
||||
for idx, user := range users[:2] {
|
||||
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
CheckUser(t, models3[idx], user)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var none []User
|
||||
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||
|
@ -384,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) {
|
|||
if totalBatch != 0 {
|
||||
t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch)
|
||||
}
|
||||
|
||||
if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
totalBatch += batch
|
||||
return nil
|
||||
}); result.Error != gorm.ErrPrimaryKeyRequired {
|
||||
t.Fatal("expected errors to have occurred, but nothing happened")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSmallerStruct(t *testing.T) {
|
||||
|
@ -1335,3 +1366,17 @@ func TestQueryResetNullValue(t *testing.T) {
|
|||
AssertEqual(t, q1, qs[0])
|
||||
AssertEqual(t, q2, qs[1])
|
||||
}
|
||||
|
||||
func TestQueryError(t *testing.T) {
|
||||
type P struct{}
|
||||
var p1 P
|
||||
err := DB.Take(&p1, 1).Error
|
||||
AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired)
|
||||
|
||||
var p2 interface{}
|
||||
|
||||
err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{
|
||||
Table: clause.CurrentTable, Name: clause.PrimaryKey,
|
||||
}, Value: 1}).Scan(&p2).Error
|
||||
AssertEqual(t, err, gorm.ErrModelValueRequired)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ type SerializerStruct struct {
|
|||
gorm.Model
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Roles2 *Roles `gorm:"serializer:json"`
|
||||
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
|
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
|
|||
}
|
||||
|
||||
var result SerializerStruct
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
@ -39,6 +40,11 @@ func TestSoftDelete(t *testing.T) {
|
|||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Errorf("Table with escape character, got %v", sql)
|
||||
}
|
||||
|
||||
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
|
@ -93,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) {
|
|||
t.Fatalf("invalid sql generated, got %v", actualSQL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDeleteZeroValue(t *testing.T) {
|
||||
type SoftDeleteBook struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pages uint
|
||||
DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"`
|
||||
}
|
||||
DB.Migrator().DropTable(&SoftDeleteBook{})
|
||||
if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate soft delete table")
|
||||
}
|
||||
|
||||
book := SoftDeleteBook{Name: "jinzhu", Pages: 10}
|
||||
DB.Save(&book)
|
||||
|
||||
var count int64
|
||||
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
|
||||
}
|
||||
|
||||
var pages uint
|
||||
if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||
t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages)
|
||||
}
|
||||
|
||||
if err := DB.Delete(&book).Error; err != nil {
|
||||
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
||||
}
|
||||
|
||||
zeroTime, _ := now.Parse("1970-01-01 00:00:01")
|
||||
if book.DeletedAt.Time.Equal(zeroTime) {
|
||||
t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt)
|
||||
}
|
||||
|
||||
if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
|
||||
count = 0
|
||||
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
|
||||
}
|
||||
|
||||
pages = 0
|
||||
if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 {
|
||||
t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err)
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil {
|
||||
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
||||
}
|
||||
|
||||
count = 0
|
||||
if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
|
||||
}
|
||||
|
||||
pages = 0
|
||||
if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages)
|
||||
}
|
||||
|
||||
DB.Unscoped().Delete(&book)
|
||||
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Errorf("Can't find permanently deleted record")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestRow(t *testing.T) {
|
|||
}
|
||||
|
||||
table := "gorm.users"
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
if DB.Dialector.Name() != "mysql" || isTiDB() {
|
||||
table = "users" // other databases doesn't support select with `database.table`
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) {
|
|||
t.Skip("Skip SQL Server for this test, because it too difference with other dialects.")
|
||||
}
|
||||
|
||||
date, _ := time.Parse("2006-01-02", "2021-10-18")
|
||||
date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local)
|
||||
|
||||
// find
|
||||
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
|
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
|
|||
if DB.Statement.DryRun || DB.DryRun {
|
||||
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
||||
}
|
||||
|
||||
// UpdateColumns
|
||||
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
|
||||
})
|
||||
})
|
||||
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
|
||||
}
|
||||
|
||||
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) {
|
|||
|
||||
AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3})
|
||||
}
|
||||
|
||||
type UserWithTableNamer struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
func (UserWithTableNamer) TableName(namer schema.Namer) string {
|
||||
return namer.TableName("user")
|
||||
}
|
||||
|
||||
func TestTableWithNamer(t *testing.T) {
|
||||
var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
TablePrefix: "t_",
|
||||
}})
|
||||
|
||||
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})
|
||||
})
|
||||
|
||||
if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) {
|
||||
t.Errorf("Table with namer, got %v", sql)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash -e
|
||||
|
||||
dialects=("sqlite" "mysql" "postgres" "sqlserver")
|
||||
dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
|
||||
|
||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||
cd ..
|
||||
|
|
|
@ -21,6 +21,7 @@ var (
|
|||
mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||
postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||
tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -80,6 +81,12 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||
dbDSN = sqlserverDSN
|
||||
}
|
||||
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
|
||||
case "tidb":
|
||||
log.Println("testing tidb...")
|
||||
if dbDSN == "" {
|
||||
dbDSN = tidbDSN
|
||||
}
|
||||
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
Logger logger.Interface
|
||||
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
|
||||
}
|
||||
|
||||
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
|
||||
return S.Logger.LogMode(level)
|
||||
}
|
||||
|
||||
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Info(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Warn(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Error(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
S.Logger.Trace(ctx, begin, fc, err)
|
||||
S.Test(ctx, begin, fc, err)
|
||||
}
|
|
@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||
return errors.New("the error message")
|
||||
})
|
||||
|
||||
if err.Error() != "the error message" {
|
||||
if err != nil && err.Error() != "the error message" {
|
||||
t.Fatalf("Transaction return error will equal the block returns error")
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package tests
|
|||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
|
@ -13,7 +14,14 @@ func (DummyDialector) Name() string {
|
|||
return "dummy"
|
||||
}
|
||||
|
||||
func (DummyDialector) Initialize(*gorm.DB) error {
|
||||
func (DummyDialector) Initialize(db *gorm.DB) error {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ package utils
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -16,7 +16,18 @@ var gormSourceDir string
|
|||
func init() {
|
||||
_, file, _, _ := runtime.Caller(0)
|
||||
// compatible solution to get gorm source directory with various operating systems
|
||||
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "")
|
||||
gormSourceDir = sourceDir(file)
|
||||
}
|
||||
|
||||
func sourceDir(file string) string {
|
||||
dir := filepath.Dir(file)
|
||||
dir = filepath.Dir(dir)
|
||||
|
||||
s := filepath.Dir(dir)
|
||||
if filepath.Base(s) != "gorm.io" {
|
||||
s = dir
|
||||
}
|
||||
return filepath.ToSlash(s) + "/"
|
||||
}
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsValidDBNameChar(t *testing.T) {
|
||||
|
@ -13,6 +18,29 @@ func TestIsValidDBNameChar(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCheckTruth(t *testing.T) {
|
||||
checkTruthTests := []struct {
|
||||
v string
|
||||
out bool
|
||||
}{
|
||||
{"123", true},
|
||||
{"true", true},
|
||||
{"", false},
|
||||
{"false", false},
|
||||
{"False", false},
|
||||
{"FALSE", false},
|
||||
{"\u0046alse", false},
|
||||
}
|
||||
|
||||
for _, test := range checkTruthTests {
|
||||
t.Run(test.v, func(t *testing.T) {
|
||||
if out := CheckTruth(test.v); out != test.out {
|
||||
t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringKey(t *testing.T) {
|
||||
cases := []struct {
|
||||
values []interface{}
|
||||
|
@ -29,3 +57,81 @@ func TestToStringKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
containsTests := []struct {
|
||||
name string
|
||||
elems []string
|
||||
elem string
|
||||
out bool
|
||||
}{
|
||||
{"exists", []string{"1", "2", "3"}, "1", true},
|
||||
{"not exists", []string{"1", "2", "3"}, "4", false},
|
||||
}
|
||||
for _, test := range containsTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if out := Contains(test.elems, test.elem); test.out != out {
|
||||
t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type ModifyAt sql.NullTime
|
||||
|
||||
// Value return a Unix time.
|
||||
func (n ModifyAt) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time.Unix(), nil
|
||||
}
|
||||
|
||||
func TestAssertEqual(t *testing.T) {
|
||||
now := time.Now()
|
||||
assertEqualTests := []struct {
|
||||
name string
|
||||
src, dst interface{}
|
||||
out bool
|
||||
}{
|
||||
{"error equal", errors.New("1"), errors.New("1"), true},
|
||||
{"error not equal", errors.New("1"), errors.New("2"), false},
|
||||
{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true},
|
||||
{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false},
|
||||
}
|
||||
for _, test := range assertEqualTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if out := AssertEqual(test.src, test.dst); test.out != out {
|
||||
t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in interface{}
|
||||
out string
|
||||
}{
|
||||
{"int", math.MaxInt64, "9223372036854775807"},
|
||||
{"int8", int8(math.MaxInt8), "127"},
|
||||
{"int16", int16(math.MaxInt16), "32767"},
|
||||
{"int32", int32(math.MaxInt32), "2147483647"},
|
||||
{"int64", int64(math.MaxInt64), "9223372036854775807"},
|
||||
{"uint", uint(math.MaxUint64), "18446744073709551615"},
|
||||
{"uint8", uint8(math.MaxUint8), "255"},
|
||||
{"uint16", uint16(math.MaxUint16), "65535"},
|
||||
{"uint32", uint32(math.MaxUint32), "4294967295"},
|
||||
{"uint64", uint64(math.MaxUint64), "18446744073709551615"},
|
||||
{"string", "abc", "abc"},
|
||||
{"other", true, ""},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if out := ToString(test.in); test.out != out {
|
||||
t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//go:build unix
|
||||
// +build unix
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSourceDir(t *testing.T) {
|
||||
cases := []struct {
|
||||
file string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go",
|
||||
want: "/Users/name/go/pkg/mod/gorm.io/",
|
||||
},
|
||||
{
|
||||
file: "/go/work/proj/gorm/utils/utils.go",
|
||||
want: "/go/work/proj/gorm/",
|
||||
},
|
||||
{
|
||||
file: "/go/work/proj/gorm_alias/utils/utils.go",
|
||||
want: "/go/work/proj/gorm_alias/",
|
||||
},
|
||||
{
|
||||
file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go",
|
||||
want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
s := sourceDir(c.file)
|
||||
if s != c.want {
|
||||
t.Fatalf("%s: expected %s, got %s", c.file, c.want, s)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSourceDir(t *testing.T) {
|
||||
cases := []struct {
|
||||
file string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`,
|
||||
want: `C:/Users/name/go/pkg/mod/gorm.io/`,
|
||||
},
|
||||
{
|
||||
file: `C:/go/work/proj/gorm/utils/utils.go`,
|
||||
want: `C:/go/work/proj/gorm/`,
|
||||
},
|
||||
{
|
||||
file: `C:/go/work/proj/gorm_alias/utils/utils.go`,
|
||||
want: `C:/go/work/proj/gorm_alias/`,
|
||||
},
|
||||
{
|
||||
file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`,
|
||||
want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
s := sourceDir(c.file)
|
||||
if s != c.want {
|
||||
t.Fatalf("%s: expected %s, got %s", c.file, c.want, s)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue