Upgrade message content in db to new protocol schema

This commit is contained in:
Tulir Asokan 2019-05-24 01:09:42 +03:00
parent 8d0d5ff504
commit 95e62fae77
6 changed files with 68 additions and 7 deletions

View File

@ -6,7 +6,7 @@ import (
) )
func init() { func init() {
upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx) error { upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
var byteType string var byteType string
if dialect == SQLite { if dialect == SQLite {
byteType = "BLOB" byteType = "BLOB"

View File

@ -5,7 +5,7 @@ import (
) )
func init() { func init() {
upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx) error { upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
if dialect == SQLite { if dialect == SQLite {
// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway. // SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
return nil return nil

View File

@ -5,7 +5,7 @@ import (
) )
func init() { func init() {
upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx) error { upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0") _, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
if err != nil { if err != nil {
return err return err

View File

@ -5,7 +5,7 @@ import (
) )
func init() { func init() {
upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx) error { upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`) _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
if err != nil { if err != nil {
return err return err

View File

@ -0,0 +1,61 @@
package upgrades
import (
"database/sql"
"encoding/json"
"fmt"
)
func init() {
var keys = []string{"imageMessage", "contactMessage", "locationMessage", "extendedTextMessage", "documentMessage", "audioMessage", "videoMessage"}
upgrades[4] = upgrade{"Update message content to new protocol version", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
rows, err := db.Query("SELECT mxid, content FROM message")
if err != nil {
return err
}
for rows.Next() {
var mxid string
var rawContent []byte
err = rows.Scan(&mxid, &rawContent)
if err != nil {
fmt.Println("Error scanning:", err)
continue
}
var content map[string]interface{}
err = json.Unmarshal(rawContent, &content)
if err != nil {
fmt.Printf("Error unmarshaling content of %s: %v\n", mxid, err)
continue
}
for _, key := range keys {
val, ok := content[key].(map[string]interface{})
if !ok {
continue
}
ci, ok := val["contextInfo"].(map[string]interface{})
if !ok {
continue
}
qm, ok := ci["quotedMessage"].([]interface{})
if !ok {
continue
}
ci["quotedMessage"] = qm[0]
goto save
}
continue
save:
rawContent, err = json.Marshal(&content)
if err != nil {
fmt.Printf("Error marshaling updated content of %s: %v\n", mxid, err)
}
_, err = tx.Exec("UPDATE message SET content=$1 WHERE mxid=$2", rawContent, mxid)
if err != nil {
fmt.Printf("Error updating row of %s: %v\n", mxid, err)
}
}
return nil
}}
}

View File

@ -15,14 +15,14 @@ const (
SQLite SQLite
) )
type upgradeFunc func(Dialect, *sql.Tx) error type upgradeFunc func(Dialect, *sql.Tx, *sql.DB) error
type upgrade struct { type upgrade struct {
message string message string
fn upgradeFunc fn upgradeFunc
} }
var upgrades [4]upgrade var upgrades [5]upgrade
func getVersion(dialect Dialect, db *sql.DB) (int, error) { func getVersion(dialect Dialect, db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)") _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
@ -70,7 +70,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
if err != nil { if err != nil {
return err return err
} }
err = upgrade.fn(dialect, tx) err = upgrade.fn(dialect, tx, db)
if err != nil { if err != nil {
return err return err
} }