Add way to migrate database
This commit is contained in:
		
							
								
								
									
										124
									
								
								database/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								database/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,124 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func countRows(db *Database, table string) (int, error) { | ||||
| 	countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) | ||||
| 	var count int | ||||
| 	err := countRow.Scan(&count) | ||||
| 	return count, err | ||||
| } | ||||
|  | ||||
| const VariableCountLimit = 512 | ||||
|  | ||||
| func migrateTable(old *Database, new *Database, table string, columns ...string) error { | ||||
| 	columnNames := strings.Join(columns, ",") | ||||
| 	fmt.Printf("Migrating %s: ", table) | ||||
| 	rowCount, err := countRows(old, table) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fmt.Print("found ", rowCount, " rows of data, ") | ||||
| 	rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	serverColNames, err := rows.Columns() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	colCount := len(serverColNames) | ||||
| 	valueStringFormat := strings.Repeat("$%d, ", colCount) | ||||
| 	valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2]) | ||||
| 	cols := make([]interface{}, colCount) | ||||
| 	colPtrs := make([]interface{}, colCount) | ||||
| 	for i := 0; i < colCount; i++ { | ||||
| 		colPtrs[i] = &cols[i] | ||||
| 	} | ||||
| 	batchSize := VariableCountLimit / colCount | ||||
| 	values := make([]interface{}, batchSize*colCount) | ||||
| 	valueStrings := make([]string, batchSize) | ||||
| 	var inserted int64 | ||||
| 	batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize))) | ||||
| 	tx, err := new.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fmt.Printf("migrating in %d batches: ", batchCount) | ||||
| 	for rowCount > 0 { | ||||
| 		var i int | ||||
| 		for ; rows.Next() && i < batchSize; i++ { | ||||
| 			colPtrs := make([]interface{}, colCount) | ||||
| 			valueStringArgs := make([]interface{}, colCount) | ||||
| 			for j := 0; j < colCount; j++ { | ||||
| 				pos := i*colCount + j | ||||
| 				colPtrs[j] = &values[pos] | ||||
| 				valueStringArgs[j] = pos + 1 | ||||
| 			} | ||||
| 			valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...) | ||||
| 			err = rows.Scan(colPtrs...) | ||||
| 			if err != nil { | ||||
| 				panic(err) | ||||
| 			} | ||||
| 		} | ||||
| 		slicedValues := values | ||||
| 		slicedValueStrings := valueStrings | ||||
| 		if i < len(valueStrings) { | ||||
| 			slicedValueStrings = slicedValueStrings[:i] | ||||
| 			slicedValues = slicedValues[:i*colCount] | ||||
| 		} | ||||
| 		res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		count, _ := res.RowsAffected() | ||||
| 		inserted += count | ||||
| 		rowCount -= batchSize | ||||
| 		fmt.Print("#") | ||||
| 	} | ||||
| 	err = tx.Commit() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fmt.Println(" -- done with", inserted, "rows inserted") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func Migrate(old *Database, new *Database) { | ||||
| 	err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "user", "mxid", "jid", "management_room", "client_id", "client_token", "server_token", "enc_key", "mac_key", "last_connection") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "puppet", "jid", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "user_portal", "user_jid", "portal_jid", "portal_receiver", "in_community") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "content", "timestamp") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "mx_registrations", "user_id") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
| @@ -86,17 +86,16 @@ func init() { | ||||
| 			roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1) | ||||
| 		} | ||||
|  | ||||
| 		if data, err := ioutil.ReadFile("mx-state.json"); err != nil { | ||||
| 			ctx.log.Debugln("mx-state.json not found, not migrating state store") | ||||
| 			return nil | ||||
| 		} else if err = json.Unmarshal(data, &store); err != nil { | ||||
| 			return err | ||||
| 		} else if _, err := tx.Exec(userProfileTable); err != nil { | ||||
| 		if _, err := tx.Exec(userProfileTable); err != nil { | ||||
| 			return err | ||||
| 		} else if _, err = tx.Exec(roomStateTable); err != nil { | ||||
| 			return err | ||||
| 		} else if _, err = tx.Exec(registrationsTable); err != nil { | ||||
| 			return err | ||||
| 		} else if data, err := ioutil.ReadFile("mx-state.json"); err != nil { | ||||
| 			ctx.log.Debugln("mx-state.json not found, not migrating state store") | ||||
| 		} else if err = json.Unmarshal(data, &store); err != nil { | ||||
| 			return err | ||||
| 		} else if err = migrateRegistrations(tx, store.Registrations); err != nil { | ||||
| 			return err | ||||
| 		} else if err = migrateMemberships(tx, store.Memberships); err != nil { | ||||
|   | ||||
							
								
								
									
										32
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								main.go
									
									
									
									
									
								
							| @@ -43,6 +43,7 @@ var configPath = flag.MakeFull("c", "config", "The path to your config file.", " | ||||
| var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() | ||||
| var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() | ||||
| var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if database is too new").Default("false").Bool() | ||||
| var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool() | ||||
| var wantHelp, _ = flag.MakeHelpFlag() | ||||
|  | ||||
| func (bridge *Bridge) GenerateRegistration() { | ||||
| @@ -67,6 +68,32 @@ func (bridge *Bridge) GenerateRegistration() { | ||||
| 	os.Exit(0) | ||||
| } | ||||
|  | ||||
| func (bridge *Bridge) MigrateDatabase() { | ||||
| 	oldDB, err := database.New(flag.Arg(0), flag.Arg(1)) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Failed to open old database:", err) | ||||
| 		os.Exit(30) | ||||
| 	} | ||||
| 	err = oldDB.Init() | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Failed to upgrade old database:", err) | ||||
| 		os.Exit(31) | ||||
| 	} | ||||
|  | ||||
| 	newDB, err := database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI) | ||||
| 	if err != nil { | ||||
| 		bridge.Log.Fatalln("Failed to open new database:", err) | ||||
| 		os.Exit(32) | ||||
| 	} | ||||
| 	err = newDB.Init() | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Failed to upgrade new database:", err) | ||||
| 		os.Exit(33) | ||||
| 	} | ||||
|  | ||||
| 	database.Migrate(oldDB, newDB) | ||||
| } | ||||
|  | ||||
| type Bridge struct { | ||||
| 	AS             *appservice.AppService | ||||
| 	EventProcessor *appservice.EventProcessor | ||||
| @@ -265,6 +292,9 @@ func (bridge *Bridge) Main() { | ||||
| 	if *generateRegistration { | ||||
| 		bridge.GenerateRegistration() | ||||
| 		return | ||||
| 	} else if *migrateFrom { | ||||
| 		bridge.MigrateDatabase() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	bridge.Init() | ||||
| @@ -285,7 +315,7 @@ func (bridge *Bridge) Main() { | ||||
| func main() { | ||||
| 	flag.SetHelpTitles( | ||||
| 		"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.", | ||||
| 		"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g]") | ||||
| 		"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]") | ||||
| 	err := flag.Parse() | ||||
| 	if err != nil { | ||||
| 		fmt.Fprintln(os.Stderr, err) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user