125 lines
2.6 KiB
Go
125 lines
2.6 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type Trans struct {
|
|
DB *gorm.DB
|
|
}
|
|
|
|
type TransFunc func(context.Context) error
|
|
|
|
func (a *Trans) Exec(ctx context.Context, fn TransFunc) error {
|
|
if _, ok := FromTrans(ctx); ok {
|
|
return fn(ctx)
|
|
}
|
|
|
|
return a.DB.Transaction(func(db *gorm.DB) error {
|
|
return fn(NewTrans(ctx, db))
|
|
})
|
|
}
|
|
|
|
func GetDB(ctx context.Context, defDB *gorm.DB) *gorm.DB {
|
|
db := defDB
|
|
if tdb, ok := FromTrans(ctx); ok {
|
|
db = tdb
|
|
}
|
|
if FromRowLock(ctx) {
|
|
db = db.Clauses(clause.Locking{Strength: "UPDATE"})
|
|
}
|
|
return db.WithContext(ctx)
|
|
}
|
|
|
|
func wrapQueryOptions(db *gorm.DB, opts QueryOptions) *gorm.DB {
|
|
if len(opts.SelectFields) > 0 {
|
|
db = db.Select(opts.SelectFields)
|
|
}
|
|
if len(opts.OmitFields) > 0 {
|
|
db = db.Omit(opts.OmitFields...)
|
|
}
|
|
if len(opts.OrderFields) > 0 {
|
|
db = db.Order(opts.OrderFields.ToSQL())
|
|
}
|
|
return db
|
|
}
|
|
|
|
func WrapPageQuery(ctx context.Context, db *gorm.DB, pp PaginationParam, opts QueryOptions, out interface{}) (*PaginationResult, error) {
|
|
if pp.OnlyCount {
|
|
var count int64
|
|
err := db.Count(&count).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &PaginationResult{Total: count}, nil
|
|
} else if !pp.Pagination {
|
|
pageSize := pp.PageSize
|
|
if pageSize > 0 {
|
|
db = db.Limit(pageSize)
|
|
}
|
|
|
|
db = wrapQueryOptions(db, opts)
|
|
err := db.Find(out).Error
|
|
return nil, err
|
|
}
|
|
|
|
total, err := FindPage(ctx, db, pp, opts, out)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &PaginationResult{
|
|
Total: total,
|
|
Current: pp.Current,
|
|
PageSize: pp.PageSize,
|
|
}, nil
|
|
}
|
|
|
|
func FindPage(ctx context.Context, db *gorm.DB, pp PaginationParam, opts QueryOptions, out interface{}) (int64, error) {
|
|
db = db.WithContext(ctx)
|
|
var count int64
|
|
err := db.Count(&count).Error
|
|
if err != nil {
|
|
return 0, err
|
|
} else if count == 0 {
|
|
return count, nil
|
|
}
|
|
|
|
current, pageSize := pp.Current, pp.PageSize
|
|
if current > 0 && pageSize > 0 {
|
|
db = db.Offset((current - 1) * pageSize).Limit(pageSize)
|
|
} else if pageSize > 0 {
|
|
db = db.Limit(pageSize)
|
|
}
|
|
|
|
db = wrapQueryOptions(db, opts)
|
|
err = db.Find(out).Error
|
|
return count, err
|
|
}
|
|
|
|
func FindOne(ctx context.Context, db *gorm.DB, opts QueryOptions, out interface{}) (bool, error) {
|
|
db = db.WithContext(ctx)
|
|
db = wrapQueryOptions(db, opts)
|
|
result := db.First(out)
|
|
if err := result.Error; err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func Exists(ctx context.Context, db *gorm.DB) (bool, error) {
|
|
db = db.WithContext(ctx)
|
|
var count int64
|
|
result := db.Count(&count)
|
|
if err := result.Error; err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|