Store outbound group sessions in database
This commit is contained in:
@ -22,7 +22,6 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
@ -44,9 +43,6 @@ type SQLCryptoStore struct {
|
||||
Account *crypto.OlmAccount
|
||||
|
||||
GhostIDFormat string
|
||||
|
||||
OGSLock sync.RWMutex
|
||||
OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
|
||||
}
|
||||
|
||||
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||
@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
|
||||
log: db.log.Sub("CryptoStore"),
|
||||
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
|
||||
DeviceID: deviceID,
|
||||
|
||||
OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
|
||||
}
|
||||
}
|
||||
|
||||
@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
|
||||
store.OGSLock.Lock()
|
||||
store.OutGroupSessions[roomID] = session
|
||||
store.OGSLock.Unlock()
|
||||
return nil
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
||||
session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
|
||||
sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
|
||||
store.OGSLock.RLock()
|
||||
defer store.OGSLock.RUnlock()
|
||||
return store.OutGroupSessions[roomID], nil
|
||||
var ogs crypto.OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
err := store.db.QueryRow(`
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1`,
|
||||
roomID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
intOGS := olm.NewBlankOutboundGroupSession()
|
||||
err = intOGS.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ogs.Internal = *intOGS
|
||||
ogs.RoomID = roomID
|
||||
return &ogs, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
|
||||
store.OGSLock.Lock()
|
||||
delete(store.OutGroupSessions, roomID)
|
||||
store.OGSLock.Unlock()
|
||||
return nil
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
|
||||
@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
||||
queryString[i] = fmt.Sprintf("$%d", i+1)
|
||||
params[i] = user
|
||||
}
|
||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
|
||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
}
|
||||
if err != nil {
|
||||
store.log.Warnln("Failed to filter tracked users:", err)
|
||||
|
Reference in New Issue
Block a user