From 9c48eeb534e349bf1cc2ae0a710d9ac56b82e028 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Aug 2018 15:59:18 +0300 Subject: [PATCH] Add puppet and portal stuff and fix config stuff --- .gitignore | 3 + config/bridge.go | 26 +++---- config/registration.go | 5 +- database/database.go | 12 +++- database/portal.go | 4 +- database/puppet.go | 98 +++++++++++++++++++++++++ database/user.go | 38 +++++++--- example-config.yaml | 2 +- main.go | 6 +- matrix.go | 63 +++++++++++++++- portal.go | 89 +++++++++++++++++++++++ user.go | 158 +++++++++++++++++++++++++++++++++++++++++ 12 files changed, 474 insertions(+), 30 deletions(-) create mode 100644 database/puppet.go create mode 100644 portal.go diff --git a/.gitignore b/.gitignore index 138bbf8..58cd300 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .idea *.session + +*.yaml +!example-config.yaml diff --git a/config/bridge.go b/config/bridge.go index c6e02f6..623b7aa 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -22,24 +22,26 @@ import ( ) type BridgeConfig struct { - RawUsernameTemplate string `yaml:"username_template"` - RawDisplaynameTemplate string `yaml:"displayname_template"` - UsernameTemplate *template.Template `yaml:"-"` - DisplaynameTemplate *template.Template `yaml:"-"` + UsernameTemplate string `yaml:"username_template"` + DisplaynameTemplate string `yaml:"displayname_template"` + usernameTemplate *template.Template `yaml:"-"` + displaynameTemplate *template.Template `yaml:"-"` } -func (bc BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal(bc) +type umBridgeConfig BridgeConfig + +func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + err := unmarshal((*umBridgeConfig)(bc)) if err != nil { return err } - bc.UsernameTemplate, err = template.New("username").Parse(bc.RawUsernameTemplate) + bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate) if err != nil { return err } - bc.DisplaynameTemplate, err = template.New("displayname").Parse(bc.RawDisplaynameTemplate) + bc.displaynameTemplate, err = template.New("displayname").Parse(bc.DisplaynameTemplate) return err } @@ -54,7 +56,7 @@ type UsernameTemplateArgs struct { func (bc BridgeConfig) FormatDisplayname(displayname string) string { var buf bytes.Buffer - bc.DisplaynameTemplate.Execute(&buf, DisplaynameTemplateArgs{ + bc.displaynameTemplate.Execute(&buf, DisplaynameTemplateArgs{ Displayname: displayname, }) return buf.String() @@ -62,7 +64,7 @@ func (bc BridgeConfig) FormatDisplayname(displayname string) string { func (bc BridgeConfig) FormatUsername(receiver, userID string) string { var buf bytes.Buffer - bc.UsernameTemplate.Execute(&buf, UsernameTemplateArgs{ + bc.usernameTemplate.Execute(&buf, UsernameTemplateArgs{ Receiver: receiver, UserID: userID, }) @@ -70,7 +72,7 @@ func (bc BridgeConfig) FormatUsername(receiver, userID string) string { } func (bc BridgeConfig) MarshalYAML() (interface{}, error) { - bc.RawDisplaynameTemplate = bc.FormatDisplayname("{{.Displayname}}") - bc.RawUsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}") + bc.DisplaynameTemplate = bc.FormatDisplayname("{{.Displayname}}") + bc.UsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}") return bc, nil } diff --git a/config/registration.go b/config/registration.go index 2b78fe5..1ea1724 100644 --- a/config/registration.go +++ b/config/registration.go @@ -19,6 +19,7 @@ package config import ( "maunium.net/go/mautrix-appservice" "regexp" + "fmt" ) func (config *Config) NewRegistration() (*appservice.Registration, error) { @@ -53,7 +54,9 @@ func (config *Config) copyToRegistration(registration *appservice.Registration) registration.RateLimited = false registration.SenderLocalpart = config.AppService.Bot.Username - userIDRegex, err := regexp.Compile(config.Bridge.FormatUsername("[0-9]+", "[0-9]+")) + userIDRegex, err := regexp.Compile(fmt.Sprintf("@%s:%s", + config.Bridge.FormatUsername("[0-9]+", "[0-9]+"), + config.Homeserver.Domain)) if err != nil { return err } diff --git a/database/database.go b/database/database.go index 429cc70..341aa64 100644 --- a/database/database.go +++ b/database/database.go @@ -26,7 +26,9 @@ type Database struct { *sql.DB log *log.Sublogger - User *UserQuery + User *UserQuery + Portal *PortalQuery + Puppet *PuppetQuery } func New(file string) (*Database, error) { @@ -43,6 +45,14 @@ func New(file string) (*Database, error) { db: db, log: log.CreateSublogger("Database/User", log.LevelDebug), } + db.Portal = &PortalQuery{ + db: db, + log: log.CreateSublogger("Database/Portal", log.LevelDebug), + } + db.Puppet = &PuppetQuery{ + db: db, + log: log.CreateSublogger("Database/Puppet", log.LevelDebug), + } return db, nil } diff --git a/database/portal.go b/database/portal.go index af2feca..5f5b4e3 100644 --- a/database/portal.go +++ b/database/portal.go @@ -44,8 +44,8 @@ func (pq *PortalQuery) New() *Portal { } } -func (pq *PortalQuery) GetAll() (portals []*Portal) { - rows, err := pq.db.Query("SELECT * FROM portal") +func (pq *PortalQuery) GetAll(owner string) (portals []*Portal) { + rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner) if err != nil || rows == nil { return nil } diff --git a/database/puppet.go b/database/puppet.go new file mode 100644 index 0000000..93b22b9 --- /dev/null +++ b/database/puppet.go @@ -0,0 +1,98 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2018 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + log "maunium.net/go/maulogger" +) + +type PuppetQuery struct { + db *Database + log *log.Sublogger +} + +func (pq *PuppetQuery) CreateTable() error { + _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet ( + jid VARCHAR(255), + receiver VARCHAR(255), + + displayname VARCHAR(255), + avatar VARCHAR(255), + + PRIMARY KEY(jid, receiver) + )`) + return err +} + +func (pq *PuppetQuery) New() *Puppet { + return &Puppet{ + db: pq.db, + log: pq.log, + } +} + +func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { + rows, err := pq.db.Query("SELECT * FROM puppet") + if err != nil || rows == nil { + return nil + } + defer rows.Close() + for rows.Next() { + puppets = append(puppets, pq.New().Scan(rows)) + } + return +} + +func (pq *PuppetQuery) Get(jid, receiver string) *Puppet { + row := pq.db.QueryRow("SELECT * FROM user WHERE jid=? AND receiver=?", jid, receiver) + if row == nil { + return nil + } + return pq.New().Scan(row) +} + +type Puppet struct { + db *Database + log *log.Sublogger + + JID string + Receiver string + + Displayname string + Avatar string +} + +func (puppet *Puppet) Scan(row Scannable) *Puppet { + err := row.Scan(&puppet.JID, &puppet.Receiver, &puppet.Displayname, &puppet.Avatar) + if err != nil { + puppet.log.Fatalln("Database scan failed:", err) + } + return puppet +} + +func (puppet *Puppet) Insert() error { + _, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)", + puppet.JID, puppet.Receiver, puppet.Displayname, puppet.Avatar) + return err +} + +func (puppet *Puppet) Update() error { + _, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=? AND receiver=?", + puppet.Displayname, puppet.Avatar, + puppet.JID, puppet.Receiver) + return err +} diff --git a/database/user.go b/database/user.go index 4d90f66..cbb9f47 100644 --- a/database/user.go +++ b/database/user.go @@ -28,7 +28,9 @@ type UserQuery struct { func (uq *UserQuery) CreateTable() error { _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user ( - mxid VARCHAR(255) PRIMARY KEY, + mxid VARCHAR(255) PRIMARY KEY, + + management_room VARCHAR(255), client_id VARCHAR(255), client_token VARCHAR(255), @@ -71,29 +73,43 @@ type User struct { db *Database log *log.Sublogger - UserID string - - session whatsapp.Session + UserID string + ManagementRoom string + Session *whatsapp.Session } func (user *User) Scan(row Scannable) *User { - err := row.Scan(&user.UserID, &user.session.ClientId, &user.session.ClientToken, &user.session.ServerToken, - &user.session.EncKey, &user.session.MacKey, &user.session.Wid) + sess := whatsapp.Session{} + err := row.Scan(&user.UserID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken, + &sess.EncKey, &sess.MacKey, &sess.Wid) if err != nil { user.log.Fatalln("Database scan failed:", err) } + if len(sess.ClientId) > 0 { + user.Session = &sess + } else { + user.Session = nil + } return user } func (user *User) Insert() error { - _, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?)", user.UserID, user.session.ClientId, - user.session.ClientToken, user.session.ServerToken, user.session.EncKey, user.session.MacKey, user.session.Wid) + var sess whatsapp.Session + if user.Session != nil { + sess = *user.Session + } + _, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.UserID, user.ManagementRoom, + sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid) return err } func (user *User) Update() error { - _, err := user.db.Exec("UPDATE user SET client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?", - user.session.ClientId, user.session.ClientToken, user.session.ServerToken, user.session.EncKey, user.session.MacKey, - user.session.Wid, user.UserID) + var sess whatsapp.Session + if user.Session != nil { + sess = *user.Session + } + _, err := user.db.Exec("UPDATE user SET management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?", + user.ManagementRoom, + sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, user.UserID) return err } diff --git a/example-config.yaml b/example-config.yaml index a06d4be..9c97d18 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -52,7 +52,7 @@ logging: # The directory for log files. Will be created if not found. directory: ./logs # Available variables: .date for the file date and .index for different log files on the same day. - file_name_format: {{.date}}-{{.index}.log + file_name_format: "{{.date}}-{{.index}.log" # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants file_date_format: 2006-01-02 # Log file permissions. diff --git a/main.go b/main.go index eff3133..67a1927 100644 --- a/main.go +++ b/main.go @@ -67,6 +67,8 @@ type Bridge struct { Log *log.Logger MatrixListener *MatrixListener + + users map[string]*User } func NewBridge() *Bridge { @@ -133,7 +135,9 @@ func (bridge *Bridge) Main() { } func main() { - flag.SetHelpTitles("mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.", "[-h] [-c ] [-r ] [-g]") + flag.SetHelpTitles( + "mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.", + "mautrix-whatsapp [-h] [-c ] [-r ] [-g]") err := flag.Parse() if err != nil { fmt.Fprintln(os.Stderr, err) diff --git a/matrix.go b/matrix.go index ee70b4b..c71c3a0 100644 --- a/matrix.go +++ b/matrix.go @@ -18,10 +18,13 @@ package main import ( log "maunium.net/go/maulogger" + "maunium.net/go/mautrix-appservice" + "maunium.net/go/gomatrix" ) type MatrixListener struct { bridge *Bridge + as *appservice.AppService log *log.Sublogger stop chan struct{} } @@ -29,8 +32,9 @@ type MatrixListener struct { func NewMatrixListener(bridge *Bridge) *MatrixListener { return &MatrixListener{ bridge: bridge, + as: bridge.AppService, stop: make(chan struct{}, 1), - log: bridge.Log.CreateSublogger("Matrix", log.LevelDebug), + log: bridge.Log.CreateSublogger("Matrix", log.LevelDebug), } } @@ -39,12 +43,69 @@ func (ml *MatrixListener) Start() { select { case evt := <-ml.bridge.AppService.Events: log.Debugln("Received Matrix event:", evt) + switch evt.Type { + case gomatrix.StateMember: + ml.HandleMembership(evt) + case gomatrix.EventMessage: + ml.HandleMessage(evt) + } case <-ml.stop: return } } } +func (ml *MatrixListener) HandleBotInvite(evt *gomatrix.Event) { + cli := ml.as.BotClient() + + resp, err := cli.JoinRoom(evt.RoomID, "", nil) + if err != nil { + ml.log.Debugln("Failed to join room", evt.RoomID, "with invite from", evt.Sender) + return + } + + members, err := cli.JoinedMembers(resp.RoomID) + if err != nil { + ml.log.Debugln("Failed to get members in room", resp.RoomID, "after accepting invite from", evt.Sender) + cli.LeaveRoom(resp.RoomID) + return + } + + if len(members.Joined) < 2 { + ml.log.Debugln("Leaving empty room", resp.RoomID, "after accepting invite from", evt.Sender) + cli.LeaveRoom(resp.RoomID) + return + } + for mxid, _ := range members.Joined { + if mxid == cli.UserID || mxid == evt.Sender { + continue + } else if true { // TODO check if mxid is WhatsApp puppet + + continue + } + ml.log.Debugln("Leaving multi-user room", resp.RoomID, "after accepting invite from", evt.Sender) + cli.SendNotice(resp.RoomID, "This bridge is user-specific, please don't invite me into rooms with other users.") + cli.LeaveRoom(resp.RoomID) + return + } + + user := ml.bridge.GetUser(evt.Sender) + user.ManagementRoom = resp.RoomID + user.Update() + cli.SendNotice(user.ManagementRoom, "This room has been registered as your bridge management/status room.") + ml.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender) +} + +func (ml *MatrixListener) HandleMembership(evt *gomatrix.Event) { + if evt.Content.Membership == "invite" && evt.GetStateKey() == ml.as.BotMXID() { + ml.HandleBotInvite(evt) + } +} + +func (ml *MatrixListener) HandleMessage(evt *gomatrix.Event) { + +} + func (ml *MatrixListener) Stop() { ml.stop <- struct{}{} } diff --git a/portal.go b/portal.go new file mode 100644 index 0000000..5a82e91 --- /dev/null +++ b/portal.go @@ -0,0 +1,89 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2018 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "maunium.net/go/mautrix-whatsapp/database" + log "maunium.net/go/maulogger" + "fmt" +) + +func (user *User) GetPortalByMXID(mxid string) *Portal { + portal, ok := user.portalsByMXID[mxid] + if !ok { + dbPortal := user.bridge.DB.Portal.GetByMXID(mxid) + if dbPortal == nil || dbPortal.Owner != user.UserID { + return nil + } + portal = user.NewPortal(dbPortal) + user.portalsByJID[portal.JID] = portal + if len(portal.MXID) > 0 { + user.portalsByMXID[portal.MXID] = portal + } + } + return portal +} + +func (user *User) GetPortalByJID(jid string) *Portal { + portal, ok := user.portalsByJID[jid] + if !ok { + dbPortal := user.bridge.DB.Portal.GetByJID(user.UserID, jid) + if dbPortal == nil { + return nil + } + portal = user.NewPortal(dbPortal) + user.portalsByJID[portal.JID] = portal + if len(portal.MXID) > 0 { + user.portalsByMXID[portal.MXID] = portal + } + } + return portal +} + +func (user *User) GetAllPortals() []*Portal { + dbPortals := user.bridge.DB.Portal.GetAll(user.UserID) + output := make([]*Portal, len(dbPortals)) + for index, dbPortal := range dbPortals { + portal, ok := user.portalsByJID[dbPortal.JID] + if !ok { + portal = user.NewPortal(dbPortal) + user.portalsByJID[dbPortal.JID] = portal + if len(dbPortal.MXID) > 0 { + user.portalsByMXID[dbPortal.MXID] = portal + } + } + output[index] = portal + } + return output +} + +func (user *User) NewPortal(dbPortal *database.Portal) *Portal { + return &Portal{ + Portal: dbPortal, + user: user, + bridge: user.bridge, + log: user.bridge.Log.CreateSublogger(fmt.Sprintf("Portal/%s/%s", user.UserID, dbPortal.JID), log.LevelDebug), + } +} + +type Portal struct { + *database.Portal + + user *User + bridge *Bridge + log *log.Sublogger +} diff --git a/user.go b/user.go index 460fbb3..0c3dc9a 100644 --- a/user.go +++ b/user.go @@ -15,3 +15,161 @@ // along with this program. If not, see . package main + +import ( + "maunium.net/go/mautrix-whatsapp/database" + "github.com/Rhymen/go-whatsapp" + "time" + "fmt" + "os" + "github.com/skip2/go-qrcode" + log "maunium.net/go/maulogger" +) + +type User struct { + *database.User + Conn *whatsapp.Conn + + bridge *Bridge + log *log.Sublogger + + portalsByMXID map[string]*Portal + portalsByJID map[string]*Portal + puppets map[string]*Portal +} + +func (bridge *Bridge) GetUser(userID string) *User { + user, ok := bridge.users[userID] + if !ok { + dbUser := bridge.DB.User.Get(userID) + if dbUser == nil { + dbUser = bridge.DB.User.New() + dbUser.Insert() + } + user = bridge.NewUser(dbUser) + bridge.users[user.UserID] = user + } + return user +} + +func (bridge *Bridge) GetAllUsers() []*User { + dbUsers := bridge.DB.User.GetAll() + output := make([]*User, len(dbUsers)) + for index, dbUser := range dbUsers { + user, ok := bridge.users[dbUser.UserID] + if !ok { + user = bridge.NewUser(dbUser) + bridge.users[user.UserID] = user + } + output[index] = user + } + return output +} + +func (bridge *Bridge) InitWhatsApp() { + users := bridge.GetAllUsers() + for _, user := range users { + user.Connect() + } +} + +func (bridge *Bridge) NewUser(dbUser *database.User) *User { + return &User{ + User: dbUser, + bridge: bridge, + log: bridge.Log.CreateSublogger(fmt.Sprintf("User/%s", dbUser.UserID), log.LevelDebug), + } +} + +func (user *User) Connect() { + var err error + user.Conn, err = whatsapp.NewConn(20 * time.Second) + if err != nil { + user.log.Errorln("Failed to connect to WhatsApp:", err) + return + } + user.Conn.AddHandler(user) + user.RestoreSession() +} + +func (user *User) RestoreSession() { + if user.Session != nil { + sess, err := user.Conn.RestoreSession(*user.Session) + if err != nil { + user.log.Errorln("Failed to restore session:", err) + user.Session = nil + return + } + user.Session = &sess + user.log.Debugln("Session restored") + } + return +} + +func (user *User) Login(roomID string) { + bot := user.bridge.AppService.BotClient() + + qrChan := make(chan string, 2) + go func() { + code := <-qrChan + if code == "error" { + return + } + qrCode, err := qrcode.Encode(code, qrcode.Low, 256) + if err != nil { + user.log.Errorln("Failed to encode QR code:", err) + bot.SendNotice(roomID, "Failed to encode QR code (see logs for details)") + return + } + + resp, err := bot.UploadBytes(qrCode, "image/png") + if err != nil { + user.log.Errorln("Failed to upload QR code:", err) + bot.SendNotice(roomID, "Failed to upload QR code (see logs for details)") + return + } + + bot.SendImage(roomID, string(qrCode), resp.ContentURI) + }() + session, err := user.Conn.Login(qrChan) + if err != nil { + user.log.Warnln("Failed to log in:", err) + bot.SendNotice(roomID, "Failed to log in: "+err.Error()) + qrChan <- "error" + return + } + user.Session = &session + user.Update() + bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...") + go user.Sync() +} + +func (user *User) Sync() { + chats, err := user.Conn.Chats() + if err != nil { + user.log.Warnln("Failed to get chats") + return + } + user.log.Debugln(chats) +} + +func (user *User) HandleError(err error) { + user.log.Errorln("WhatsApp error:", err) + fmt.Fprintf(os.Stderr, "%v", err) +} + +func (user *User) HandleTextMessage(message whatsapp.TextMessage) { + fmt.Println(message) +} + +func (user *User) HandleImageMessage(message whatsapp.ImageMessage) { + fmt.Println(message) +} + +func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) { + fmt.Println(message) +} + +func (user *User) HandleJsonMessage(message string) { + fmt.Println(message) +}