package gormx import ( "database/sql" "fmt" "os" "path/filepath" "strings" "time" sdmysql "github.com/go-sql-driver/mysql" "go.uber.org/zap" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/plugin/dbresolver" ) type ResolverConfig struct { DBType string // mysql/postgres/sqlite3 Sources []string Replicas []string Tables []string } type Config struct { Debug bool PrepareStmt bool DBType string // mysql/postgres/sqlite3 DSN string MaxLifetime int MaxIdleTime int MaxOpenConns int MaxIdleConns int TablePrefix string Resolver []ResolverConfig } func New(cfg Config) (*gorm.DB, error) { var dialector gorm.Dialector switch strings.ToLower(cfg.DBType) { case "mysql": if err := createDatabaseWithMySQL(cfg.DSN); err != nil { return nil, err } dialector = mysql.Open(cfg.DSN) case "postgres": dialector = postgres.Open(cfg.DSN) case "sqlite3": _ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm) dialector = sqlite.Open(cfg.DSN) default: return nil, fmt.Errorf("unsupported database type: %s", cfg.DBType) } ormCfg := &gorm.Config{ NamingStrategy: schema.NamingStrategy{ TablePrefix: cfg.TablePrefix, SingularTable: true, }, Logger: logger.Discard, PrepareStmt: cfg.PrepareStmt, } if cfg.Debug { ormCfg.Logger = logger.Default } db, err := gorm.Open(dialector, ormCfg) if err != nil { return nil, err } if len(cfg.Resolver) > 0 { resolver := &dbresolver.DBResolver{} for _, r := range cfg.Resolver { resolverCfg := dbresolver.Config{} var open func(dsn string) gorm.Dialector dbType := strings.ToLower(r.DBType) switch dbType { case "mysql": open = mysql.Open case "postgres": open = postgres.Open case "sqlite3": open = sqlite.Open default: continue } for _, replica := range r.Replicas { if dbType == "sqlite3" { _ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm) } resolverCfg.Replicas = append(resolverCfg.Replicas, open(replica)) } for _, source := range r.Sources { if dbType == "sqlite3" { _ = os.MkdirAll(filepath.Dir(cfg.DSN), os.ModePerm) } resolverCfg.Sources = append(resolverCfg.Sources, open(source)) } tables := stringSliceToInterfaceSlice(r.Tables) resolver.Register(resolverCfg, tables...) zap.L().Info(fmt.Sprintf("Use resolver, #tables: %v, #replicas: %v, #sources: %v \n", tables, r.Replicas, r.Sources)) } resolver.SetMaxIdleConns(cfg.MaxIdleConns). SetMaxOpenConns(cfg.MaxOpenConns). SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Second). SetConnMaxIdleTime(time.Duration(cfg.MaxIdleTime) * time.Second) if err := db.Use(resolver); err != nil { return nil, err } } if cfg.Debug { db = db.Debug() } sqlDB, err := db.DB() if err != nil { return nil, err } sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) sqlDB.SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Second) sqlDB.SetConnMaxIdleTime(time.Duration(cfg.MaxIdleTime) * time.Second) return db, nil } func stringSliceToInterfaceSlice(s []string) []interface{} { r := make([]interface{}, len(s)) for i, v := range s { r[i] = v } return r } func createDatabaseWithMySQL(dsn string) error { cfg, err := sdmysql.ParseDSN(dsn) if err != nil { return err } db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/", cfg.User, cfg.Passwd, cfg.Addr)) if err != nil { return err } defer db.Close() query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` DEFAULT CHARACTER SET = `utf8mb4`;", cfg.DBName) _, err = db.Exec(query) return err }