// mautrix-groupme - A Matrix-GroupMe puppeting bridge. // Copyright (C) 2022 Sumner Evans, Karmanyaah Malhotra // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package database import ( "strings" "time" log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" "github.com/beeper/groupme/types" ) type UserQuery struct { db *Database log log.Logger } func (uq *UserQuery) New() *User { return &User{ db: uq.db, log: uq.log, } } func (uq *UserQuery) GetAll() (users []*User) { ans := uq.db.Find(&users) if ans.Error != nil || len(users) == 0 { return nil } for _, i := range users { i.db = uq.db i.log = uq.log } return } func (uq *UserQuery) GetByMXID(userID id.UserID) *User { var user User ans := uq.db.Where("mxid = ?", userID).Take(&user) user.db = uq.db user.log = uq.log if ans.Error != nil { return nil } return &user } func (uq *UserQuery) GetByJID(userID types.GroupMeID) *User { var user User ans := uq.db.Where("jid = ?", userID).Limit(1).Find(&user) if ans.Error != nil || ans.RowsAffected == 0 { return nil } user.db = uq.db user.log = uq.log return &user } type User struct { db *Database log log.Logger MXID id.UserID `gorm:"primaryKey"` JID types.GroupMeID `gorm:"unique"` Token types.AuthToken ManagementRoom id.RoomID LastConnection uint64 `gorm:"notNull;default:0"` } //func (user *User) Scan(row Scannable) *User { // var jid, clientID, clientToken, serverToken sql.NullString // var encKey, macKey []byte // err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &user.LastConnection, &clientID, &clientToken, &serverToken, &encKey, &macKey) // if err != nil { // if err != sql.ErrNoRows { // user.log.Errorln("Database scan failed:", err) // } // return nil // } // if len(jid.String) > 0 && len(clientID.String) > 0 { // user.JID = jid.String + whatsappExt.NewUserSuffix // // user.Session = &whatsapp.Session{ // // ClientId: clientID.String, // // ClientToken: clientToken.String, // // ServerToken: serverToken.String, // // EncKey: encKey, // // MacKey: macKey, // // Wid: jid.String + whatsappExt.OldUserSuffix, // // } // } // else { // // user.Session = nil // // } // return user //} func stripSuffix(jid types.GroupMeID) string { if len(jid) == 0 { return jid } index := strings.IndexRune(jid, '@') if index < 0 { return jid } return jid[:index] } func (user *User) jidPtr() *string { if len(user.JID) > 0 { str := stripSuffix(user.JID) return &str } return nil } //func (user *User) sessionUnptr() (sess whatsapp.Session) { // // if user.Session != nil { // // sess = *user.Session // // } // return //} func (user *User) Insert() { ans := user.db.Create(&user) if ans.Error != nil { user.log.Warnfln("Failed to insert %s: %v", user.MXID, ans.Error) } } func (user *User) UpdateLastConnection() { user.LastConnection = uint64(time.Now().Unix()) user.Update() } func (user *User) Update() { ans := user.db.Save(&user) if ans.Error != nil { user.log.Warnfln("Failed to update user: %v", ans.Error) } } type PortalKeyWithMeta struct { PortalKey InCommunity bool } type UserPortal struct { UserJID types.GroupMeID `gorm:"primaryKey;"` PortalJID types.GroupMeID `gorm:"primaryKey;"` PortalReceiver types.GroupMeID `gorm:"primaryKey;"` InCommunity bool `gorm:"notNull;default:false;"` User User `gorm:"foreignKey:UserJID;references:jid;constraint:OnDelete:CASCADE;"` Portal Portal `gorm:"foreignKey:PortalJID,PortalReceiver;references:JID,Receiver;constraint:OnDelete:CASCADE;"` } func (user *User) SetPortalKeys(newKeys []PortalKeyWithMeta) error { tx := user.db.Begin() ans := tx.Where("user_jid = ?", *user.jidPtr()).Delete(UserPortal{}) if ans.Error != nil { _ = tx.Rollback() return ans.Error } for _, key := range newKeys { ans = tx.Create(&UserPortal{ UserJID: *user.jidPtr(), PortalJID: key.JID, PortalReceiver: key.Receiver, InCommunity: key.InCommunity, }) if ans.Error != nil { _ = tx.Rollback() return ans.Error } } println("portalkey transaction complete") return tx.Commit().Error } func (user *User) IsInPortal(key PortalKey) bool { var count int64 user.db.Find(&UserPortal{ UserJID: *user.jidPtr(), PortalJID: key.JID, PortalReceiver: key.Receiver, }).Count(&count) //TODO: efficient return count > 0 } func (user *User) GetPortalKeys() []PortalKey { var up []UserPortal ans := user.db.Where("user_jid = ?", *user.jidPtr()).Find(&up) if ans.Error != nil { user.log.Warnln("Failed to get user portal keys:", ans.Error) return nil } var keys []PortalKey for _, i := range up { key := PortalKey{ JID: i.PortalJID, Receiver: i.PortalReceiver, } keys = append(keys, key) } return keys } func (user *User) GetInCommunityMap() map[PortalKey]bool { var up []UserPortal ans := user.db.Where("user_jid = ?", *user.jidPtr()).Find(&up) if ans.Error != nil { user.log.Warnln("Failed to get user portal keys:", ans.Error) return nil } keys := make(map[PortalKey]bool) for _, i := range up { key := PortalKey{ JID: i.PortalJID, Receiver: i.PortalReceiver, } keys[key] = i.InCommunity } return keys }