Compare commits

...

62 Commits

Author SHA1 Message Date
Cheese 02b7e26f6b
feat: add tidb integration test cases (#6014)
* feat: support tidb integration test

* feat: update the mysql driver version to test
2023-02-08 16:29:09 +08:00
Cr 878ac51e98
fix:throw model value required error (#6031)
* fix:throw model value required error

* chore:ingore typecheck

* chore:ingore errcheck

* refactor: use other error

* chore: gofumpt style
2023-02-08 13:40:41 +08:00
chyroc e1f46eb802
fix: ignore nil query (#6021) 2023-02-02 17:54:51 +08:00
Jinzhu 4d6b70ec88 Allow modify statement from dest 2023-02-02 17:15:08 +08:00
qiankunli cfbcedbf03
fix: support zeroValue tag on DeletedAt (#6011)
* fix: support zeroValue tag on DeletedAt

Signed-off-by: qiankunli <qiankun.li@qq.com>

* Update soft_delete_test.go

* Update tests_test.go

* Update soft_delete.go

---------

Signed-off-by: qiankunli <qiankun.li@qq.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-02-01 14:40:55 +08:00
Jinzhu d834dd60b7 Remove unnecessary code 2023-01-19 15:22:13 +08:00
Jinzhu 3d35ddba55 Fix use table.* as select/omit columns 2023-01-12 16:52:56 +08:00
Haibo baf1afa1fc
fix(schema): field is only unique when there is one unique index (#5974) 2023-01-11 14:05:39 +08:00
Jinzhu 2bc913787b support implicit table alias, close #5840 #5940 2023-01-02 21:46:27 +08:00
Jinzhu 3d91802b1d Fix unexpected alter table in auto migration, close #5942, #5943 2023-01-02 21:06:04 +08:00
Jinzhu b0e13d95b4 update github tests action 2023-01-01 22:27:49 +08:00
Jinzhu 4b768c8aff Upgrade tests deps 2023-01-01 22:22:08 +08:00
Haibo 16a272209a
fix(migrator): Tag default:'null' always causes field migration #5953 (#5954)
* fix(migrator): Tag default:'null' always causes field migration #5953

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2023-01-01 22:14:28 +08:00
Haibo da2b2861de
fix(migrator): ignore relationships when migrating #5913 (#5946) 2023-01-01 19:54:28 +08:00
dependabot[bot] 7da24d1d52
chore(deps): bump actions/stale from 6 to 7 (#5945)
Bumps [actions/stale](https://github.com/actions/stale) from 6 to 7.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-12-27 08:47:17 +08:00
Jinzhu ddd3cc2502 Add ParameterizedQueries option support for logger, close #5288 2022-12-25 11:37:23 +08:00
Cr 794edad60e
test(MigrateColumn): mock alter column to improve field compare (#5499)
* test(MigrateColumn): mock alter column to improve field compare

* Update migrate_test.go

* Update migrate_test.go

* Update migrate_test.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 17:42:16 +08:00
Cr 1935eb0adb
feat: support inner join (#5583)
* feat: support inner join

* test: mixed inner join and left join

* chore: code comment

* Update statement.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:27:38 +08:00
Defoo Li 775fa70af5
DryRun for migrator (#5689)
* DryRun for migrator

* Update migrator.go

* Update migrator.go

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-12-24 12:14:23 +08:00
Ning bbd2bbe521
fix:Issue migrating field with CURRENT_TIMESTAMP (#5906)
Co-authored-by: ningfei <accelerator314@outlook.com>
2022-12-24 11:02:11 +08:00
Nate Armstrong f3c6fc2533
Update func comments in chainable_api and FirstOr_ (#5935)
Add comments to functions in chainable_api. Depending on the method,
these comments add some additional context or details that are relevant
when reading the function, link to the actual docs at gorm.io/docs, or
provide examples of use. These comments should make GORM much more
pleasant to use with an IDE that provides hoverable comments, and are
minimal examples.

Also add in-code documentation to FirstOrInit and FirstOrCreate.

Almost all examples are directly pulled from the docs, with short
comments explaining the code. Most examples omit the `db.Model(&User{})`
for brevity, and would not actually work.

Co-authored-by: Nate Armstrong <nate.armstrong@eluv.io>
2022-12-23 16:51:01 +08:00
Edward McFarlane 4ec73c9bf4
Add test case for embedded value selects (#5901)
* Add test case for embedded value selects

* Revert recycle struct optimisation to avoid pointer overwrites
2022-12-19 11:49:05 +08:00
Cr d9525d4da4
fix: skip append relation field to default db value (#5885)
* fix: relation field returning

* chore: gofumpt style
2022-12-01 20:26:59 +08:00
wjw1758548031 f931def33d
clear code syntax (#5889)
* clear code syntax

* clear code syntax
2022-12-01 20:25:53 +08:00
Jinzhu f91313436a Fix group by with count logic 2022-11-21 11:10:56 +08:00
Cr 342310fba4
fix(FindInBatches): throw err if pk not exists (#5868) 2022-11-21 10:49:27 +08:00
kvii b6836c2d3e
fix bug in windows (#5844)
* fix bug in windows

* fix file name bug

* test in unix like platform
2022-11-21 10:48:13 +08:00
jessetang cef3de694d
cleanup(prepare_stmt.go): unnecessary map delete (#5849) 2022-11-13 11:12:09 +08:00
jessetang 1b9cd56c53
doc(README.md): add contributors (#5847) 2022-11-10 16:30:32 +08:00
kvii 871f1de6b9
fix logger path bug (#5836) 2022-11-05 11:52:08 +08:00
jessetang fb640cf7da
test(utils): add utils unit test (#5834) 2022-11-05 08:38:14 +08:00
jessetang 5c8ecc3a2a
feat: golangci add goimports and whitespace (#5835) 2022-11-05 08:37:37 +08:00
jessetang f82e9cfdbe
test(clause/joins): add join unit test (#5832) 2022-11-03 21:03:13 +08:00
Cr b2f42528a4
fix(Joins): args with select and omit (#5790)
* fix(Joins): args with select and omit

* chore: gofumpt style
2022-11-02 10:28:00 +08:00
Cr 9d82aa5673
test: invalid cache plan with prepare stmt (#5778)
* test: invalid cache plan with prepare stmt

* test: more test cases

* test: drop and rename column
2022-10-20 14:10:47 +08:00
Cr 5dd2bb4827
feat(PreparedStmtDB): support reset (#5782)
* feat(PreparedStmtDB): support reset

* fix: close all stmt

* test: fix test

* fix: delete one by one
2022-10-19 14:46:59 +08:00
Jinzhu 3f20a543fa Support use clause.Interface as query params 2022-10-18 18:01:55 +08:00
viatoriche / Maxim Panfilov 62593cfad0 add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 2022-10-18 17:28:06 +08:00
Jinzhu a0f4d3f7d2 Save as empty string for not nullable nil field serialized into json 2022-10-18 16:25:39 +08:00
Jinzhu ab5f80a8d8 Save as NULL for nil object serialized into json 2022-10-18 15:44:56 +08:00
Cr 186e8a9e14
fix: association without pks (#5779) 2022-10-18 11:58:42 +08:00
Jinzhu 2a788fb20c Upgrade tests go.mod 2022-10-17 17:01:42 +08:00
Jinzhu aa4312ee74 Don't display any GORM related package path as source 2022-10-17 17:01:42 +08:00
Jinzhu 08aa2f9888 Update README 2022-10-14 20:30:41 +08:00
Jinzhu 2c56954cb1 tests mariadb with returning support 2022-10-08 20:48:22 +08:00
Jinzhu e93dc3426e Test postgres autoincrement check 2022-10-08 17:16:32 +08:00
Jinzhu 983e96f142 Add tests for alter column type 2022-10-08 16:04:57 +08:00
Jinzhu 34fbe84580 Add TableName with NamingStrategy support, close #5726 2022-10-07 21:18:37 +08:00
robhafner e8f48b5c15
fix: limit=0 results (#5735) (#5736) 2022-10-07 20:14:14 +08:00
jesse.tang 4b22a55a75
fix: primaryFields are overwritten (#5721) 2022-10-07 18:29:28 +08:00
Wen Sun 9564b82975
Fix OnConstraint builder (#5738) 2022-10-07 13:46:20 +08:00
Cr 0b7113b618
fix: prepare deadlock (#5568)
* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-30 18:13:36 +08:00
Stephano George a3cc6c6088 Fix: wrong value when Find with Join with same column name, close #5723, #5711 2022-09-30 17:18:42 +08:00
jesse.tang be440e7512
fix possible nil panic in tests (#5720)
* fix maybe nil panic

* reset code
2022-09-30 11:14:34 +08:00
dependabot[bot] e1dd0dcbc4
chore(deps): bump actions/stale from 5 to 6 (#5717)
Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-09-30 11:13:01 +08:00
Nguyen Huu Tuan 328f301982
add some test case which related the logic (#5477) 2022-09-22 18:35:21 +08:00
kinggo 12237454ed fix: use preparestmt in trasaction will use new conn, close #5508 2022-09-22 16:47:43 +08:00
Cr 73bc53f061
feat: migrator support type aliases (#5627)
* feat: migrator support type aliases

* perf: check type
2022-09-22 15:56:32 +08:00
Cr 101a7c789f
fix: scan array (#5624)
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2022-09-22 15:51:47 +08:00
Jinzhu 3a72ba102e Allow shared foreign key for many2many jointable 2022-09-22 15:03:41 +08:00
jesse.tang 1f634c3937
support scan assign slice cap (#5634)
* support scan assign slice cap

* fix
2022-09-22 14:50:35 +08:00
Cr 5ed7b1a65e
fix: same embedded filed name (#5705) 2022-09-22 11:25:03 +08:00
64 changed files with 2280 additions and 338 deletions

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v5 uses: actions/stale@v7
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v5 uses: actions/stale@v7
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v5 uses: actions/stale@v7
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"

View File

@ -16,7 +16,7 @@ jobs:
sqlite: sqlite:
strategy: strategy:
matrix: matrix:
go: ['1.19', '1.18', '1.17', '1.16'] go: ['1.19', '1.18']
platform: [ubuntu-latest] # can not run in windows OS platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -42,7 +42,7 @@ jobs:
strategy: strategy:
matrix: matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.19', '1.18', '1.17', '1.16'] go: ['1.19', '1.18']
platform: [ubuntu-latest] platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -86,7 +86,7 @@ jobs:
strategy: strategy:
matrix: matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.19', '1.18', '1.17', '1.16'] go: ['1.19', '1.18']
platform: [ubuntu-latest] # can not run in macOS and Windows platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -128,7 +128,7 @@ jobs:
sqlserver: sqlserver:
strategy: strategy:
matrix: matrix:
go: ['1.19', '1.18', '1.17', '1.16'] go: ['1.19', '1.18']
platform: [ubuntu-latest] # can not run test in macOS and windows platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -167,3 +167,36 @@ jobs:
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: [ '1.19', '1.18' ]
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh

View File

@ -9,3 +9,12 @@ linters:
- prealloc - prealloc
- unconvert - unconvert
- unparam - unparam
- goimports
- whitespace
linters-settings:
whitespace:
multi-func: true
goimports:
local-prefixes: gorm.io/gorm

View File

@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
@ -30,12 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
Thank you for contributing to the GORM framework!
[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors)
## License ## License
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now

View File

@ -93,6 +93,10 @@ func (p *processor) Execute(db *DB) *DB {
resetBuildClauses = true resetBuildClauses = true
} }
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values // assign model values
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -132,7 +136,11 @@ func (p *processor) Execute(db *DB) *DB {
if stmt.SQL.Len() > 0 { if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error) }, db.Error)
} }

View File

@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem) distinctElems = reflect.Append(distinctElems, elem)
} }

View File

@ -117,13 +117,21 @@ func BuildQuerySQL(db *gorm.DB) {
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name tableAliasName := relation.Name
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames { for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName, Table: tableAliasName,
Name: s, Name: s,
Alias: tableAliasName + "__" + s, Alias: tableAliasName + "__" + s,
}) })
} }
}
exprs := make([]clause.Expression, len(relation.References)) exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References { for idx, ref := range relation.References {
@ -177,7 +185,7 @@ func BuildQuerySQL(db *gorm.DB) {
} }
fromClause.Joins = append(fromClause.Joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Type: clause.LeftJoin, Type: join.JoinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs}, ON: clause.Where{Exprs: exprs},
}) })

View File

@ -7,7 +7,7 @@ import (
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
BuildQuerySQL(db) BuildQuerySQL(db)
if db.DryRun { if db.DryRun || db.Error != nil {
return return
} }

View File

@ -10,9 +10,10 @@ import (
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
//
// // update all users's name to `hello` // // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello") // db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
} }
// Clauses Add clauses // Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
var whereConds []interface{} var whereConds []interface{}
@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
return return
} }
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) { func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1] tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
} }
} else if tables := strings.Split(name, "."); len(tables) == 2 { } else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
} }
// Distinct specify distinct fields that you want querying // Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) { func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Distinct = true tx.Statement.Distinct = true
@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
} }
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -152,6 +186,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
} }
// Where add conditions // Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -161,6 +206,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
} }
// Not add NOT conditions // Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -170,6 +220,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
} }
// Or add OR conditions // Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -179,26 +234,42 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
} }
// Joins specify Joins conditions // Joins specify Joins conditions
//
// db.Joins("Account").Find(&user) // db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(args) == 1 { if len(args) == 1 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return return
} }
} }
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return return
} }
// Group specify the group method on the find // Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) { func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -210,6 +281,9 @@ func (db *DB) Group(name string) (tx *DB) {
} }
// Having specify HAVING conditions for GROUP BY // Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
@ -218,7 +292,8 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
// Order specify order when retrieve records from database // Order specify order when retrieving records from database
//
// db.Order("name DESC") // db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
@ -242,13 +317,27 @@ func (db *DB) Order(value interface{}) (tx *DB) {
} }
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit}) tx.Statement.AddClause(clause.Limit{Limit: &limit})
return return
} }
// Offset specify the number of records to skip before starting to return the records // Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) { func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset}) tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -256,6 +345,7 @@ func (db *DB) Offset(offset int) (tx *DB) {
} }
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000) // return db.Where("amount > ?", 1000)
// } // }
@ -274,6 +364,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
} }
// Preload preload associations with given conditions // Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -284,12 +376,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return return
} }
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.attrs = attrs tx.Statement.attrs = attrs
return return
} }
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.assigns = attrs tx.Statement.assigns = attrs

View File

@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{ clauses := []clause.Interface{
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}}, }},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
} }

View File

@ -20,6 +20,7 @@ type Builder interface {
Writer Writer
WriteQuoted(field interface{}) WriteQuoted(field interface{})
AddVar(Writer, ...interface{}) AddVar(Writer, ...interface{})
AddError(error) error
} }
// Clause // Clause

View File

@ -9,7 +9,7 @@ const (
RightJoin JoinType = "RIGHT" RightJoin JoinType = "RIGHT"
) )
// Join join clause for from // Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType
Table Table Table Table

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -4,7 +4,7 @@ import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Limit int Limit *int
Offset int Offset int
} }
@ -15,12 +15,12 @@ func (limit Limit) Name() string {
// Build build where clause // Build build where clause
func (limit Limit) Build(builder Builder) { func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit)) builder.WriteString(strconv.Itoa(*limit.Limit))
} }
if limit.Offset > 0 { if limit.Offset > 0 {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ') builder.WriteByte(' ')
} }
builder.WriteString("OFFSET ") builder.WriteString("OFFSET ")
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 { if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
limit.Limit = v.Limit limit.Limit = v.Limit
} }

View File

@ -8,6 +8,10 @@ import (
) )
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
@ -15,11 +19,15 @@ func TestLimit(t *testing.T) {
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10, Limit: &limit10,
Offset: 20, Offset: 20,
}}, }},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil, "SELECT * FROM `users` OFFSET 20", nil,
@ -29,23 +37,23 @@ func TestLimit(t *testing.T) {
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil, "SELECT * FROM `users` LIMIT 10", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
}, },
} }

View File

@ -16,6 +16,11 @@ func (OnConflict) Name() string {
// Build build onConflict clause // Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) { func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 { if len(onConflict.Columns) > 0 {
builder.WriteByte('(') builder.WriteByte('(')
for idx, column := range onConflict.Columns { for idx, column := range onConflict.Columns {
@ -32,11 +37,6 @@ func (onConflict OnConflict) Build(builder Builder) {
onConflict.TargetWhere.Build(builder) onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ') builder.WriteByte(' ')
} }
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} }
if onConflict.DoNothing { if onConflict.DoNothing {

View File

@ -21,6 +21,8 @@ var (
ErrPrimaryKeyRequired = errors.New("primary key required") ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required // ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required") ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrInvalidData unsupported data // ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data") ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver // ErrUnsupportedDriver unsupported driver

View File

@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
var totalSize int var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok { if limit, ok := c.Expression.(clause.Limit); ok {
totalSize = limit.Limit if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize { if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize batchSize = totalSize
@ -229,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
@ -288,6 +294,16 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map. // Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -315,12 +331,30 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map. // Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if result := queryTx.Find(dest, conds...); result.Error == nil {
result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok { if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
@ -349,16 +383,13 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
assigns[column] = eq.Value assigns[column] = eq.Value
case clause.Column: case clause.Column:
assigns[column.Name] = eq.Value assigns[column.Name] = eq.Value
default:
} }
} }
} }
return tx.Model(dest).Updates(assigns) return tx.Model(dest).Updates(assigns)
} }
} else {
tx.Error = result.Error
}
return tx return tx
} }
@ -459,7 +490,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
tx.Statement.Dest = count tx.Statement.Dest = count
tx = tx.callbacks.Query().Execute(tx) tx = tx.callbacks.Query().Execute(tx)
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { if tx.RowsAffected != 1 {
*count = tx.RowsAffected *count = tx.RowsAffected
} }
@ -512,6 +543,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
} }
// Pluck queries a single column from a model, returning in the slice dest. E.g.: // Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64 // var ages []int64
// db.Model(&users).Pluck("age", &ages) // db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {

12
gorm.go
View File

@ -37,6 +37,8 @@ type Config struct {
DisableAutomaticPing bool DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating // DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction // DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool DisableNestedTransaction bool
// AllowGlobalUpdate allow global update // AllowGlobalUpdate allow global update
@ -179,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
preparedStmt := &PreparedStmtDB{ preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool, ConnPool: db.ConnPool,
Stmts: map[string]Stmt{}, Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{}, Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100), PreparedSQL: make([]string, 0, 100),
} }
@ -248,11 +250,19 @@ func (db *DB) Session(config *Session) *DB {
if config.PrepareStmt { if config.PrepareStmt {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB) preparedStmt := v.(*PreparedStmtDB)
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux, Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts, Stmts: preparedStmt.Stmts,
} }
}
txConfig.ConnPool = tx.Statement.ConnPool txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true txConfig.PrepareStmt = true
} }

View File

@ -26,6 +26,10 @@ type Plugin interface {
Initialize(*DB) error Initialize(*DB) error
} }
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface // ConnPool db conns pool interface
type ConnPool interface { type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)

View File

@ -55,6 +55,7 @@ type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
IgnoreRecordNotFoundError bool IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel LogLevel LogLevel
} }
@ -181,6 +182,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
} }
} }
// Trace print sql message
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct { type traceRecorder struct {
Interface Interface
BeginAt time.Time BeginAt time.Time

View File

@ -68,6 +68,7 @@ type Migrator interface {
// Database // Database
CurrentDatabase() string CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables // Tables
CreateTable(dst ...interface{}) error CreateTable(dst ...interface{}) error

View File

@ -8,9 +8,11 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -30,6 +32,16 @@ type Config struct {
gorm.Dialector gorm.Dialector
} }
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface // GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string GormDBDataType(*gorm.DB, *schema.Field) string
@ -92,18 +104,26 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
// AutoMigrate auto migrate values // AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) { for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{}) queryTx := m.DB.Session(&gorm.Session{})
if !tx.Migrator().HasTable(value) { execTx := queryTx
if err := tx.Migrator().CreateTable(value); err != nil { if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err return err
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, err := m.DB.Migrator().ColumnTypes(value) columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil { if err != nil {
return err return err
} }
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
@ -117,37 +137,40 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil { if foundColumn == nil {
// not found, add column // not found, add column
if err := tx.Migrator().AddColumn(value, dbName); err != nil { if err := execTx.Migrator().AddColumn(value, dbName); err != nil {
return err return err
} }
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate // found, smart migrate
return err return err
} }
} }
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil && if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err return err
} }
} }
} }
} }
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range parseCheckConstraints {
if !tx.Migrator().HasConstraint(value, chk.Name) { if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err return err
} }
} }
} }
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range parseIndexes {
if !tx.Migrator().HasIndex(value, idx.Name) { if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err return err
} }
} }
@ -227,8 +250,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
} }
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.DisableForeignKeyConstraintWhenMigrating { if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil { if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema { if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint) sql, vars := buildConstraint(constraint)
@ -406,13 +432,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false var (
alterColumn bool
isSameType = fullDataType == realDataType
)
if !field.PrimaryKey {
// check type // check type
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { if !strings.HasPrefix(fullDataType, realDataType) {
alterColumn = true // check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
} }
if !isSameType {
alterColumn = true
}
}
}
if !isSameType {
// check size // check size
if length, ok := columnType.Length(); length != int64(field.Size) { if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 { if length > 0 && field.Size > 0 {
@ -434,6 +477,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true alterColumn = true
} }
} }
}
// check nullable // check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
@ -453,17 +497,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
// check default value // check default value
if !field.PrimaryKey { if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue() dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && field.DefaultValueInterface == nil { if dvNotNull && !currentDefaultNotNull {
// defalut value -> null // defalut value -> null
alterColumn = true alterColumn = true
} else if !dvNotNull && field.DefaultValueInterface != nil { } else if !dvNotNull && currentDefaultNotNull {
// null -> default value // null -> default value
alterColumn = true alterColumn = true
} else if dv != field.DefaultValue { } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal // default value not equal
// not both null // not both null
if !(field.DefaultValueInterface == nil && !dvNotNull) { if currentDefaultNotNull || dvNotNull {
alterColumn = true alterColumn = true
} }
} }
@ -478,7 +524,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
if alterColumn && !field.IgnoreMigration { if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name) return m.DB.Migrator().AlterColumn(value, field.DBName)
} }
return nil return nil
@ -782,7 +828,11 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
} }
parsedSchemas[dep.Statement.Schema] = true parsedSchemas[dep.Statement.Schema] = true
if !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range dep.Schema.Relationships.Relations { for _, rel := range dep.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema) dep.Depends = append(dep.Depends, c.ReferenceSchema)
} }
@ -804,6 +854,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
}(rel, reflect.New(rel.JoinTable.ModelType).Interface()) }(rel, reflect.New(rel.JoinTable.ModelType).Interface())
} }
} }
}
valuesMap[dep.Schema.Table] = dep valuesMap[dep.Schema.Table] = dep
@ -863,3 +914,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support") return nil, errors.New("not support")
} }
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}

View File

@ -9,10 +9,12 @@ import (
type Stmt struct { type Stmt struct {
*sql.Stmt *sql.Stmt
Transaction bool Transaction bool
prepared chan struct{}
prepareErr error
} }
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]Stmt Stmts map[string]*Stmt
PreparedSQL []string PreparedSQL []string
Mux *sync.RWMutex Mux *sync.RWMutex
ConnPool ConnPool
@ -42,31 +44,72 @@ func (db *PreparedStmtDB) Close() {
} }
} }
func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, stmt := range db.Stmts {
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock() db.Mux.RUnlock()
return stmt, nil // wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
} }
db.Mux.RUnlock() db.Mux.RUnlock()
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
// double check // double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
return stmt, nil db.Mux.Unlock()
} else if ok { // wait for other goroutines prepared
go stmt.Close() <-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
} }
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query) stmt, err := conn.PrepareContext(ctx, query)
if err == nil { if err != nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} cacheStmt.prepareErr = err
db.PreparedSQL = append(db.PreparedSQL, query) db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
} }
return db.Stmts[query], err db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
} }
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {

49
scan.go
View File

@ -65,7 +65,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}) joinedSchemaMap := make(map[*schema.Field]interface{})
for idx, field := range fields { for idx, field := range fields {
if field == nil { if field == nil {
@ -164,7 +163,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
default: default:
var ( var (
fields = make([]*schema.Field, len(columns)) fields = make([]*schema.Field, len(columns))
selectedColumnsMap = make(map[string]int, len(columns))
joinFields [][2]*schema.Field joinFields [][2]*schema.Field
sch = db.Statement.Schema sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
@ -200,26 +198,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
// Not Pluck // Not Pluck
if sch != nil { if sch != nil {
schFieldsCount := len(sch.Fields) matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns { for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable { if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok { fields[idx] = field
fields[idx] = field // handle duplicate fields if count, ok := matchedFieldCount[column]; ok {
offset := curIndex + 1 // handle duplicate fields
// handle sch inconsistent with database for _, selectField := range sch.Fields {
// like Raw(`...`).Scan
if schFieldsCount > offset {
for fieldIndex, selectField := range sch.Fields[offset:] {
if selectField.DBName == column && selectField.Readable { if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1 if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField fields[idx] = selectField
break break
} }
count--
} }
} }
} else { } else {
fields[idx] = field matchedFieldCount[column] = 1
selectedColumnsMap[column] = idx
} }
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok { if rel, ok := sch.Relationships.Relations[names[0]]; ok {
@ -243,12 +239,20 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var elem reflect.Value var (
recyclableStruct := reflect.New(reflectValueType) elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 { if !update || reflectValue.Len() == 0 {
update = false update = false
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
} }
for initialized || rows.Next() { for initialized || rows.Next() {
@ -269,20 +273,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} }
} }
} else { } else {
if isPtr && db.RowsAffected > 0 {
elem = reflect.New(reflectValueType) elem = reflect.New(reflectValueType)
} else {
elem = recyclableStruct
}
} }
db.scanIntoStruct(rows, elem, values, fields, joinFields) db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update { if !update {
if isPtr { if !isPtr {
reflectValue = reflect.Append(reflectValue, elem) elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else { } else {
reflectValue = reflect.Append(reflectValue, elem.Elem()) reflectValue = reflect.Append(reflectValue, elem)
} }
} }
} }

View File

@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index {
} }
} }
} }
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.Unique = true
}
}
return indexes return indexes
} }

View File

@ -65,7 +65,7 @@ func TestParseIndex(t *testing.T) {
"idx_name": { "idx_name": {
Name: "idx_name", Name: "idx_name",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}},
}, },
"idx_user_indices_name3": { "idx_user_indices_name3": {
Name: "idx_user_indices_name3", Name: "idx_user_indices_name3",
@ -81,7 +81,7 @@ func TestParseIndex(t *testing.T) {
"idx_user_indices_name4": { "idx_user_indices_name4": {
Name: "idx_user_indices_name4", Name: "idx_user_indices_name4",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}},
}, },
"idx_user_indices_name5": { "idx_user_indices_name5": {
Name: "idx_user_indices_name5", Name: "idx_user_indices_name5",
@ -102,12 +102,12 @@ func TestParseIndex(t *testing.T) {
}, },
"idx_id": { "idx_id": {
Name: "idx_id", Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}},
}, },
"idx_oid": { "idx_oid": {
Name: "idx_oid", Name: "idx_oid",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}},
}, },
"type": { "type": {
Name: "type", Name: "type",
@ -168,6 +168,9 @@ func TestParseIndex(t *testing.T) {
if rf.Field.Name != ef.Field.Name { if rf.Field.Name != ef.Field.Name {
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
} }
if rf.Field.Unique != ef.Field.Unique {
t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique)
}
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {

View File

@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error err error
joinTableFields []reflect.StructField joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{} fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
) )
@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
joinFieldName = strings.Title(joinForeignKeys[idx]) joinFieldName = strings.Title(joinForeignKeys[idx])
} }
ownFieldsMap[joinFieldName] = true ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
for idx, relField := range refForeignFields { for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
if _, ok := ownFieldsMap[joinFieldName]; ok { if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name { if field.Name != relation.FieldSchema.Name {
@ -254,6 +252,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
@ -263,6 +268,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
"column", "autoincrement", "index", "unique", "uniqueindex"), "column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
}
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name, Name: strings.Title(schema.Name) + field.Name,
@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
f.Size = fieldsMap[f.Name].Size f.Size = fieldsMap[f.Name].Size
} }
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPrimaryField { if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{ joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: of,
ForeignKey: f, ForeignKey: f,
}) })
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil { if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field joinRefRel.Field = relation.Field
} }
joinRefRel.References = append(joinRefRel.References, &Reference{ joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name], PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f, ForeignKey: f,
}) })
} }
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPrimaryField,
})
} }
} }
} }
@ -391,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs: case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs: case guessEmbeddedBelongs:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas: case guessHas:
case guessEmbeddedHas: case guessEmbeddedHas:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} }
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys { for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil { f := foreignSchema.LookUpField(foreignKey)
foreignFields = append(foreignFields, f) if f == nil {
} else {
reguessOrErr() reguessOrErr()
return return
} }
foreignFields = append(foreignFields, f)
} }
} else { } else {
var primaryFields []*Field
var primarySchemaName = primarySchema.Name var primarySchemaName = primarySchema.Name
if primarySchemaName == "" { if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name primarySchemaName = relation.FieldSchema.Name
@ -454,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} }
if len(foreignFields) == 0 { switch {
case len(foreignFields) == 0:
reguessOrErr() reguessOrErr()
return return
} else if len(relation.primaryKeys) > 0 { case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil { if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 { if len(primaryFields) < idx+1 {
@ -471,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return return
} }
} }
} else if len(primaryFields) == 0 { case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) { } else if len(primarySchema.PrimaryFields) == len(foreignFields) {

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema") t.Errorf("Failed to parse schema, got error %v", err)
} else { } else {
for _, rel := range relations { for _, rel := range relations {
checkSchemaRelation(t, s, rel) checkSchemaRelation(t, s, rel)
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
}) })
} }
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model

View File

@ -71,6 +71,10 @@ type Tabler interface {
TableName() string TableName() string
} }
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "") return ParseWithSpecialTableName(dest, cacheStore, namer, "")
@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }
@ -223,7 +230,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.HasDefaultValue && field.DefaultValueInterface == nil { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
} }
} }

View File

@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
// Value implements serializer interface // Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue) result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err return string(result), err
} }

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -45,10 +46,20 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
} }
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f}} return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
} }
type SoftDeleteQueryClause struct { type SoftDeleteQueryClause struct {
ZeroValue sql.NullString
Field *schema.Field Field *schema.Field
} }
@ -78,17 +89,18 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{ stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}}) }})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{} stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
} }
} }
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f}} return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteUpdateClause struct { type SoftDeleteUpdateClause struct {
ZeroValue sql.NullString
Field *schema.Field Field *schema.Field
} }
@ -109,10 +121,11 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
} }
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f}} return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
} }
type SoftDeleteDeleteClause struct { type SoftDeleteDeleteClause struct {
ZeroValue sql.NullString
Field *schema.Field Field *schema.Field
} }

View File

@ -52,6 +52,9 @@ type join struct {
Name string Name string
Conds []interface{} Conds []interface{}
On *clause.Where On *clause.Where
Selects []string
Omits []string
JoinType clause.JoinType
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -117,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 { } else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0]) write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
} }
} else { } else {
write(v.Raw, v.Name) write(v.Raw, v.Name)
@ -179,6 +184,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
} }
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression: case clause.Expression:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:
@ -304,6 +313,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds := make([]clause.Expression, 0, 4) conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...) args = append([]interface{}{query}, args...)
for idx, arg := range args { for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok { if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
@ -540,6 +552,7 @@ func (stmt *Statement) clone() *Statement {
} }
// SetColumn set column's value // SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
@ -657,47 +670,41 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
results := map[string]bool{} results := map[string]bool{}
notRestricted := false notRestricted := false
// select columns processColumn := func(column string, result bool) {
for _, column := range stmt.Selects {
if stmt.Schema == nil { if stmt.Schema == nil {
results[column] = true results[column] = result
} else if column == "*" { } else if column == "*" {
notRestricted = true notRestricted = result
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true results[dbName] = result
} }
} else if column == clause.Associations { } else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true results[rel.Name] = result
} }
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true results[field.DBName] = result
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
results[matches[2]] = true if matches[2] == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else { } else {
results[column] = true results[matches[2]] = result
}
} else {
results[column] = result
} }
} }
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns // omit columns
for _, omit := range stmt.Omits { for _, column := range stmt.Omits {
if stmt.Schema == nil { processColumn(column, false)
results[omit] = false
} else if omit == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = false
}
} else if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
results[matches[1]] = false
} else {
results[omit] = false
}
} }
if stmt.Schema != nil { if stmt.Schema != nil {

View File

@ -35,6 +35,13 @@ func TestWhereCloneCorruption(t *testing.T) {
} }
} }
func TestNilCondition(t *testing.T) {
s := new(Statement)
if len(s.BuildCondition(nil)) != 0 {
t.Errorf("Nil condition should be empty")
}
}
func TestNameMatcher(t *testing.T) { func TestNameMatcher(t *testing.T) {
for k, v := range map[string][]string{ for k, v := range map[string][]string{
"table.name": {"table", "name"}, "table.name": {"table", "name"},

View File

@ -3,6 +3,7 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) {
unexistCompanyID := company.ID + 9999999 unexistCompanyID := company.ID + 9999999
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
if err := DB.Create(&user).Error; err == nil { if err := DB.Create(&user).Error; err == nil {
tidbSkip(t, "not support the foreign key feature")
t.Errorf("should have gotten foreign key violation error") t.Errorf("should have gotten foreign key violation error")
} }
} }
@ -224,3 +226,28 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
} }
func TestBelongsToDefaultValue(t *testing.T) {
type Org struct {
ID string
}
type BelongsToUser struct {
OrgID string
Org Org `gorm:"default:NULL"`
}
tx := DB.Session(&gorm.Session{})
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
tx.AutoMigrate(&BelongsToUser{}, &Org{})
user := &BelongsToUser{
Org: Org{
ID: "BelongsToUser_Org_1",
},
}
err := DB.Create(&user).Error
AssertEqual(t, err, nil)
}

View File

@ -95,6 +95,8 @@ func TestMany2ManyAssociation(t *testing.T) {
} }
func TestMany2ManyOmitAssociations(t *testing.T) { func TestMany2ManyOmitAssociations(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
user := *GetUser("many2many_omit_associations", Config{Languages: 2}) user := *GetUser("many2many_omit_associations", Config{Languages: 2})
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {

View File

@ -71,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) {
} }
func TestForeignKeyConstraints(t *testing.T) { func TestForeignKeyConstraints(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type Profile struct { type Profile struct {
ID uint ID uint
Name string Name string
@ -126,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) {
} }
func TestForeignKeyConstraintsBelongsTo(t *testing.T) { func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type Profile struct { type Profile struct {
ID uint ID uint
Name string Name string
@ -348,3 +352,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
AssertEqual(t, len(orgs), 0) AssertEqual(t, len(orgs), 0)
} }
} }
type AssociationEmptyUser struct {
ID uint
Name string
Pets []AssociationEmptyPet
}
type AssociationEmptyPet struct {
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
}
func TestAssociationEmptyPrimaryKey(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
id := uint(100)
user := AssociationEmptyUser{
ID: id,
Name: "jinzhu",
Pets: []AssociationEmptyPet{
{AssociationEmptyUserID: &id, Name: "bar"},
{AssociationEmptyUserID: &id, Name: "foo"},
},
}
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
if err != nil {
t.Fatalf("Failed to create, got error: %v", err)
}
var result AssociationEmptyUser
err = DB.Preload("Pets").First(&result, &id).Error
if err != nil {
t.Fatalf("Failed to find, got error: %v", err)
}
AssertEqual(t, result, user)
}

View File

@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) {
for idx, data := range datas { for idx, data := range datas {
db, err := gorm.Open(nil, nil) db, err := gorm.Open(nil, nil)
if err != nil {
t.Fatal(err)
}
callbacks := db.Callback() callbacks := db.Callback()
for _, c := range data.callbacks { for _, c := range data.callbacks {

View File

@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
} }
}() }()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
if err != nil { if err != nil {
t.Fatalf("Should open db success, but got %v", err) t.Fatalf("Should open db success, but got %v", err)
} }

View File

@ -141,7 +141,7 @@ func TestCount(t *testing.T) {
} }
DB.Create(sameUsers) DB.Create(sameUsers)
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) {
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
} }

View File

@ -29,3 +29,8 @@ services:
- MSSQL_DB=gorm - MSSQL_DB=gorm
- MSSQL_USER=gorm - MSSQL_USER=gorm
- MSSQL_PASSWORD=LoremIpsum86 - MSSQL_PASSWORD=LoremIpsum86
tidb:
image: 'pingcap/tidb:v6.5.0'
ports:
- 9940:4000
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &

View File

@ -36,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) {
type EngadgetPost struct { type EngadgetPost struct {
BasePost BasePost `gorm:"Embedded"` BasePost BasePost `gorm:"Embedded"`
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
ImageUrl string ImageUrl string
} }
@ -74,13 +74,27 @@ func TestEmbeddedStruct(t *testing.T) {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
var egNews EngadgetPost var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" { } else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
var egPosts []EngadgetPost
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
}
expectAuthors := []string{"Edward", "George"}
for i, post := range egPosts {
t.Log(i, post.Author)
if want := expectAuthors[i]; post.Author.Name != want {
t.Errorf("expected author %s got %s", want, post.Author.Name)
}
}
} }
func TestEmbeddedPointerTypeStruct(t *testing.T) { func TestEmbeddedPointerTypeStruct(t *testing.T) {

View File

@ -3,17 +3,16 @@ module gorm.io/gorm/tests
go 1.16 go 1.16
require ( require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.6 github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.14 // indirect github.com/mattn/go-sqlite3 v1.14.16 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/crypto v0.5.0 // indirect
gorm.io/driver/mysql v1.3.5 gorm.io/driver/mysql v1.4.6
gorm.io/driver/postgres v1.3.8 gorm.io/driver/postgres v1.4.6
gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlite v1.4.4
gorm.io/driver/sqlserver v1.3.2 gorm.io/driver/sqlserver v1.4.2
gorm.io/gorm v1.23.8 gorm.io/gorm v1.24.5
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -235,3 +236,13 @@ func CheckUser(t *testing.T, user User, expect User) {
} }
}) })
} }
func tidbSkip(t *testing.T, reason string) {
if isTiDB() {
t.Skipf("This test case skipped, because of TiDB '%s'", reason)
}
}
func isTiDB() bool {
return os.Getenv("GORM_DIALECT") == "tidb"
}

View File

@ -229,3 +229,99 @@ func TestJoinWithSoftDeleted(t *testing.T) {
t.Fatalf("joins NamedPet and Account should not empty:%v", user2) t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
} }
} }
func TestInnerJoins(t *testing.T) {
user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
DB.Create(&user)
var user2 User
var err error
err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
AssertEqual(t, err, nil)
CheckUser(t, user2, user)
// inner join and NamedPet is nil
err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
AssertEqual(t, err, gorm.ErrRecordNotFound)
// mixed inner join and left join
var user3 User
err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error
AssertEqual(t, err, nil)
CheckUser(t, user3, user)
}
func TestJoinWithSameColumnName(t *testing.T) {
user := GetUser("TestJoinWithSameColumnName", Config{
Languages: 1,
Pets: 1,
})
DB.Create(user)
type UserSpeak struct {
UserID uint
LanguageCode string
}
type Result struct {
User
UserSpeak
Language
Pet
}
results := make([]Result, 0, 1)
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
Joins("JOIN languages ON languages.code = user_speaks.language_code").
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
if len(results) == 0 {
t.Fatalf("no record find")
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
t.Fatalf("wrong user id in pet")
} else if results[0].Pet.Name != user.Pets[0].Name {
t.Fatalf("wrong pet name")
}
}
func TestJoinArgsWithDB(t *testing.T) {
user := *GetUser("joins-args-db", Config{Pets: 2})
DB.Save(&user)
// test where
var user1 User
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
// test where and omit
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
var user2 User
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
AssertEqual(t, user2.NamedPet.Name, "")
// test where and select
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
var user3 User
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user3.NamedPet.ID, 0)
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
// test select
onQuery4 := DB.Select("ID")
var user4 User
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
if user4.NamedPet.ID == 0 {
t.Fatal("Pet ID can not be empty")
}
AssertEqual(t, user4.NamedPet.Name, "")
}

View File

@ -1,8 +1,10 @@
package tests_test package tests_test
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -10,6 +12,7 @@ import (
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -25,7 +28,7 @@ func TestMigrate(t *testing.T) {
} }
if err := DB.AutoMigrate(allModels...); err != nil { if err := DB.AutoMigrate(allModels...); err != nil {
t.Fatalf("Failed to auto migrate, but got error %v", err) t.Fatalf("Failed to auto migrate, got error %v", err)
} }
if tables, err := DB.Migrator().GetTables(); err != nil { if tables, err := DB.Migrator().GetTables(); err != nil {
@ -72,6 +75,44 @@ func TestMigrate(t *testing.T) {
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
} }
} }
}
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Smallint int8
type MigrateInt struct {
Int8 Smallint
}
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
}
},
}
DB.Migrator().DropTable(&MigrateInt{})
// The first AutoMigrate to make table with field with correct type
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
// make new session to set custom logger tracer
session := DB.Session(&gorm.Session{Logger: tracer})
// The second AutoMigrate to catch an error
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
} }
func TestAutoMigrateSelfReferential(t *testing.T) { func TestAutoMigrateSelfReferential(t *testing.T) {
@ -333,7 +374,137 @@ func TestMigrateIndexes(t *testing.T) {
} }
} }
func TestTiDBMigrateColumns(t *testing.T) {
if !isTiDB() {
t.Skip()
}
// TiDB can't change column constraint and has auto_random feature
type ColumnStruct struct {
ID int `gorm:"primarykey;default:auto_random()"`
Name string
Age int `gorm:"default:18;comment:my age"`
Code string `gorm:"unique;comment:my code;"`
Code2 string
Code3 string `gorm:"unique"`
}
DB.Migrator().DropTable(&ColumnStruct{})
if err := DB.AutoMigrate(&ColumnStruct{}); err != nil {
t.Errorf("Failed to migrate, got %v", err)
}
type ColumnStruct2 struct {
ID int `gorm:"primarykey;default:auto_random()"`
Name string `gorm:"size:100"`
Code string `gorm:"unique;comment:my code2;default:hello"`
Code2 string `gorm:"comment:my code2;default:hello"`
}
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
t.Fatalf("no error should happened when alter column, but got %v", err)
}
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
}
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
t.Fatalf("no error should returns for ColumnTypes")
} else {
stmt := &gorm.Statement{DB: DB}
stmt.Parse(&ColumnStruct2{})
for _, columnType := range columnTypes {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
if length, ok := columnType.Length(); !ok || length != 100 {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !ok || v != "my age" {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !ok || v != "my code2" {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
// Code2 string `gorm:"comment:my code2;default:hello"`
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !ok || v != "my code2" {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
}
}
}
type NewColumnStruct struct {
gorm.Model
Name string
NewName string
}
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
t.Fatalf("Failed to find added column")
}
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") {
t.Fatalf("Found deleted column")
}
if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
t.Fatalf("Failed to found renamed column")
}
if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") {
t.Fatalf("Found deleted column")
}
}
func TestMigrateColumns(t *testing.T) { func TestMigrateColumns(t *testing.T) {
tidbSkip(t, "use another test case")
sqlite := DB.Dialector.Name() == "sqlite" sqlite := DB.Dialector.Name() == "sqlite"
sqlserver := DB.Dialector.Name() == "sqlserver" sqlserver := DB.Dialector.Name() == "sqlserver"
@ -400,7 +571,7 @@ func TestMigrateColumns(t *testing.T) {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
} }
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
@ -715,6 +886,32 @@ func TestPrimarykeyID(t *testing.T) {
} }
} }
func TestCurrentTimestamp(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
return
}
type CurrentTimestampTest struct {
ID string `gorm:"primary_key"`
TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"`
}
var err error
err = DB.Migrator().DropTable(&CurrentTimestampTest{})
if err != nil {
t.Errorf("DropTable err:%v", err)
}
err = DB.AutoMigrate(&CurrentTimestampTest{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
err = DB.AutoMigrate(&CurrentTimestampTest{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
}
func TestUniqueColumn(t *testing.T) { func TestUniqueColumn(t *testing.T) {
if DB.Dialector.Name() != "mysql" { if DB.Dialector.Name() != "mysql" {
return return
@ -786,6 +983,8 @@ func TestUniqueColumn(t *testing.T) {
AssertEqual(t, "", value) AssertEqual(t, "", value)
AssertEqual(t, false, ok) AssertEqual(t, false, ok)
tidbSkip(t, "can't change column constraint")
// null -> empty string // null -> empty string
err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{})
if err != nil { if err != nil {
@ -850,7 +1049,7 @@ func findColumnType(dest interface{}, columnName string) (
return return
} }
func TestInvalidCachedPlan(t *testing.T) { func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
if DB.Dialector.Name() != "postgres" { if DB.Dialector.Name() != "postgres" {
return return
} }
@ -885,6 +1084,101 @@ func TestInvalidCachedPlan(t *testing.T) {
} }
} }
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
if err != nil {
t.Errorf("Open err:%v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" {
db.Logger = db.Logger.LogMode(logger.Silent)
}
type Object1 struct {
ID uint
}
type Object2 struct {
ID uint
Field1 int `gorm:"type:int8"`
}
type Object3 struct {
ID uint
Field1 int `gorm:"type:int4"`
}
type Object4 struct {
ID uint
Field2 int
}
db.Migrator().DropTable("objects")
err = db.Table("objects").AutoMigrate(&Object1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Create(&Object1{}).Error
if err != nil {
t.Errorf("create err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object2{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AlterColumn
err = db.Table("objects").AutoMigrate(&Object3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object3{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object4{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
}
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
type DiffType struct { type DiffType struct {
ID uint ID uint
@ -959,3 +1253,247 @@ func TestMigrateArrayTypeModel(t *testing.T) {
AssertEqual(t, nil, err) AssertEqual(t, nil, err)
AssertEqual(t, "integer[]", ct.DatabaseTypeName()) AssertEqual(t, "integer[]", ct.DatabaseTypeName())
} }
type mockMigrator struct {
gorm.Migrator
}
func (mm mockMigrator) AlterColumn(dst interface{}, field string) error {
err := mm.Migrator.AlterColumn(dst, field)
if err != nil {
return err
}
return fmt.Errorf("trigger alter column error, field: %s", field)
}
func TestMigrateDonotAlterColumn(t *testing.T) {
var wrapMockMigrator = func(m gorm.Migrator) mockMigrator {
return mockMigrator{
Migrator: m,
}
}
m := DB.Migrator()
mockM := wrapMockMigrator(m)
type NotTriggerUpdate struct {
ID uint
F1 uint16
F2 uint32
F3 int
F4 int64
F5 string
F6 float32
F7 float64
F8 time.Time
F9 bool
F10 []byte
}
var err error
err = mockM.DropTable(&NotTriggerUpdate{})
AssertEqual(t, err, nil)
err = mockM.AutoMigrate(&NotTriggerUpdate{})
AssertEqual(t, err, nil)
err = mockM.AutoMigrate(&NotTriggerUpdate{})
AssertEqual(t, err, nil)
}
func TestMigrateSameEmbeddedFieldName(t *testing.T) {
type UserStat struct {
GroundDestroyCount int
}
type GameUser struct {
gorm.Model
StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"`
}
type UserStat1 struct {
GroundDestroyCount string
}
type GroundRate struct {
GroundDestroyCount int
}
type GameUser1 struct {
gorm.Model
StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"`
GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"`
}
DB.Migrator().DropTable(&GameUser{})
err := DB.AutoMigrate(&GameUser{})
AssertEqual(t, nil, err)
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
AssertEqual(t, nil, err)
}
func TestMigrateDefaultNullString(t *testing.T) {
if DB.Dialector.Name() == "sqlserver" {
// sqlserver driver treats NULL and 'NULL' the same
t.Skip("skip sqlserver")
}
type NullModel struct {
ID uint
Content string `gorm:"default:null"`
}
type NullStringModel struct {
ID uint
Content string `gorm:"default:'null'"`
}
tableName := "null_string_model"
DB.Migrator().DropTable(tableName)
err := DB.Table(tableName).AutoMigrate(&NullModel{})
AssertEqual(t, err, nil)
// default null -> 'null'
err = DB.Table(tableName).AutoMigrate(&NullStringModel{})
AssertEqual(t, err, nil)
columnType, err := findColumnType(tableName, "content")
AssertEqual(t, err, nil)
defVal, ok := columnType.DefaultValue()
AssertEqual(t, defVal, "null")
AssertEqual(t, ok, true)
// default 'null' -> 'null'
session := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE") {
t.Errorf("shouldn't execute: sql=%s", sql)
}
},
}})
err = session.Table(tableName).AutoMigrate(&NullStringModel{})
AssertEqual(t, err, nil)
columnType, err = findColumnType(tableName, "content")
AssertEqual(t, err, nil)
defVal, ok = columnType.DefaultValue()
AssertEqual(t, defVal, "null")
AssertEqual(t, ok, true)
// default 'null' -> null
err = DB.Table(tableName).AutoMigrate(&NullModel{})
AssertEqual(t, err, nil)
columnType, err = findColumnType(tableName, "content")
AssertEqual(t, err, nil)
defVal, ok = columnType.DefaultValue()
AssertEqual(t, defVal, "")
AssertEqual(t, ok, false)
}
func TestMigrateMySQLWithCustomizedTypes(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
type MyTable struct {
Def string `gorm:"size:512;index:idx_def,unique"`
Abc string `gorm:"size:65000000"`
}
DB.Migrator().DropTable("my_tables")
sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))"
if err := DB.Exec(sql).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
session := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE") {
t.Errorf("shouldn't execute: sql=%s", sql)
}
},
}})
if err := session.AutoMigrate(&MyTable{}); err != nil {
t.Errorf("Failed, got error: %v", err)
}
}
func TestMigrateIgnoreRelations(t *testing.T) {
type RelationModel1 struct {
ID uint
}
type RelationModel2 struct {
ID uint
}
type RelationModel3 struct {
ID uint
RelationModel1ID uint
RelationModel1 *RelationModel1
RelationModel2ID uint
RelationModel2 *RelationModel2 `gorm:"-:migration"`
}
var err error
_ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{})
tx := DB.Session(&gorm.Session{})
tx.IgnoreRelationshipsWhenMigrating = true
err = tx.AutoMigrate(&RelationModel3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// RelationModel3 should be existed
_, err = findColumnType(&RelationModel3{}, "id")
AssertEqual(t, nil, err)
// RelationModel1 should not be existed
_, err = findColumnType(&RelationModel1{}, "id")
if err == nil {
t.Errorf("RelationModel1 should not be migrated")
}
// RelationModel2 should not be existed
_, err = findColumnType(&RelationModel2{}, "id")
if err == nil {
t.Errorf("RelationModel2 should not be migrated")
}
tx.IgnoreRelationshipsWhenMigrating = false
err = tx.AutoMigrate(&RelationModel3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// RelationModel3 should be existed
_, err = findColumnType(&RelationModel3{}, "id")
AssertEqual(t, nil, err)
// RelationModel1 should be existed
_, err = findColumnType(&RelationModel1{}, "id")
AssertEqual(t, nil, err)
// RelationModel2 should not be existed
_, err = findColumnType(&RelationModel2{}, "id")
if err == nil {
t.Errorf("RelationModel2 should not be migrated")
}
}

View File

@ -7,8 +7,60 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq" "github.com/lib/pq"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
) )
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Yasuo struct {
ID string `gorm:"default:gen_random_uuid()"`
Name string
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Yasuo{})
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
yasuo := Yasuo{Name: "jinzhu"}
if err := DB.Create(&yasuo).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
if yasuo.ID == "" {
t.Fatal("should be able to has ID, but got zero value")
}
var result Yasuo
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
yasuo.Name = "jinzhu1"
if err := DB.Save(&yasuo).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
}
func TestPostgres(t *testing.T) { func TestPostgres(t *testing.T) {
if DB.Dialector.Name() != "postgres" { if DB.Dialector.Name() != "postgres" {
t.Skip() t.Skip()
@ -60,6 +112,45 @@ func TestPostgres(t *testing.T) {
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err) t.Errorf("No error should happen, but got %v", err)
} }
DB.Migrator().DropTable("log_usage")
if err := DB.Exec(`
CREATE TABLE public.log_usage (
log_id bigint NOT NULL
);
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
SEQUENCE NAME public.log_usage_log_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1
);
`).Error; err != nil {
t.Fatalf("failed to create table, got error %v", err)
}
columns, err := DB.Migrator().ColumnTypes("log_usage")
if err != nil {
t.Fatalf("failed to get columns, got error %v", err)
}
hasLogID := false
for _, column := range columns {
if column.Name() == "log_id" {
hasLogID = true
autoIncrement, ok := column.AutoIncrement()
if !ok || !autoIncrement {
t.Fatalf("column log_id should be auto incrementment")
}
}
}
if !hasLogID {
t.Fatalf("failed to found column log_id")
}
} }
type Post struct { type Post struct {
@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
t.Errorf("Failed, got error: %v", err) t.Errorf("Failed, got error: %v", err)
} }
} }
func TestPostgresOnConstraint(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Thing struct {
gorm.Model
SomeID string
OtherID string
Data string
}
DB.Migrator().DropTable(&Thing{})
DB.Migrator().CreateTable(&Thing{})
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
t.Error(err)
}
thing := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something",
}
DB.Create(&thing)
thing2 := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something else",
}
result := DB.Clauses(clause.OnConflict{
OnConstraint: "some_id_other_id_unique",
UpdateAll: true,
}).Create(&thing2)
if result.Error != nil {
t.Errorf("creating second thing: %v", result.Error)
}
var things []Thing
if err := DB.Find(&things).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
if len(things) > 1 {
t.Errorf("expected 1 thing got more")
}
}
type CompanyNew struct {
ID int
Name int
}
func TestAlterColumnDataType(t *testing.T) {
DB.AutoMigrate(Company{})
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
t.Fatalf("failed to alter column from string to int, got error %v", err)
}
DB.AutoMigrate(Company{})
}

View File

@ -2,6 +2,8 @@ package tests_test
import ( import (
"context" "context"
"errors"
"sync"
"testing" "testing"
"time" "time"
@ -88,3 +90,107 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
} }
tx2.Commit() tx2.Commit()
} }
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
user := User{Name: "jinzhu"}
tx.Create(&user)
var result User
tx.First(&result)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}
if err := DB.Transaction(func(tx *gorm.DB) error {
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
return errors.New("test")
}); err == nil {
t.Error(err)
}
var result User
if err := DB.First(&result, user.ID).Error; err == nil {
t.Errorf("Failed, got error: %v", err)
}
}
func TestPreparedStmtReset(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}

View File

@ -216,6 +216,30 @@ func TestFind(t *testing.T) {
} }
} }
// test array
var models2 [3]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
} else {
for idx, user := range users {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models2[idx], user)
})
}
}
// test smaller array
var models3 [2]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
} else {
for idx, user := range users[:2] {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models3[idx], user)
})
}
}
var none []User var none []User
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
@ -384,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) {
if totalBatch != 0 { if totalBatch != 0 {
t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch)
} }
if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
totalBatch += batch
return nil
}); result.Error != gorm.ErrPrimaryKeyRequired {
t.Fatal("expected errors to have occurred, but nothing happened")
}
} }
func TestFillSmallerStruct(t *testing.T) { func TestFillSmallerStruct(t *testing.T) {
@ -1335,3 +1366,17 @@ func TestQueryResetNullValue(t *testing.T) {
AssertEqual(t, q1, qs[0]) AssertEqual(t, q1, qs[0])
AssertEqual(t, q2, qs[1]) AssertEqual(t, q2, qs[1])
} }
func TestQueryError(t *testing.T) {
type P struct{}
var p1 P
err := DB.Take(&p1, 1).Error
AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired)
var p2 interface{}
err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{
Table: clause.CurrentTable, Name: clause.PrimaryKey,
}, Value: 1}).Scan(&p2).Error
AssertEqual(t, err, gorm.ErrModelValueRequired)
}

View File

@ -18,6 +18,8 @@ type SerializerStruct struct {
gorm.Model gorm.Model
Name []byte `gorm:"json"` Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"` Roles Roles `gorm:"serializer:json"`
Roles2 *Roles `gorm:"serializer:json"`
Roles3 *Roles `gorm:"serializer:json;not null"`
Contracts map[string]interface{} `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"` JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
} }
var result SerializerStruct var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil { if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err) t.Fatalf("failed to query data, got error %v", err)
} }

View File

@ -7,6 +7,7 @@ import (
"regexp" "regexp"
"testing" "testing"
"github.com/jinzhu/now"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -39,6 +40,11 @@ func TestSoftDelete(t *testing.T) {
t.Fatalf("invalid sql generated, got %v", sql) t.Fatalf("invalid sql generated, got %v", sql)
} }
sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String()
if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) {
t.Errorf("Table with escape character, got %v", sql)
}
if DB.First(&User{}, "name = ?", user.Name).Error == nil { if DB.First(&User{}, "name = ?", user.Name).Error == nil {
t.Errorf("Can't find a soft deleted record") t.Errorf("Can't find a soft deleted record")
} }
@ -93,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) {
t.Fatalf("invalid sql generated, got %v", actualSQL) t.Fatalf("invalid sql generated, got %v", actualSQL)
} }
} }
func TestSoftDeleteZeroValue(t *testing.T) {
type SoftDeleteBook struct {
ID uint
Name string
Pages uint
DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"`
}
DB.Migrator().DropTable(&SoftDeleteBook{})
if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil {
t.Fatalf("failed to auto migrate soft delete table")
}
book := SoftDeleteBook{Name: "jinzhu", Pages: 10}
DB.Save(&book)
var count int64
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
}
var pages uint
if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages)
}
if err := DB.Delete(&book).Error; err != nil {
t.Fatalf("No error should happen when soft delete user, but got %v", err)
}
zeroTime, _ := now.Parse("1970-01-01 00:00:01")
if book.DeletedAt.Time.Equal(zeroTime) {
t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt)
}
if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
count = 0
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 {
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
}
pages = 0
if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 {
t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err)
}
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil {
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
}
count = 0
if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
}
pages = 0
if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages)
}
DB.Unscoped().Delete(&book)
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("Can't find permanently deleted record")
}
}

View File

@ -29,7 +29,7 @@ func TestRow(t *testing.T) {
} }
table := "gorm.users" table := "gorm.users"
if DB.Dialector.Name() != "mysql" { if DB.Dialector.Name() != "mysql" || isTiDB() {
table = "users" // other databases doesn't support select with `database.table` table = "users" // other databases doesn't support select with `database.table`
} }
@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) {
t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") t.Skip("Skip SQL Server for this test, because it too difference with other dialects.")
} }
date, _ := time.Parse("2006-01-02", "2021-10-18") date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local)
// find // find
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
if DB.Statement.DryRun || DB.DryRun { if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
} }
// UpdateColumns
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
})
})
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
} }
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) {
AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3})
} }
type UserWithTableNamer struct {
gorm.Model
Name string
}
func (UserWithTableNamer) TableName(namer schema.Namer) string {
return namer.TableName("user")
}
func TestTableWithNamer(t *testing.T) {
var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: "t_",
}})
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})
})
if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) {
t.Errorf("Table with namer, got %v", sql)
}
}

View File

@ -1,6 +1,6 @@
#!/bin/bash -e #!/bin/bash -e
dialects=("sqlite" "mysql" "postgres" "sqlserver") dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
if [[ $(pwd) == *"gorm/tests"* ]]; then if [[ $(pwd) == *"gorm/tests"* ]]; then
cd .. cd ..

View File

@ -21,6 +21,7 @@ var (
mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
) )
func init() { func init() {
@ -80,6 +81,12 @@ func OpenTestConnection() (db *gorm.DB, err error) {
dbDSN = sqlserverDSN dbDSN = sqlserverDSN
} }
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
case "tidb":
log.Println("testing tidb...")
if dbDSN == "" {
dbDSN = tidbDSN
}
db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{})
default: default:
log.Println("testing sqlite3...") log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})

34
tests/tracer_test.go Normal file
View File

@ -0,0 +1,34 @@
package tests_test
import (
"context"
"time"
"gorm.io/gorm/logger"
)
type Tracer struct {
Logger logger.Interface
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
return S.Logger.LogMode(level)
}
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
S.Logger.Info(ctx, s, i...)
}
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
S.Logger.Warn(ctx, s, i...)
}
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
S.Logger.Error(ctx, s, i...)
}
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
S.Logger.Trace(ctx, begin, fc, err)
S.Test(ctx, begin, fc, err)
}

View File

@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) {
return errors.New("the error message") return errors.New("the error message")
}) })
if err.Error() != "the error message" { if err != nil && err.Error() != "the error message" {
t.Fatalf("Transaction return error will equal the block returns error") t.Fatalf("Transaction return error will equal the block returns error")
} }

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
@ -13,7 +14,14 @@ func (DummyDialector) Name() string {
return "dummy" return "dummy"
} }
func (DummyDialector) Initialize(*gorm.DB) error { func (DummyDialector) Initialize(db *gorm.DB) error {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
return nil return nil
} }

View File

@ -3,8 +3,8 @@ package utils
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"path/filepath"
"reflect" "reflect"
"regexp"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -16,7 +16,18 @@ var gormSourceDir string
func init() { func init() {
_, file, _, _ := runtime.Caller(0) _, file, _, _ := runtime.Caller(0)
// compatible solution to get gorm source directory with various operating systems // compatible solution to get gorm source directory with various operating systems
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") gormSourceDir = sourceDir(file)
}
func sourceDir(file string) string {
dir := filepath.Dir(file)
dir = filepath.Dir(dir)
s := filepath.Dir(dir)
if filepath.Base(s) != "gorm.io" {
s = dir
}
return filepath.ToSlash(s) + "/"
} }
// FileWithLineNum return the file name and line number of the current file // FileWithLineNum return the file name and line number of the current file

View File

@ -1,8 +1,13 @@
package utils package utils
import ( import (
"database/sql"
"database/sql/driver"
"errors"
"math"
"strings" "strings"
"testing" "testing"
"time"
) )
func TestIsValidDBNameChar(t *testing.T) { func TestIsValidDBNameChar(t *testing.T) {
@ -13,6 +18,29 @@ func TestIsValidDBNameChar(t *testing.T) {
} }
} }
func TestCheckTruth(t *testing.T) {
checkTruthTests := []struct {
v string
out bool
}{
{"123", true},
{"true", true},
{"", false},
{"false", false},
{"False", false},
{"FALSE", false},
{"\u0046alse", false},
}
for _, test := range checkTruthTests {
t.Run(test.v, func(t *testing.T) {
if out := CheckTruth(test.v); out != test.out {
t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out)
}
})
}
}
func TestToStringKey(t *testing.T) { func TestToStringKey(t *testing.T) {
cases := []struct { cases := []struct {
values []interface{} values []interface{}
@ -29,3 +57,81 @@ func TestToStringKey(t *testing.T) {
} }
} }
} }
func TestContains(t *testing.T) {
containsTests := []struct {
name string
elems []string
elem string
out bool
}{
{"exists", []string{"1", "2", "3"}, "1", true},
{"not exists", []string{"1", "2", "3"}, "4", false},
}
for _, test := range containsTests {
t.Run(test.name, func(t *testing.T) {
if out := Contains(test.elems, test.elem); test.out != out {
t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out)
}
})
}
}
type ModifyAt sql.NullTime
// Value return a Unix time.
func (n ModifyAt) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time.Unix(), nil
}
func TestAssertEqual(t *testing.T) {
now := time.Now()
assertEqualTests := []struct {
name string
src, dst interface{}
out bool
}{
{"error equal", errors.New("1"), errors.New("1"), true},
{"error not equal", errors.New("1"), errors.New("2"), false},
{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true},
{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false},
}
for _, test := range assertEqualTests {
t.Run(test.name, func(t *testing.T) {
if out := AssertEqual(test.src, test.dst); test.out != out {
t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out)
}
})
}
}
func TestToString(t *testing.T) {
tests := []struct {
name string
in interface{}
out string
}{
{"int", math.MaxInt64, "9223372036854775807"},
{"int8", int8(math.MaxInt8), "127"},
{"int16", int16(math.MaxInt16), "32767"},
{"int32", int32(math.MaxInt32), "2147483647"},
{"int64", int64(math.MaxInt64), "9223372036854775807"},
{"uint", uint(math.MaxUint64), "18446744073709551615"},
{"uint8", uint8(math.MaxUint8), "255"},
{"uint16", uint16(math.MaxUint16), "65535"},
{"uint32", uint32(math.MaxUint32), "4294967295"},
{"uint64", uint64(math.MaxUint64), "18446744073709551615"},
{"string", "abc", "abc"},
{"other", true, ""},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if out := ToString(test.in); test.out != out {
t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out)
}
})
}
}

38
utils/utils_unix_test.go Normal file
View File

@ -0,0 +1,38 @@
//go:build unix
// +build unix
package utils
import (
"testing"
)
func TestSourceDir(t *testing.T) {
cases := []struct {
file string
want string
}{
{
file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go",
want: "/Users/name/go/pkg/mod/gorm.io/",
},
{
file: "/go/work/proj/gorm/utils/utils.go",
want: "/go/work/proj/gorm/",
},
{
file: "/go/work/proj/gorm_alias/utils/utils.go",
want: "/go/work/proj/gorm_alias/",
},
{
file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go",
want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/",
},
}
for _, c := range cases {
s := sourceDir(c.file)
if s != c.want {
t.Fatalf("%s: expected %s, got %s", c.file, c.want, s)
}
}
}

View File

@ -0,0 +1,35 @@
package utils
import (
"testing"
)
func TestSourceDir(t *testing.T) {
cases := []struct {
file string
want string
}{
{
file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`,
want: `C:/Users/name/go/pkg/mod/gorm.io/`,
},
{
file: `C:/go/work/proj/gorm/utils/utils.go`,
want: `C:/go/work/proj/gorm/`,
},
{
file: `C:/go/work/proj/gorm_alias/utils/utils.go`,
want: `C:/go/work/proj/gorm_alias/`,
},
{
file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`,
want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`,
},
}
for _, c := range cases {
s := sourceDir(c.file)
if s != c.want {
t.Fatalf("%s: expected %s, got %s", c.file, c.want, s)
}
}
}