groupme/database/upgrades/upgrades.go

124 lines
2.2 KiB
Go
Raw Normal View History

package upgrades
import (
2021-02-13 00:53:35 -05:00
"errors"
"fmt"
"strings"
2021-02-13 00:53:35 -05:00
"gorm.io/gorm"
log "maunium.net/go/maulogger/v2"
)
type Dialect int
const (
Postgres Dialect = iota
SQLite
)
2020-07-10 15:56:45 +03:00
func (dialect Dialect) String() string {
switch dialect {
case Postgres:
return "postgres"
case SQLite:
return "sqlite3"
default:
return ""
}
}
2021-02-13 00:53:35 -05:00
type upgradeFunc func(*gorm.DB, context) error
type context struct {
dialect Dialect
2021-02-13 00:53:35 -05:00
db *gorm.DB
log log.Logger
}
type upgrade struct {
message string
fn upgradeFunc
}
2021-02-13 00:53:35 -05:00
type version struct {
gorm.Model
V int
}
const NumberOfUpgrades = 1
var upgrades [NumberOfUpgrades]upgrade
var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
2021-02-13 00:53:35 -05:00
func GetVersion(db *gorm.DB) (int, error) {
var ver = version{V: 0}
result := db.FirstOrCreate(&ver, &ver)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) ||
errors.Is(result.Error, gorm.ErrInvalidField) {
db.Create(&ver)
print("create version")
} else {
return 0, result.Error
}
}
2021-02-13 00:53:35 -05:00
return int(ver.V), nil
}
2021-02-13 00:53:35 -05:00
func SetVersion(tx *gorm.DB, newVersion int) error {
err := tx.Where("v IS NOT NULL").Delete(&version{})
if err.Error != nil {
return err.Error
}
2021-02-13 00:53:35 -05:00
val := version{V: newVersion}
tx = tx.Create(&val)
return tx.Error
}
2021-02-13 00:53:35 -05:00
func Run(log log.Logger, dialectName string, db *gorm.DB) error {
var dialect Dialect
switch strings.ToLower(dialectName) {
case "postgres":
dialect = Postgres
case "sqlite3":
dialect = SQLite
default:
return fmt.Errorf("unknown dialect %s", dialectName)
}
2021-02-13 00:53:35 -05:00
db.AutoMigrate(&version{})
version, err := GetVersion(db)
2021-02-13 00:53:35 -05:00
if err != nil {
return err
}
if version > NumberOfUpgrades {
return UnsupportedDatabaseVersion
}
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
for i, upgrade := range upgrades[version:] {
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
2021-02-13 00:53:35 -05:00
err = db.Transaction(func(tx *gorm.DB) error {
err = upgrade.fn(tx, context{dialect, db, log})
if err != nil {
return err
}
err = SetVersion(tx, version+i+1)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
2021-02-13 00:53:35 -05:00
}
return nil
}