hailin/pkg/util/db.go
2025-06-19 10:30:46 +08:00

125 lines
2.6 KiB
Go

package util
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Trans struct {
DB *gorm.DB
}
type TransFunc func(context.Context) error
func (a *Trans) Exec(ctx context.Context, fn TransFunc) error {
if _, ok := FromTrans(ctx); ok {
return fn(ctx)
}
return a.DB.Transaction(func(db *gorm.DB) error {
return fn(NewTrans(ctx, db))
})
}
func GetDB(ctx context.Context, defDB *gorm.DB) *gorm.DB {
db := defDB
if tdb, ok := FromTrans(ctx); ok {
db = tdb
}
if FromRowLock(ctx) {
db = db.Clauses(clause.Locking{Strength: "UPDATE"})
}
return db.WithContext(ctx)
}
func wrapQueryOptions(db *gorm.DB, opts QueryOptions) *gorm.DB {
if len(opts.SelectFields) > 0 {
db = db.Select(opts.SelectFields)
}
if len(opts.OmitFields) > 0 {
db = db.Omit(opts.OmitFields...)
}
if len(opts.OrderFields) > 0 {
db = db.Order(opts.OrderFields.ToSQL())
}
return db
}
func WrapPageQuery(ctx context.Context, db *gorm.DB, pp PaginationParam, opts QueryOptions, out interface{}) (*PaginationResult, error) {
if pp.OnlyCount {
var count int64
err := db.Count(&count).Error
if err != nil {
return nil, err
}
return &PaginationResult{Total: count}, nil
} else if !pp.Pagination {
pageSize := pp.PageSize
if pageSize > 0 {
db = db.Limit(pageSize)
}
db = wrapQueryOptions(db, opts)
err := db.Find(out).Error
return nil, err
}
total, err := FindPage(ctx, db, pp, opts, out)
if err != nil {
return nil, err
}
return &PaginationResult{
Total: total,
Current: pp.Current,
PageSize: pp.PageSize,
}, nil
}
func FindPage(ctx context.Context, db *gorm.DB, pp PaginationParam, opts QueryOptions, out interface{}) (int64, error) {
db = db.WithContext(ctx)
var count int64
err := db.Count(&count).Error
if err != nil {
return 0, err
} else if count == 0 {
return count, nil
}
current, pageSize := pp.Current, pp.PageSize
if current > 0 && pageSize > 0 {
db = db.Offset((current - 1) * pageSize).Limit(pageSize)
} else if pageSize > 0 {
db = db.Limit(pageSize)
}
db = wrapQueryOptions(db, opts)
err = db.Find(out).Error
return count, err
}
func FindOne(ctx context.Context, db *gorm.DB, opts QueryOptions, out interface{}) (bool, error) {
db = db.WithContext(ctx)
db = wrapQueryOptions(db, opts)
result := db.First(out)
if err := result.Error; err != nil {
if err == gorm.ErrRecordNotFound {
return false, nil
}
return false, err
}
return true, nil
}
func Exists(ctx context.Context, db *gorm.DB) (bool, error) {
db = db.WithContext(ctx)
var count int64
result := db.Count(&count)
if err := result.Error; err != nil {
return false, err
}
return count > 0, nil
}