diff --git a/crypto.go b/crypto.go index 0c60603..53b63d1 100644 --- a/crypto.go +++ b/crypto.go @@ -68,6 +68,10 @@ func NewCryptoHelper(bridge *Bridge) Crypto { func (helper *CryptoHelper) Init() error { helper.log.Debugln("Initializing end-to-bridge encryption...") + + helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(), + fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain)) + var err error helper.client, err = helper.loginBot() if err != nil { @@ -77,8 +81,6 @@ func (helper *CryptoHelper) Init() error { helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID) logger := &cryptoLogger{helper.baseLog} stateStore := &cryptoStateStore{helper.bridge} - helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID, helper.client.UserID, - fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain)) helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore) helper.client.Logger = logger.int.Sub("Bot") @@ -89,27 +91,30 @@ func (helper *CryptoHelper) Init() error { } func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) { - deviceID := helper.bridge.DB.FindDeviceID() + deviceID := helper.store.FindDeviceID() if len(deviceID) > 0 { helper.log.Debugln("Found existing device ID for bot in database:", deviceID) } mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret)) mac.Write([]byte(helper.bridge.AS.BotMXID())) - resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{ + client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "") + if err != nil { + return nil, err + } + resp, err := client.Login(&mautrix.ReqLogin{ Type: "m.login.password", Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())}, Password: hex.EncodeToString(mac.Sum(nil)), DeviceID: deviceID, InitialDeviceDisplayName: "WhatsApp Bridge", + StoreCredentials: true, }) if err != nil { return nil, err } - client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken) - if err != nil { - return nil, err + if len(deviceID) == 0 { + helper.store.DeviceID = resp.DeviceID } - client.DeviceID = resp.DeviceID return client, nil } @@ -228,6 +233,8 @@ type cryptoStateStore struct { bridge *Bridge } +var _ crypto.StateStore = (*cryptoStateStore)(nil) + func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { portal := c.bridge.GetPortalByMXID(id) if portal != nil { @@ -239,3 +246,8 @@ func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID { return c.bridge.StateStore.FindSharedRooms(id) } + +func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { + // TODO implement + return nil +} diff --git a/database/cryptostore.go b/database/cryptostore.go index b691349..eab7e7e 100644 --- a/database/cryptostore.go +++ b/database/cryptostore.go @@ -35,9 +35,9 @@ type SQLCryptoStore struct { var _ crypto.Store = (*SQLCryptoStore)(nil) -func NewSQLCryptoStore(db *Database, deviceID id.DeviceID, userID id.UserID, ghostIDFormat string) *SQLCryptoStore { +func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore { return &SQLCryptoStore{ - SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, deviceID, + SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "", []byte("maunium.net/go/mautrix-whatsapp"), &cryptoLogger{db.log.Sub("CryptoStore")}), UserID: userID, @@ -45,10 +45,10 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID, userID id.UserID, gho } } -func (db *Database) FindDeviceID() (deviceID id.DeviceID) { - err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID) +func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) { + err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID) if err != nil && err != sql.ErrNoRows { - db.log.Warnln("Failed to scan device ID:", err) + store.Log.Warn("Failed to scan device ID: %v", err) } return } diff --git a/database/upgrades/2020-07-10-update-crypto-store.go b/database/upgrades/2020-07-10-update-crypto-store.go new file mode 100644 index 0000000..9baf6b0 --- /dev/null +++ b/database/upgrades/2020-07-10-update-crypto-store.go @@ -0,0 +1,13 @@ +package upgrades + +import ( + "database/sql" + + "maunium.net/go/mautrix/crypto" +) + +func init() { + upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error { + return crypto.SQLStoreMigrations[1](tx, c.dialect.String()) + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 9b1d572..4566959 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -15,6 +15,17 @@ const ( SQLite ) +func (dialect Dialect) String() string { + switch dialect { + case Postgres: + return "postgres" + case SQLite: + return "sqlite3" + default: + return "" + } +} + type upgradeFunc func(*sql.Tx, context) error type context struct { @@ -28,7 +39,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 16 +const NumberOfUpgrades = 17 var upgrades [NumberOfUpgrades]upgrade diff --git a/go.mod b/go.mod index 4574b2a..6a524f5 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( gopkg.in/yaml.v2 v2.3.0 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.1.1 - maunium.net/go/mautrix v0.5.8 + maunium.net/go/mautrix v0.6.0 ) replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.3.4 diff --git a/go.sum b/go.sum index 37a4896..edc6a0c 100644 --- a/go.sum +++ b/go.sum @@ -204,3 +204,5 @@ maunium.net/go/mautrix v0.5.7 h1:tyRwllz3SZvMfD2YjaJPWopxmUCxZgQ2hl5/3/loHTE= maunium.net/go/mautrix v0.5.7/go.mod h1:FLbMANzwqlsX2Fgm7SDe+E4I3wSa4UxJRKqS5wGkCwA= maunium.net/go/mautrix v0.5.8 h1:jOE3U8WYSIc4qbYvyVaDhOaQcB3sDPN5A2zQ93YixZ0= maunium.net/go/mautrix v0.5.8/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo= +maunium.net/go/mautrix v0.6.0 h1:V32l4aygKk2XcH3fi8Yd0pFeSyYZJNRIvr8vdA2GtC8= +maunium.net/go/mautrix v0.6.0/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo=