Improve type hints and set version to 0.4.0+dev

This commit is contained in:
Tulir Asokan
2018-09-10 01:14:12 +03:00
parent 4b2cdc3d39
commit d4ea5f8b38
21 changed files with 200 additions and 181 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
__version__ = "0.3.0"
__version__ = "0.4.0+dev"
__author__ = "Tulir Asokan <tulir@maunium.net>"
+16 -16
View File
@@ -35,10 +35,10 @@ from alchemysession import AlchemySessionContainer
from . import portal as po, puppet as pu, __version__
from .db import Message as DBMessage
from .types import TelegramID
from .tgclient import MautrixTelegramClient
if TYPE_CHECKING:
from .types import TelegramID
from .context import Context
from .config import Config
from .bot import Bot
@@ -210,13 +210,13 @@ class AbstractUser(ABC):
@staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(update.channel_id)
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)
@staticmethod
async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(update.participants.chat_id)
portal = po.Portal.get_by_tgid(TelegramID(update.participants.chat_id))
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)
@@ -225,7 +225,7 @@ class AbstractUser(ABC):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
portal = po.Portal.get_by_tgid(update.peer.user_id, self.tgid)
portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid)
if not portal or not portal.mxid:
return
@@ -234,38 +234,38 @@ class AbstractUser(ABC):
if not message:
return
puppet = pu.Puppet.get(update.peer.user_id)
puppet = pu.Puppet.get(TelegramID(update.peer.user_id))
await puppet.intent.mark_read(portal.mxid, message.mxid)
async def update_admin(self,
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
if not portal or not portal.mxid:
return
if isinstance(update, UpdateChatAdmins):
await portal.set_telegram_admins_enabled(update.enabled)
elif isinstance(update, UpdateChatParticipantAdmin):
await portal.set_telegram_admin(update.user_id)
await portal.set_telegram_admin(TelegramID(update.user_id))
else:
self.log.warning("Unexpected admin status update: %s", update)
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
if not portal or not portal.mxid:
return
sender = pu.Puppet.get(update.user_id)
sender = pu.Puppet.get(TelegramID(update.user_id))
await portal.handle_telegram_typing(sender, update)
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update, UpdateUserName):
if await puppet.update_displayname(self, update):
puppet.save()
@@ -276,7 +276,7 @@ class AbstractUser(ABC):
self.log.warning("Unexpected other user info update: %s", update)
async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(update.user_id)
puppet = pu.Puppet.get(TelegramID(update.user_id))
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
elif isinstance(update.status, UserStatusOffline):
@@ -289,10 +289,10 @@ class AbstractUser(ABC):
Optional[pu.Puppet],
Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.from_id)
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
sender = pu.Puppet.get(TelegramID(update.from_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
sender = pu.Puppet.get(self.tgid if update.out else update.user_id)
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage,
UpdateEditMessage, UpdateEditChannelMessage)):
@@ -338,7 +338,7 @@ class AbstractUser(ABC):
if len(update.messages) > MAX_DELETIONS:
return
portal = po.Portal.get_by_tgid(update.channel_id)
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
if not portal:
return
+29 -26
View File
@@ -61,15 +61,15 @@ class Bot(AbstractUser):
async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist:
if isinstance(id, str):
entity = await self.client.get_input_entity(id)
for user_id in whitelist:
if isinstance(user_id, str):
entity = await self.client.get_input_entity(user_id)
if isinstance(entity, InputUser):
id = entity.user_id
user_id = entity.user_id
else:
id = None
if isinstance(id, int):
self.tg_whitelist.append(id)
user_id = None
if isinstance(user_id, int):
self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
await super().start(delete_unless_authenticated)
@@ -85,20 +85,20 @@ class Bot(AbstractUser):
self.username = info.username
self.mxid = pu.Puppet.get_mxid_from_id(self.tgid)
chat_ids = [id for id, type in self.chats.items() if type == "chat"]
chat_ids = [chat_id for chat_id, chat_type in self.chats.items() if chat_type == "chat"]
response = await self.client(GetChatsRequest(chat_ids))
for chat in response.chats:
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
self.remove_chat(chat.id)
channel_ids = [InputChannel(id, 0)
for id, type in self.chats.items()
if type == "channel"]
for id in channel_ids:
channel_ids = [InputChannel(chat_id, 0)
for chat_id, chat_type in self.chats.items()
if chat_type == "channel"]
for channel_id in channel_ids:
try:
await self.client(GetChannelsRequest([id]))
await self.client(GetChannelsRequest([channel_id]))
except (ChannelPrivateError, ChannelInvalidError):
self.remove_chat(id.channel_id)
self.remove_chat(channel_id.channel_id)
if config["bridge.catch_up"]:
try:
@@ -112,18 +112,18 @@ class Bot(AbstractUser):
def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)
def add_chat(self, id: int, type: str) -> None:
if id not in self.chats:
self.chats[id] = type
self.db.add(BotChat(id=id, type=type))
def add_chat(self, chat_id: int, chat_type: str) -> None:
if chat_id not in self.chats:
self.chats[chat_id] = chat_type
self.db.add(BotChat(id=chat_id, type=chat_type))
self.db.commit()
def remove_chat(self, id: int) -> None:
def remove_chat(self, chat_id: int) -> None:
try:
del self.chats[id]
del self.chats[chat_id]
except KeyError:
pass
existing_chat = BotChat.query.get(id)
existing_chat = BotChat.query.get(chat_id)
if existing_chat:
self.db.delete(existing_chat)
self.db.commit()
@@ -191,7 +191,8 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
@staticmethod
def handle_command_id(message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
@@ -241,16 +242,16 @@ class Bot(AbstractUser):
to_id = message.to_id
if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id
type = "channel"
chat_type = "channel"
elif isinstance(to_id, PeerChat):
to_id = to_id.chat_id
type = "chat"
chat_type = "chat"
else:
return
action = message.action
if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users:
self.add_chat(to_id, type)
self.add_chat(to_id, chat_type)
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)
@@ -265,7 +266,9 @@ class Bot(AbstractUser):
and update.message.entities and len(update.message.entities) > 0
and isinstance(update.message.entities[0], MessageEntityBotCommand))
if is_command:
return await self.handle_command(update.message)
await self.handle_command(update.message)
return True
return False
def is_in_chat(self, peer_id) -> bool:
return peer_id in self.chats
+1 -1
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Awaitable, Dict, Optional
from typing import Any, Dict, Optional
import asyncio
from telethon.errors import (
+6 -5
View File
@@ -33,7 +33,8 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[Mat
empty_portals = [] # type: List[po.Portal]
rooms = await intent.get_joined_rooms()
for room in rooms:
for room_str in rooms:
room = MatrixRoomID(room_str)
portal = po.Portal.get_by_mxid(room)
if not portal:
try:
@@ -41,7 +42,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[Mat
except MatrixRequestError:
members = []
if len(members) == 2:
other_member = members[0] if members[0] != intent.mxid else members[1]
other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1])
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
else:
@@ -128,9 +129,9 @@ async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
rooms_to_clean += empty_portals
elif command == "clean-range":
try:
range = evt.args[1]
group, range = range[0], range[1:]
start, end = range.split("-")
clean_range = evt.args[1]
group, clean_range = clean_range[0], clean_range[1:]
start, end = clean_range.split("-")
start, end = int(start), int(end)
if group == "M":
group = [room_id for (room_id, user_id) in management_rooms]
+5 -6
View File
@@ -14,8 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
from collections import namedtuple
from typing import Awaitable, Callable, Dict, List, NamedTuple, Optional
import markdown
import logging
@@ -160,16 +159,16 @@ class CommandProcessor:
orig_command = command
command = command.lower()
try:
command_handler = command_handlers[command]
handler = command_handlers[command]
except KeyError:
if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command)
evt.command = ""
command_handler = sender.command_status["next"]
handler = sender.command_status["next"]
else:
command_handler = command_handlers["unknown-command"]
handler = command_handlers["unknown-command"]
try:
await command_handler(evt)
await handler(evt)
except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception:
+12 -12
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
from typing import Dict, Callable, Optional, Tuple, Coroutine
import asyncio
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
@@ -85,18 +85,20 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None
) -> Tuple[Union[Dict, po.Portal], bool]:
) -> Optional[po.Portal]:
room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id)
if not portal:
that_this = "This" if room_id == evt.room_id else "That"
return await evt.reply(f"{that_this} is not a portal room."), False
await evt.reply(f"{that_this} is not a portal room.")
return None
if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, permission):
action = action or f"{permission.replace('_', ' ')}s"
return await evt.reply(f"You do not have the permissions to {action} that portal."), False
return portal, True
await evt.reply(f"You do not have the permissions to {action} that portal.")
return None
return portal
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
@@ -123,10 +125,9 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
"Only works for group chats; to delete a private chat portal, simply "
"leave the room.")
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
portal = await _get_portal_and_check_permission(evt, "unbridge")
if not portal:
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete",
@@ -145,10 +146,9 @@ async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
portal = await _get_portal_and_check_permission(evt, "unbridge")
if not portal:
return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge",
@@ -231,7 +231,7 @@ async def bridge(evt: CommandEvent) -> Dict:
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Coroutine[None, None, None]]:
) -> Tuple[bool, Optional[Coroutine[None, None, None]]]:
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
+1 -1
View File
@@ -92,7 +92,7 @@ async def private_message(evt: CommandEvent) -> Optional[Dict]:
f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
async def _join(evt: CommandEvent, arg: str) -> Tuple[Optional[TypeUpdates], Optional[Dict]]:
if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):]
try:
+1 -1
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
+26 -24
View File
@@ -14,14 +14,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from typing import Dict, Optional, List
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
from .types import TelegramID
from .base import Base
@@ -30,13 +31,13 @@ class Portal(Base):
__tablename__ = "portal"
# Telegram chat information
tgid = Column(Integer, primary_key=True)
tg_receiver = Column(Integer, primary_key=True)
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_receiver = Column(Integer, primary_key=True) # type: TelegramID
peer_type = Column(String, nullable=False)
megagroup = Column(Boolean)
# Matrix portal information
mxid = Column(String, unique=True, nullable=True)
mxid = Column(String, unique=True, nullable=True) # type: Optional[MatrixRoomID]
# Telegram chat metadata
username = Column(String, nullable=True)
@@ -49,10 +50,10 @@ class Message(Base):
query = None # type: Query
__tablename__ = "message"
mxid = Column(String)
mx_room = Column(String)
tgid = Column(Integer, primary_key=True)
tg_space = Column(Integer, primary_key=True)
mxid = Column(String) # type: MatrixEventID
mx_room = Column(String) # type: MatrixRoomID
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_space = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@@ -62,9 +63,9 @@ class UserPortal(Base):
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
portal = Column(Integer, primary_key=True)
portal_receiver = Column(Integer, primary_key=True)
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
@@ -75,12 +76,13 @@ class User(Base):
query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True)
tgid = Column(Integer, nullable=True, unique=True)
mxid = Column(String, primary_key=True) # type: MatrixUserID
tgid = Column(Integer, nullable=True, unique=True) # type: Optional[TelegramID]
tg_username = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan")
cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@@ -88,7 +90,7 @@ class RoomState(Base):
query = None # type: Query
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True)
room_id = Column(String, primary_key=True) # type: MatrixRoomID
_power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = {} # type: Dict
@@ -112,8 +114,8 @@ class UserProfile(Base):
query = None # type: Query
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True)
user_id = Column(String, primary_key=True)
room_id = Column(String, primary_key=True) # type: MatrixRoomID
user_id = Column(String, primary_key=True) # type: MatrixUserID
membership = Column(String, nullable=False, default="leave")
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
@@ -130,19 +132,19 @@ class Contact(Base):
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True)
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
id = Column(Integer, primary_key=True)
custom_mxid = Column(String, nullable=True)
id = Column(Integer, primary_key=True) # type: TelegramID
custom_mxid = Column(String, nullable=True) # type: Optional[MatrixUserID]
access_token = Column(String, nullable=True)
displayname = Column(String, nullable=True)
displayname_source = Column(Integer, nullable=True)
displayname_source = Column(Integer, nullable=True) # type: Optional[TelegramID]
username = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
is_bot = Column(Boolean, nullable=True)
@@ -153,7 +155,7 @@ class Puppet(Base):
class BotChat(Base):
query = None # type: Query
__tablename__ = "bot_chat"
id = Column(Integer, primary_key=True)
id = Column(Integer, primary_key=True) # type: TelegramID
type = Column(String, nullable=False)
@@ -26,6 +26,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
MessageEntityBotCommand, TypeMessageEntity)
from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage
@@ -55,7 +56,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
mention = self.mention_regex.match(url) # type: Match
if mention:
mxid = mention.group(1)
mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False))
if not user:
@@ -26,12 +26,13 @@ from telethon.tl.types import (MessageEntityMention as Mention,
InputMessageEntityMentionName as InputMentionName)
from ... import user as u, puppet as pu, portal as po
from ...types import MatrixUserID
from ..util import html_to_unicode
from .parser_common import MatrixParserCommon, ParsedMessage
def parse_html(html: str) -> ParsedMessage:
return MatrixParser.parse(html)
def parse_html(input_html: str) -> ParsedMessage:
return MatrixParser.parse(input_html)
class Entity:
@@ -248,7 +249,7 @@ class MatrixParser(MatrixParserCommon):
mention = cls.mention_regex.match(href)
if mention:
mxid = mention.group(1)
mxid = MatrixUserID(mention.group(1))
user = (pu.Puppet.get_by_mxid(mxid)
or u.User.get_by_mxid(mxid, create=False))
if not user:
+3 -3
View File
@@ -28,8 +28,8 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
from ..types import TelegramID
from .. import user as u, puppet as pu, portal as po
from ..types import TelegramID
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, unicode_to_html)
@@ -76,7 +76,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
fwd_from_html = f"<a href='https://matrix.to/#/{user.mxid}'>{fwd_from_text}</a>"
if not fwd_from_text:
puppet = pu.Puppet.get(fwd_from.from_id, create=False)
puppet = pu.Puppet.get(TelegramID(fwd_from.from_id), create=False)
if puppet and puppet.displayname:
fwd_from_text = puppet.displayname or puppet.mxid
fwd_from_html = f"<a href='https://matrix.to/#/{puppet.mxid}'>{fwd_from_text}</a>"
@@ -247,7 +247,7 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -
elif entity_type == MessageEntityMention:
skip_entity = _parse_mention(html, entity_text)
elif entity_type == MessageEntityMentionName:
skip_entity = _parse_name_mention(html, entity_text, entity.user_id)
skip_entity = _parse_name_mention(html, entity_text, TelegramID(entity.user_id))
elif entity_type == MessageEntityEmail:
html.append(f"<a href='mailto:{entity_text}'>{entity_text}</a>")
elif entity_type in {MessageEntityTextUrl, MessageEntityUrl}:
+1 -5
View File
@@ -25,11 +25,7 @@ from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from mautrix_appservice import AppService
from .context import Context
from sqlalchemy.orm import scoped_session
from .config import Config
from .bot import Bot
class MatrixHandler:
@@ -130,7 +126,7 @@ class MatrixHandler:
if not inviter.whitelisted:
await self.az.intent.send_notice(
room_id, text=None,
room_id, text="",
html="You are not whitelisted to use this bridge.<br/><br/>"
"If you are the owner of this bridge, see the "
"<code>bridge.permissions</code> section in your config file.")
+25 -24
View File
@@ -102,17 +102,17 @@ class Portal:
mx_alias_regex = None # type: Pattern
hs_domain = None # type: str
by_mxid = {} # type: Dict[str, Portal]
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
by_mxid = {} # type: Dict[MatrixRoomID, Portal]
by_tgid = {} # type: Dict[Tuple[TelegramID, TelegramID], Portal]
def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[int] = None,
def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[TelegramID] = None,
mxid: Optional[MatrixRoomID] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None,
db_instance: DBPortal = None) -> None:
self.mxid = mxid # type: Optional[MatrixRoomID]
self.tgid = tgid # type: TelegramID
self.tg_receiver = tg_receiver or tgid # type: int
self.tg_receiver = tg_receiver or tgid # type: TelegramID
self.peer_type = peer_type # type: str
self.username = username # type: str
self.megagroup = megagroup # type: bool
@@ -141,7 +141,7 @@ class Portal:
# region Propegrties
@property
def tgid_full(self) -> Tuple[int, int]:
def tgid_full(self) -> Tuple[TelegramID, TelegramID]:
return self.tgid, self.tg_receiver
@property
@@ -174,7 +174,7 @@ class Portal:
# endregion
# region Filtering
def allow_bridging(self, tgid: Optional[int] = None) -> bool:
def allow_bridging(self, tgid: Optional[TelegramID] = None) -> bool:
tgid = tgid or self.tgid
if self.peer_type == "user":
return True
@@ -270,8 +270,8 @@ class Portal:
else:
raise ValueError("Invalid invite identifier given to invite_matrix()")
async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: Dict = None,
async def update_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User],
direct: bool, puppet: p.Puppet = None, levels: Dict = None,
users: List[User] = None,
participants: List[TypeParticipant] = None) -> None:
if not direct:
@@ -359,7 +359,7 @@ class Portal:
if not room_id:
raise Exception(f"Failed to create room for {self.tgid_log}")
self.mxid = room_id
self.mxid = MatrixRoomID(room_id)
self.by_mxid[self.mxid] = self
self.save()
self.az.state_store.set_power_levels(self.mxid, power_levels)
@@ -420,7 +420,7 @@ class Portal:
async def sync_telegram_users(self, source: "AbstractUser", users: List[User]) -> None:
allowed_tgids = set()
for entity in users:
puppet = p.Puppet.get(entity.id)
puppet = p.Puppet.get(TelegramID(entity.id))
if entity.bot:
self.add_bot_chat(entity)
allowed_tgids.add(entity.id)
@@ -464,7 +464,7 @@ class Portal:
) -> None:
puppet = p.Puppet.get(user_id)
if source:
entity = await source.client.get_entity(PeerUser(user_id))
entity = await source.client.get_entity(PeerUser(user_id)) # type: User
await puppet.update_info(source, entity)
await puppet.intent.join_room(self.mxid)
@@ -571,8 +571,7 @@ class Portal:
return True
return False
async def _get_users(self,
user: 'AbstractUser',
async def _get_users(self, user: 'AbstractUser',
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat":
@@ -640,7 +639,8 @@ class Portal:
return []
authenticated = []
has_bot = self.has_bot
for member in members:
for member_str in members:
member = MatrixUserID(member_str)
if p.Puppet.get_id_from_mxid(member) or member == self.main_intent.mxid:
continue
user = await u.User.get_by_mxid(member).ensure_started()
@@ -657,7 +657,7 @@ class Portal:
except MatrixRequestError:
members = []
for user in members:
puppet = p.Puppet.get_by_mxid(user, create=False)
puppet = p.Puppet.get_by_mxid(MatrixUserID(user), create=False)
if user != intent.mxid and (not puppets_only or puppet):
try:
if puppet:
@@ -729,7 +729,7 @@ class Portal:
or user.mxid_localpart)
def set_typing(self, user: 'u.User', typing: bool = True,
action: type = SendMessageTypingAction) -> bool:
action: type = SendMessageTypingAction) -> Awaitable[bool]:
return user.client(SetTypingRequest(
self.peer, action() if typing else SendMessageCancelAction()))
@@ -1077,7 +1077,8 @@ class Portal:
async def _get_telegram_users_in_matrix_room(self) -> List[int]:
user_tgids = set()
user_mxids = await self.main_intent.get_room_members(self.mxid, ("join", "invite"))
for user in user_mxids:
for user_str in user_mxids:
user = MatrixUserID(user_str)
if user == self.az.bot_mxid:
continue
mx_user = u.User.get_by_mxid(user, create=False)
@@ -1101,7 +1102,7 @@ class Portal:
if not entity:
raise ValueError("Upgrade may have failed: output channel not found.")
self.peer_type = "channel"
self.migrate_and_save(entity.id)
self.migrate_and_save(TelegramID(entity.id))
await self.update_info(source, entity)
async def set_telegram_username(self, source: 'u.User', username: str) -> None:
@@ -1176,7 +1177,7 @@ class Portal:
return None
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
relates_to: Dict = {}) -> None:
relates_to: Dict = None) -> Optional[Dict]:
largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location)
@@ -1516,14 +1517,14 @@ class Portal:
await self.remove_avatar(source, save=True)
elif isinstance(action, MessageActionChatAddUser):
for user_id in action.users:
await self.add_telegram_user(user_id, source)
await self.add_telegram_user(TelegramID(user_id), source)
elif isinstance(action, MessageActionChatJoinedByLink):
await self.add_telegram_user(sender.id, source)
elif isinstance(action, MessageActionChatDeleteUser):
await self.delete_telegram_user(action.user_id, sender)
await self.delete_telegram_user(TelegramID(action.user_id), sender)
elif isinstance(action, MessageActionChatMigrateTo):
self.peer_type = "channel"
self.migrate_and_save(action.channel_id)
self.migrate_and_save(TelegramID(action.channel_id))
await sender.intent.send_emote(self.mxid, "upgraded this group to a supergroup.")
elif isinstance(action, MessageActionPinMessage):
await self.receive_telegram_pin_sender(sender)
@@ -1620,7 +1621,7 @@ class Portal:
levels["events"]["m.room.power_levels"] = admin_power_level
for participant in participants:
puppet = p.Puppet.get(participant.user_id)
puppet = p.Puppet.get(TelegramID(participant.user_id))
user = u.User.get_by_tgid(participant.user_id)
new_level = self._get_level_from_participant(participant, levels)
@@ -1792,7 +1793,7 @@ class Portal:
entity_id = entity.user_id
else:
raise ValueError(f"Unknown entity type {entity_type.__name__}")
return cls.get_by_tgid(entity_id,
return cls.get_by_tgid(TelegramID(entity_id),
receiver_id if type_name == "user" else entity_id,
type_name if create else None)
+7 -7
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
@@ -37,7 +37,6 @@ if TYPE_CHECKING:
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -87,7 +86,7 @@ class Puppet:
self.by_custom_mxid[self.custom_mxid] = self
@property
def mxid(self):
def mxid(self) -> MatrixUserID:
return self.custom_mxid or self.default_mxid
@property
@@ -109,7 +108,8 @@ class Puppet:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token: str, mxid: MatrixUserID) -> PuppetError:
async def switch_mxid(self, access_token: Optional[str],
mxid: Optional[MatrixUserID]) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
@@ -217,7 +217,7 @@ class Puppet:
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
@@ -355,10 +355,10 @@ class Puppet:
if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname
self.displayname_source = TelegramID(source.tgid)
self.displayname_source = source.tgid
return True
elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = TelegramID(source.tgid)
self.displayname_source = source.tgid
return True
return False
-1
View File
@@ -1,6 +1,5 @@
from typing import Dict, NewType
# MatrixId = NewType('MatrixId', str)
MatrixUserID = NewType('MatrixUserID', str)
MatrixRoomID = NewType('MatrixRoomID', str)
MatrixEventID = NewType('MatrixEventID', str)
+12 -9
View File
@@ -49,7 +49,8 @@ class User(AbstractUser):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: MatrixUserID
@@ -105,9 +106,11 @@ class User(AbstractUser):
@db_portals.setter
def db_portals(self, portals: List[DBPortal]) -> None:
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals} if portals else {}
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
for portal in portals
} if portals else {}
# region Database conversion
@@ -119,14 +122,14 @@ class User(AbstractUser):
def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
contacts=self.db_contacts, saved_contacts=self.saved_contacts,
portals=self.db_portals)
def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.username = self.username
self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts or 0
self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals
self.db.commit()
@@ -143,7 +146,7 @@ class User(AbstractUser):
@classmethod
def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
db_user.saved_contacts, False, db_user.portals, db_instance=db_user)
# endregion
# region Telegram connection management
@@ -182,9 +185,9 @@ class User(AbstractUser):
else:
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
elif isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
else:
return False
+18 -12
View File
@@ -15,6 +15,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from abc import abstractmethod
from typing import Optional
from aiohttp import web
import abc
import asyncio
import logging
@@ -28,22 +31,24 @@ from ...user import User
class AuthAPI(abc.ABC):
log = logging.getLogger("mau.web.auth")
log = logging.getLogger("mau.web.auth") # type: logging.Logger
def __init__(self, loop):
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # type: asyncio.AbstractEventLoop
@abstractmethod
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
errcode=""):
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
mxid: str = "", message: str = "", error: str = "",
errcode: str = "") -> web.Response:
raise NotImplementedError()
@abstractmethod
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
error="", errcode=""):
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
mxid: str = "", message: str = "", error: str = "",
errcode: str = "") -> web.Response:
raise NotImplementedError()
async def post_matrix_token(self, user: User, token):
async def post_matrix_token(self, user: User, token: str) -> web.Response:
puppet = Puppet.get(user.tgid)
if puppet.is_real_user:
return self.get_mx_login_response(state="already-logged-in", status=409,
@@ -61,11 +66,11 @@ class AuthAPI(abc.ABC):
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
async def post_matrix_password(self, user, password):
async def post_matrix_password(self, user: User, password: str) -> web.Response:
return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented",
errcode="not-yet-implemented")
async def post_login_phone(self, user, phone):
async def post_login_phone(self, user: User, phone: str) -> web.Response:
try:
await user.client.sign_in(phone or "+123")
return self.get_login_response(mxid=user.mxid, state="code", status=200,
@@ -102,7 +107,7 @@ class AuthAPI(abc.ABC):
errcode="unknown_error",
error="Internal server error while requesting code.")
async def post_login_token(self, user, token):
async def post_login_token(self, user: User, token: str) -> web.Response:
try:
user_info = await user.client.sign_in(bot_token=token)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
@@ -123,7 +128,8 @@ class AuthAPI(abc.ABC):
return self.get_login_response(mxid=user.mxid, state="token", status=500,
error="Internal server error while sending token.")
async def post_login_code(self, user, code, password_in_data):
async def post_login_code(self, user: User, code: int, password_in_data: bool
) -> Optional[web.Response]:
try:
user_info = await user.client.sign_in(code=code)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
@@ -156,7 +162,7 @@ class AuthAPI(abc.ABC):
errcode="unknown_error",
error="Internal server error while sending code.")
async def post_login_password(self, user, password):
async def post_login_password(self, user: User, password: str) -> web.Response:
try:
user_info = await user.client.sign_in(password=password)
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
@@ -35,15 +35,16 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning")
log = logging.getLogger("mau.web.provisioning") # type: logging.Logger
def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"]
self.secret = context.config["appservice.provisioning.shared_secret"] # type: str
self.az = context.az # type: AppService
self.context = context # type: Context
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware])
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]
) # type: web.Application
portal_prefix = "/portal/{mxid:![^/]+}"
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
@@ -353,7 +354,8 @@ class ProvisioningAPI(AuthAPI):
await user.log_out()
@staticmethod
async def error_middleware(_, handler) -> Callable[[web.Request], Awaitable[web.Response]]:
async def error_middleware(_, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> Callable[[web.Request], Awaitable[web.Response]]:
async def middleware_handler(request: web.Request) -> web.Response:
try:
return await handler(request)
+24 -19
View File
@@ -14,14 +14,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from aiohttp import web
from mako.template import Template
import pkg_resources
import asyncio
import logging
import random
import string
import time
from ...types import MatrixUserID
from ...util import sign_token, verify_token
from ...user import User
from ...puppet import Puppet
@@ -29,20 +32,20 @@ from ..common import AuthAPI
class PublicBridgeWebsite(AuthAPI):
log = logging.getLogger("mau.web.public")
log = logging.getLogger("mau.web.public") # type: logging.Logger
def __init__(self, loop):
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
self.secret_key = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str
self.login = Template(
pkg_resources.resource_string("mautrix_telegram", "web/public/login.html.mako"))
self.login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/login.html.mako")) # type: Template
self.mx_login = Template(
pkg_resources.resource_string("mautrix_telegram", "web/public/matrix-login.html.mako"))
self.mx_login = Template(pkg_resources.resource_string(
"mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template
self.app = web.Application(loop=loop)
self.app = web.Application(loop=loop) # type: web.Application
self.app.router.add_route("GET", "/login", self.get_login)
self.app.router.add_route("POST", "/login", self.post_login)
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
@@ -50,21 +53,21 @@ class PublicBridgeWebsite(AuthAPI):
self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram",
"web/public/"))
def make_token(self, mxid, endpoint="/login", expires_in=900):
def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str:
return sign_token(self.secret_key, {
"mxid": mxid,
"endpoint": endpoint,
"expiry": int(time.time()) + expires_in,
})
def verify_token(self, token, endpoint="/login"):
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]:
token = verify_token(self.secret_key, token)
if token and (token.get("expiry", 0) > int(time.time()) and
token.get("endpoint", None) == endpoint):
return token.get("mxid", None)
return MatrixUserID(token.get("mxid", None))
return None
async def get_login(self, request):
async def get_login(self, request: web.Request) -> web.Response:
state = "bot_token" if request.rel_url.query.get("mode", "") == "bot" else "request"
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
@@ -83,7 +86,7 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_login_response(mxid=user.mxid, username=user.username)
async def get_matrix_login(self, request):
async def get_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token")
@@ -105,19 +108,21 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_mx_login_response(mxid=user.mxid)
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
errcode=""):
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
mxid: str = "", message: str = "", error: str = "",
errcode: str = "") -> web.Response:
return web.Response(status=status, content_type="text/html",
text=self.login.render(username=username, state=state, error=error,
message=message, mxid=mxid))
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
error="", errcode=""):
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
mxid: str = "", message: str = "", error: str = "",
errcode: str = "") -> web.Response:
return web.Response(status=status, content_type="text/html",
text=self.mx_login.render(username=username, state=state, error=error,
message=message, mxid=mxid))
async def post_matrix_login(self, request):
async def post_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token")
@@ -140,7 +145,7 @@ class PublicBridgeWebsite(AuthAPI):
error="You must provide an access token or "
"password.")
async def post_login(self, request):
async def post_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
if not mxid:
return self.get_login_response(status=401, state="invalid-token")