Stop handling events from custom puppets
This commit is contained in:
@@ -110,8 +110,10 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
|
||||
context.mx = MatrixHandler(context)
|
||||
init_formatter(context)
|
||||
init_portal(context)
|
||||
init_puppet(context)
|
||||
startup_actions = init_user(context) + [start, context.mx.init_as_bot()]
|
||||
startup_actions = (init_puppet(context) +
|
||||
init_user(context) +
|
||||
[start,
|
||||
context.mx.init_as_bot()])
|
||||
|
||||
if context.bot:
|
||||
startup_actions.append(context.bot.start())
|
||||
|
||||
@@ -124,7 +124,7 @@ class AbstractUser:
|
||||
self.log.debug("%s connected: %s", self.mxid, self.connected)
|
||||
return self
|
||||
|
||||
async def ensure_started(self, even_if_no_session=False):
|
||||
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
|
||||
if not self.puppet_whitelisted:
|
||||
return self
|
||||
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
|
||||
|
||||
@@ -56,15 +56,11 @@ async def ping_bot(evt: CommandEvent):
|
||||
"account")
|
||||
async def login_matrix(evt: CommandEvent):
|
||||
puppet = pu.Puppet.get(evt.sender.tgid)
|
||||
prev_info = puppet.custom_mxid, puppet.access_token
|
||||
puppet.custom_mxid = evt.sender.mxid
|
||||
puppet.access_token = " ".join(evt.args)
|
||||
puppet.refresh_intents()
|
||||
if not await puppet.get_profile():
|
||||
puppet.custom_mxid, puppet.access_token = prev_info
|
||||
puppet.refresh_intents()
|
||||
resp = puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
|
||||
if resp == 2:
|
||||
return await evt.reply("You can only log in as your own Matrix user.")
|
||||
elif resp == 1:
|
||||
return await evt.reply("Failed to verify access token.")
|
||||
puppet.save()
|
||||
return await evt.reply(
|
||||
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
|
||||
|
||||
|
||||
+11
-12
@@ -17,14 +17,14 @@
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text)
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
import json
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class Portal(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "portal"
|
||||
|
||||
# Telegram chat information
|
||||
@@ -42,9 +42,8 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid = Column(String)
|
||||
@@ -56,7 +55,7 @@ class Message(Base):
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
@@ -70,7 +69,7 @@ class UserPortal(Base):
|
||||
|
||||
|
||||
class User(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "user"
|
||||
|
||||
mxid = Column(String, primary_key=True)
|
||||
@@ -83,7 +82,7 @@ class User(Base):
|
||||
|
||||
|
||||
class RoomState(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "mx_room_state"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
@@ -107,7 +106,7 @@ class RoomState(Base):
|
||||
|
||||
|
||||
class UserProfile(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "mx_user_profile"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
@@ -125,7 +124,7 @@ class UserProfile(Base):
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
|
||||
@@ -133,7 +132,7 @@ class Contact(Base):
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
@@ -149,14 +148,14 @@ class Puppet(Base):
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
class BotChat(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "bot_chat"
|
||||
id = Column(Integer, primary_key=True)
|
||||
type = Column(String, nullable=False)
|
||||
|
||||
|
||||
class TelegramFile(Base):
|
||||
query = None
|
||||
query = None # type: Query
|
||||
__tablename__ = "telegram_file"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
|
||||
@@ -824,7 +824,12 @@ class Portal:
|
||||
mxid=event_id))
|
||||
self.db.commit()
|
||||
|
||||
async def handle_matrix_message(self, sender, message, event_id):
|
||||
async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
|
||||
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
|
||||
if puppet and message.get("net.maunium.telegram.puppet", False):
|
||||
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
|
||||
return
|
||||
|
||||
logged_in = not await sender.needs_relaybot(self)
|
||||
client = sender.client if logged_in else self.bot.client
|
||||
sender_id = sender.tgid if logged_in else self.bot.tgid
|
||||
|
||||
@@ -36,6 +36,7 @@ class Puppet:
|
||||
username_template = None
|
||||
hs_domain = None
|
||||
cache = {}
|
||||
by_custom_mxid = {}
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
@@ -60,22 +61,51 @@ class Puppet:
|
||||
self.refresh_intents()
|
||||
|
||||
self.cache[id] = self
|
||||
if self.custom_mxid:
|
||||
self.by_custom_mxid[self.custom_mxid] = self
|
||||
|
||||
def refresh_intents(self):
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
|
||||
async def get_profile(self):
|
||||
try:
|
||||
return await self.intent.get_profile(self.custom_mxid)
|
||||
except MatrixError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def tgid(self):
|
||||
return self.id
|
||||
|
||||
async def switch_mxid(self, access_token, mxid):
|
||||
prev_mxid = self.custom_mxid
|
||||
self.custom_mxid = mxid
|
||||
self.access_token = access_token
|
||||
self.refresh_intents()
|
||||
|
||||
err = await self.test_custom_mxid()
|
||||
if err != 0:
|
||||
return err
|
||||
|
||||
try:
|
||||
del self.by_custom_mxid[prev_mxid]
|
||||
except KeyError:
|
||||
pass
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
self.by_custom_mxid[self.mxid] = self
|
||||
self.save()
|
||||
return 0
|
||||
|
||||
async def test_custom_mxid(self):
|
||||
if not self.is_real_user:
|
||||
return 0
|
||||
|
||||
mxid = await self.intent.whoami()
|
||||
if not mxid or mxid != self.custom_mxid:
|
||||
self.custom_mxid = None
|
||||
self.access_token = None
|
||||
self.refresh_intents()
|
||||
if mxid != self.custom_mxid:
|
||||
return 2
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def is_logged_in(self):
|
||||
return True
|
||||
|
||||
@@ -212,6 +242,30 @@ class Puppet:
|
||||
tgid = cls.get_id_from_mxid(mxid)
|
||||
return cls.get(tgid, create) if tgid else None
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid):
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
try:
|
||||
return cls.by_custom_mxid[mxid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
|
||||
if puppet:
|
||||
puppet = cls.from_db(puppet)
|
||||
return puppet
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_with_custom_mxid(cls):
|
||||
return [cls.by_custom_mxid[puppet.mxid]
|
||||
if puppet.custom_mxid in cls.by_custom_mxid
|
||||
else cls.from_db(puppet)
|
||||
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid):
|
||||
match = cls.mxid_regex.match(mxid)
|
||||
@@ -261,3 +315,4 @@ def init(context):
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
localpart = Puppet.username_template.format(userid="(.+)")
|
||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
||||
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict
|
||||
from typing import Dict, Awaitable, Optional
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
@@ -185,6 +185,9 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
# region Telegram actions that need custom methods
|
||||
|
||||
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
|
||||
return super().ensure_started(even_if_no_session)
|
||||
|
||||
async def update_info(self, info: User = None):
|
||||
info = info or await self.client.get_me()
|
||||
changed = False
|
||||
@@ -309,7 +312,7 @@ class User(AbstractUser):
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid, create=True):
|
||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -332,7 +335,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid):
|
||||
def get_by_tgid(cls, tgid) -> "Optional[User]":
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
@@ -346,7 +349,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username):
|
||||
def find_by_username(cls, username) -> "Optional[User]":
|
||||
if not username:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user