haibei/pkg/gormx/gorm.go
2025-06-19 10:33:58 +08:00

163 lines
3.7 KiB
Go

package gormx
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"time"
sdmysql "github.com/go-sql-driver/mysql"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/plugin/dbresolver"
)
type ResolverConfig struct {
DBType string // mysql/postgres/sqlite3
Sources []string
Replicas []string
Tables []string
}
type Config struct {
Debug bool
PrepareStmt bool
DBType string // mysql/postgres/sqlite3
DSN string
MaxLifetime int
MaxIdleTime int
MaxOpenConns int
MaxIdleConns int
TablePrefix string
Resolver []ResolverConfig
}
func New(cfg Config) (*gorm.DB, error) {
var dialector gorm.Dialector
switch strings.ToLower(cfg.DBType) {
case "mysql":
if err := createDatabaseWithMySQL(cfg.DSN); err != nil {
return nil, err
}
dialector = mysql.Open(cfg.DSN)
case "postgres":
dialector = postgres.Open(cfg.DSN)
case "sqlite3":
_ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm)
dialector = sqlite.Open(cfg.DSN)
default:
return nil, fmt.Errorf("unsupported database type: %s", cfg.DBType)
}
ormCfg := &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: cfg.TablePrefix,
SingularTable: true,
},
Logger: logger.Discard,
PrepareStmt: cfg.PrepareStmt,
}
if cfg.Debug {
ormCfg.Logger = logger.Default
}
db, err := gorm.Open(dialector, ormCfg)
if err != nil {
return nil, err
}
if len(cfg.Resolver) > 0 {
resolver := &dbresolver.DBResolver{}
for _, r := range cfg.Resolver {
resolverCfg := dbresolver.Config{}
var open func(dsn string) gorm.Dialector
dbType := strings.ToLower(r.DBType)
switch dbType {
case "mysql":
open = mysql.Open
case "postgres":
open = postgres.Open
case "sqlite3":
open = sqlite.Open
default:
continue
}
for _, replica := range r.Replicas {
if dbType == "sqlite3" {
_ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm)
}
resolverCfg.Replicas = append(resolverCfg.Replicas, open(replica))
}
for _, source := range r.Sources {
if dbType == "sqlite3" {
_ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm)
}
resolverCfg.Sources = append(resolverCfg.Sources, open(source))
}
tables := stringSliceToInterfaceSlice(r.Tables)
resolver.Register(resolverCfg, tables...)
zap.L().Info(fmt.Sprintf("Use resolver, #tables: %v, #replicas: %v, #sources: %v \n",
tables, r.Replicas, r.Sources))
}
resolver.SetMaxIdleConns(cfg.MaxIdleConns).
SetMaxOpenConns(cfg.MaxOpenConns).
SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Second).
SetConnMaxIdleTime(time.Duration(cfg.MaxIdleTime) * time.Second)
if err := db.Use(resolver); err != nil {
return nil, err
}
}
if cfg.Debug {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Second)
sqlDB.SetConnMaxIdleTime(time.Duration(cfg.MaxIdleTime) * time.Second)
return db, nil
}
func stringSliceToInterfaceSlice(s []string) []interface{} {
r := make([]interface{}, len(s))
for i, v := range s {
r[i] = v
}
return r
}
func createDatabaseWithMySQL(dsn string) error {
cfg, err := sdmysql.ParseDSN(dsn)
if err != nil {
return err
}
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/", cfg.User, cfg.Passwd, cfg.Addr))
if err != nil {
return err
}
defer db.Close()
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` DEFAULT CHARACTER SET = `utf8mb4`;", cfg.DBName)
_, err = db.Exec(query)
return err
}