diff --git a/database/cryptostore.go b/database/cryptostore.go index 5cca49b..dbfca25 100644 --- a/database/cryptostore.go +++ b/database/cryptostore.go @@ -22,7 +22,6 @@ import ( "database/sql" "fmt" "strings" - "sync" "github.com/lib/pq" "github.com/pkg/errors" @@ -44,9 +43,6 @@ type SQLCryptoStore struct { Account *crypto.OlmAccount GhostIDFormat string - - OGSLock sync.RWMutex - OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession } var _ crypto.Store = (*SQLCryptoStore)(nil) @@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore { log: db.log.Sub("CryptoStore"), PickleKey: []byte("maunium.net/go/mautrix-whatsapp"), DeviceID: deviceID, - - OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession), } } @@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send }, nil } -func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error { - store.OGSLock.Lock() - store.OutGroupSessions[roomID] = session - store.OGSLock.Unlock() - return nil +func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error { + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", + session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime) + return err +} + +func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error { + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5", + sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID()) + return err } func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) { - store.OGSLock.RLock() - defer store.OGSLock.RUnlock() - return store.OutGroupSessions[roomID], nil + var ogs crypto.OutboundGroupSession + var sessionBytes []byte + err := store.db.QueryRow(` + SELECT session, shared, max_messages, message_count, max_age, created_at, last_used + FROM crypto_megolm_outbound_session WHERE room_id=$1`, + roomID, + ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + intOGS := olm.NewBlankOutboundGroupSession() + err = intOGS.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return nil, err + } + ogs.Internal = *intOGS + ogs.RoomID = roomID + return &ogs, nil } -func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error { - store.OGSLock.Lock() - delete(store.OutGroupSessions, roomID) - store.OGSLock.Unlock() - return nil +func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error { + _, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID) + return err } func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool { @@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID { queryString[i] = fmt.Sprintf("$%d", i+1) params[i] = user } - rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...) + rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) } if err != nil { store.log.Warnln("Failed to filter tracked users:", err) diff --git a/database/upgrades/2020-05-12-outbound-group-session-store.go b/database/upgrades/2020-05-12-outbound-group-session-store.go new file mode 100644 index 0000000..2d635ef --- /dev/null +++ b/database/upgrades/2020-05-12-outbound-group-session-store.go @@ -0,0 +1,26 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error { + // TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite + _, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session ( + room_id VARCHAR(255) PRIMARY KEY, + session_id CHAR(43) NOT NULL UNIQUE, + session bytea NOT NULL, + shared BOOLEAN NOT NULL, + max_messages INTEGER NOT NULL, + message_count INTEGER NOT NULL, + max_age BIGINT NOT NULL, + created_at timestamp NOT NULL, + last_used timestamp NOT NULL + )`) + if err != nil { + return err + } + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index ec8e6e7..22c8384 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -28,7 +28,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 14 +const NumberOfUpgrades = 15 var upgrades [NumberOfUpgrades]upgrade diff --git a/go.mod b/go.mod index f8d6064..574c106 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( gopkg.in/yaml.v2 v2.2.8 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.1.1 - maunium.net/go/mautrix v0.4.3 + maunium.net/go/mautrix v0.4.4 ) replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6 diff --git a/go.sum b/go.sum index 9dbe6fa..4562111 100644 --- a/go.sum +++ b/go.sum @@ -92,3 +92,5 @@ maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw= maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho= maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4= maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho= +maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM= +maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=