Improve type hints and set version to 0.4.0+dev
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.4.0+dev"
|
||||
__author__ = "Tulir Asokan <tulir@maunium.net>"
|
||||
|
||||
@@ -35,10 +35,10 @@ from alchemysession import AlchemySessionContainer
|
||||
|
||||
from . import portal as po, puppet as pu, __version__
|
||||
from .db import Message as DBMessage
|
||||
from .types import TelegramID
|
||||
from .tgclient import MautrixTelegramClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types import TelegramID
|
||||
from .context import Context
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
@@ -210,13 +210,13 @@ class AbstractUser(ABC):
|
||||
|
||||
@staticmethod
|
||||
async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
|
||||
portal = po.Portal.get_by_tgid(update.channel_id)
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
|
||||
if portal and portal.mxid:
|
||||
await portal.receive_telegram_pin_id(update.id)
|
||||
|
||||
@staticmethod
|
||||
async def update_participants(update: UpdateChatParticipants) -> None:
|
||||
portal = po.Portal.get_by_tgid(update.participants.chat_id)
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.participants.chat_id))
|
||||
if portal and portal.mxid:
|
||||
await portal.update_telegram_participants(update.participants.participants)
|
||||
|
||||
@@ -225,7 +225,7 @@ class AbstractUser(ABC):
|
||||
self.log.debug("Unexpected read receipt peer: %s", update.peer)
|
||||
return
|
||||
|
||||
portal = po.Portal.get_by_tgid(update.peer.user_id, self.tgid)
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.peer.user_id), self.tgid)
|
||||
if not portal or not portal.mxid:
|
||||
return
|
||||
|
||||
@@ -234,38 +234,38 @@ class AbstractUser(ABC):
|
||||
if not message:
|
||||
return
|
||||
|
||||
puppet = pu.Puppet.get(update.peer.user_id)
|
||||
puppet = pu.Puppet.get(TelegramID(update.peer.user_id))
|
||||
await puppet.intent.mark_read(portal.mxid, message.mxid)
|
||||
|
||||
async def update_admin(self,
|
||||
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
|
||||
# TODO duplication not checked
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
|
||||
if not portal or not portal.mxid:
|
||||
return
|
||||
|
||||
if isinstance(update, UpdateChatAdmins):
|
||||
await portal.set_telegram_admins_enabled(update.enabled)
|
||||
elif isinstance(update, UpdateChatParticipantAdmin):
|
||||
await portal.set_telegram_admin(update.user_id)
|
||||
await portal.set_telegram_admin(TelegramID(update.user_id))
|
||||
else:
|
||||
self.log.warning("Unexpected admin status update: %s", update)
|
||||
|
||||
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
|
||||
if isinstance(update, UpdateUserTyping):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
|
||||
else:
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
|
||||
|
||||
if not portal or not portal.mxid:
|
||||
return
|
||||
|
||||
sender = pu.Puppet.get(update.user_id)
|
||||
sender = pu.Puppet.get(TelegramID(update.user_id))
|
||||
await portal.handle_telegram_typing(sender, update)
|
||||
|
||||
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
|
||||
# TODO duplication not checked
|
||||
puppet = pu.Puppet.get(update.user_id)
|
||||
puppet = pu.Puppet.get(TelegramID(update.user_id))
|
||||
if isinstance(update, UpdateUserName):
|
||||
if await puppet.update_displayname(self, update):
|
||||
puppet.save()
|
||||
@@ -276,7 +276,7 @@ class AbstractUser(ABC):
|
||||
self.log.warning("Unexpected other user info update: %s", update)
|
||||
|
||||
async def update_status(self, update: UpdateUserStatus) -> None:
|
||||
puppet = pu.Puppet.get(update.user_id)
|
||||
puppet = pu.Puppet.get(TelegramID(update.user_id))
|
||||
if isinstance(update.status, UserStatusOnline):
|
||||
await puppet.default_mxid_intent.set_presence("online")
|
||||
elif isinstance(update.status, UserStatusOffline):
|
||||
@@ -289,10 +289,10 @@ class AbstractUser(ABC):
|
||||
Optional[pu.Puppet],
|
||||
Optional[po.Portal]]:
|
||||
if isinstance(update, UpdateShortChatMessage):
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
sender = pu.Puppet.get(update.from_id)
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
|
||||
sender = pu.Puppet.get(TelegramID(update.from_id))
|
||||
elif isinstance(update, UpdateShortMessage):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
|
||||
sender = pu.Puppet.get(self.tgid if update.out else update.user_id)
|
||||
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage,
|
||||
UpdateEditMessage, UpdateEditChannelMessage)):
|
||||
@@ -338,7 +338,7 @@ class AbstractUser(ABC):
|
||||
if len(update.messages) > MAX_DELETIONS:
|
||||
return
|
||||
|
||||
portal = po.Portal.get_by_tgid(update.channel_id)
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.channel_id))
|
||||
if not portal:
|
||||
return
|
||||
|
||||
|
||||
+29
-26
@@ -61,15 +61,15 @@ class Bot(AbstractUser):
|
||||
|
||||
async def init_permissions(self) -> None:
|
||||
whitelist = config["bridge.relaybot.whitelist"] or []
|
||||
for id in whitelist:
|
||||
if isinstance(id, str):
|
||||
entity = await self.client.get_input_entity(id)
|
||||
for user_id in whitelist:
|
||||
if isinstance(user_id, str):
|
||||
entity = await self.client.get_input_entity(user_id)
|
||||
if isinstance(entity, InputUser):
|
||||
id = entity.user_id
|
||||
user_id = entity.user_id
|
||||
else:
|
||||
id = None
|
||||
if isinstance(id, int):
|
||||
self.tg_whitelist.append(id)
|
||||
user_id = None
|
||||
if isinstance(user_id, int):
|
||||
self.tg_whitelist.append(user_id)
|
||||
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
|
||||
await super().start(delete_unless_authenticated)
|
||||
@@ -85,20 +85,20 @@ class Bot(AbstractUser):
|
||||
self.username = info.username
|
||||
self.mxid = pu.Puppet.get_mxid_from_id(self.tgid)
|
||||
|
||||
chat_ids = [id for id, type in self.chats.items() if type == "chat"]
|
||||
chat_ids = [chat_id for chat_id, chat_type in self.chats.items() if chat_type == "chat"]
|
||||
response = await self.client(GetChatsRequest(chat_ids))
|
||||
for chat in response.chats:
|
||||
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
|
||||
self.remove_chat(chat.id)
|
||||
|
||||
channel_ids = [InputChannel(id, 0)
|
||||
for id, type in self.chats.items()
|
||||
if type == "channel"]
|
||||
for id in channel_ids:
|
||||
channel_ids = [InputChannel(chat_id, 0)
|
||||
for chat_id, chat_type in self.chats.items()
|
||||
if chat_type == "channel"]
|
||||
for channel_id in channel_ids:
|
||||
try:
|
||||
await self.client(GetChannelsRequest([id]))
|
||||
await self.client(GetChannelsRequest([channel_id]))
|
||||
except (ChannelPrivateError, ChannelInvalidError):
|
||||
self.remove_chat(id.channel_id)
|
||||
self.remove_chat(channel_id.channel_id)
|
||||
|
||||
if config["bridge.catch_up"]:
|
||||
try:
|
||||
@@ -112,18 +112,18 @@ class Bot(AbstractUser):
|
||||
def unregister_portal(self, portal: po.Portal) -> None:
|
||||
self.remove_chat(portal.tgid)
|
||||
|
||||
def add_chat(self, id: int, type: str) -> None:
|
||||
if id not in self.chats:
|
||||
self.chats[id] = type
|
||||
self.db.add(BotChat(id=id, type=type))
|
||||
def add_chat(self, chat_id: int, chat_type: str) -> None:
|
||||
if chat_id not in self.chats:
|
||||
self.chats[chat_id] = chat_type
|
||||
self.db.add(BotChat(id=chat_id, type=chat_type))
|
||||
self.db.commit()
|
||||
|
||||
def remove_chat(self, id: int) -> None:
|
||||
def remove_chat(self, chat_id: int) -> None:
|
||||
try:
|
||||
del self.chats[id]
|
||||
del self.chats[chat_id]
|
||||
except KeyError:
|
||||
pass
|
||||
existing_chat = BotChat.query.get(id)
|
||||
existing_chat = BotChat.query.get(chat_id)
|
||||
if existing_chat:
|
||||
self.db.delete(existing_chat)
|
||||
self.db.commit()
|
||||
@@ -191,7 +191,8 @@ class Bot(AbstractUser):
|
||||
await portal.main_intent.invite(portal.mxid, user.mxid)
|
||||
return await reply(f"Invited `{user.mxid}` to the portal.")
|
||||
|
||||
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
|
||||
@staticmethod
|
||||
def handle_command_id(message: Message, reply: ReplyFunc) -> Awaitable[Message]:
|
||||
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
|
||||
# chat is a normal group or a supergroup/channel when using the ID.
|
||||
if isinstance(message.to_id, PeerChannel):
|
||||
@@ -241,16 +242,16 @@ class Bot(AbstractUser):
|
||||
to_id = message.to_id
|
||||
if isinstance(to_id, PeerChannel):
|
||||
to_id = to_id.channel_id
|
||||
type = "channel"
|
||||
chat_type = "channel"
|
||||
elif isinstance(to_id, PeerChat):
|
||||
to_id = to_id.chat_id
|
||||
type = "chat"
|
||||
chat_type = "chat"
|
||||
else:
|
||||
return
|
||||
|
||||
action = message.action
|
||||
if isinstance(action, MessageActionChatAddUser) and self.tgid in action.users:
|
||||
self.add_chat(to_id, type)
|
||||
self.add_chat(to_id, chat_type)
|
||||
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
|
||||
self.remove_chat(to_id)
|
||||
|
||||
@@ -265,7 +266,9 @@ class Bot(AbstractUser):
|
||||
and update.message.entities and len(update.message.entities) > 0
|
||||
and isinstance(update.message.entities[0], MessageEntityBotCommand))
|
||||
if is_command:
|
||||
return await self.handle_command(update.message)
|
||||
await self.handle_command(update.message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_in_chat(self, peer_id) -> bool:
|
||||
return peer_id in self.chats
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
import asyncio
|
||||
|
||||
from telethon.errors import (
|
||||
|
||||
@@ -33,7 +33,8 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[Mat
|
||||
empty_portals = [] # type: List[po.Portal]
|
||||
|
||||
rooms = await intent.get_joined_rooms()
|
||||
for room in rooms:
|
||||
for room_str in rooms:
|
||||
room = MatrixRoomID(room_str)
|
||||
portal = po.Portal.get_by_mxid(room)
|
||||
if not portal:
|
||||
try:
|
||||
@@ -41,7 +42,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[Mat
|
||||
except MatrixRequestError:
|
||||
members = []
|
||||
if len(members) == 2:
|
||||
other_member = members[0] if members[0] != intent.mxid else members[1]
|
||||
other_member = MatrixUserID(members[0] if members[0] != intent.mxid else members[1])
|
||||
if pu.Puppet.get_id_from_mxid(other_member):
|
||||
unidentified_rooms.append(room)
|
||||
else:
|
||||
@@ -128,9 +129,9 @@ async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
|
||||
rooms_to_clean += empty_portals
|
||||
elif command == "clean-range":
|
||||
try:
|
||||
range = evt.args[1]
|
||||
group, range = range[0], range[1:]
|
||||
start, end = range.split("-")
|
||||
clean_range = evt.args[1]
|
||||
group, clean_range = clean_range[0], clean_range[1:]
|
||||
start, end = clean_range.split("-")
|
||||
start, end = int(start), int(end)
|
||||
if group == "M":
|
||||
group = [room_id for (room_id, user_id) in management_rooms]
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
|
||||
from collections import namedtuple
|
||||
from typing import Awaitable, Callable, Dict, List, NamedTuple, Optional
|
||||
import markdown
|
||||
import logging
|
||||
|
||||
@@ -160,16 +159,16 @@ class CommandProcessor:
|
||||
orig_command = command
|
||||
command = command.lower()
|
||||
try:
|
||||
command_handler = command_handlers[command]
|
||||
handler = command_handlers[command]
|
||||
except KeyError:
|
||||
if sender.command_status and "next" in sender.command_status:
|
||||
args.insert(0, orig_command)
|
||||
evt.command = ""
|
||||
command_handler = sender.command_status["next"]
|
||||
handler = sender.command_status["next"]
|
||||
else:
|
||||
command_handler = command_handlers["unknown-command"]
|
||||
handler = command_handlers["unknown-command"]
|
||||
try:
|
||||
await command_handler(evt)
|
||||
await handler(evt)
|
||||
except FloodWaitError as e:
|
||||
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
|
||||
except Exception:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
|
||||
from typing import Dict, Callable, Optional, Tuple, Coroutine
|
||||
import asyncio
|
||||
|
||||
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
|
||||
@@ -85,18 +85,20 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
|
||||
|
||||
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
|
||||
action: Optional[str] = None
|
||||
) -> Tuple[Union[Dict, po.Portal], bool]:
|
||||
) -> Optional[po.Portal]:
|
||||
room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
|
||||
|
||||
portal = po.Portal.get_by_mxid(room_id)
|
||||
if not portal:
|
||||
that_this = "This" if room_id == evt.room_id else "That"
|
||||
return await evt.reply(f"{that_this} is not a portal room."), False
|
||||
await evt.reply(f"{that_this} is not a portal room.")
|
||||
return None
|
||||
|
||||
if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, permission):
|
||||
action = action or f"{permission.replace('_', ' ')}s"
|
||||
return await evt.reply(f"You do not have the permissions to {action} that portal."), False
|
||||
return portal, True
|
||||
await evt.reply(f"You do not have the permissions to {action} that portal.")
|
||||
return None
|
||||
return portal
|
||||
|
||||
|
||||
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
|
||||
@@ -123,10 +125,9 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
|
||||
"Only works for group chats; to delete a private chat portal, simply "
|
||||
"leave the room.")
|
||||
async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
|
||||
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not ok:
|
||||
portal = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not portal:
|
||||
return None
|
||||
portal = cast('po.Portal', result)
|
||||
|
||||
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
|
||||
portal.cleanup_and_delete, "delete",
|
||||
@@ -145,10 +146,9 @@ async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
|
||||
help_section=SECTION_PORTAL_MANAGEMENT,
|
||||
help_text="Remove puppets from the current portal room and forget the portal.")
|
||||
async def unbridge(evt: CommandEvent) -> Optional[Dict]:
|
||||
result, ok = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not ok:
|
||||
portal = await _get_portal_and_check_permission(evt, "unbridge")
|
||||
if not portal:
|
||||
return None
|
||||
portal = cast('po.Portal', result)
|
||||
|
||||
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
|
||||
portal.unbridge, "unbridge",
|
||||
@@ -231,7 +231,7 @@ async def bridge(evt: CommandEvent) -> Dict:
|
||||
|
||||
|
||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
|
||||
) -> Tuple[bool, Coroutine[None, None, None]]:
|
||||
) -> Tuple[bool, Optional[Coroutine[None, None, None]]]:
|
||||
if not portal.mxid:
|
||||
await evt.reply("The portal seems to have lost its Matrix room between you"
|
||||
"calling `$cmdprefix+sp bridge` and this command.\n\n"
|
||||
|
||||
@@ -92,7 +92,7 @@ async def private_message(evt: CommandEvent) -> Optional[Dict]:
|
||||
f"{pu.Puppet.get_displayname(user, False)}")
|
||||
|
||||
|
||||
async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
|
||||
async def _join(evt: CommandEvent, arg: str) -> Tuple[Optional[TypeUpdates], Optional[Dict]]:
|
||||
if arg.startswith("joinchat/"):
|
||||
invite_hash = arg[len("joinchat/"):]
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
+26
-24
@@ -14,14 +14,15 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict
|
||||
|
||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||
BigInteger, String, Boolean, Text)
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from typing import Dict, Optional, List
|
||||
import json
|
||||
|
||||
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
|
||||
from .types import TelegramID
|
||||
from .base import Base
|
||||
|
||||
|
||||
@@ -30,13 +31,13 @@ class Portal(Base):
|
||||
__tablename__ = "portal"
|
||||
|
||||
# Telegram chat information
|
||||
tgid = Column(Integer, primary_key=True)
|
||||
tg_receiver = Column(Integer, primary_key=True)
|
||||
tgid = Column(Integer, primary_key=True) # type: TelegramID
|
||||
tg_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
peer_type = Column(String, nullable=False)
|
||||
megagroup = Column(Boolean)
|
||||
|
||||
# Matrix portal information
|
||||
mxid = Column(String, unique=True, nullable=True)
|
||||
mxid = Column(String, unique=True, nullable=True) # type: Optional[MatrixRoomID]
|
||||
|
||||
# Telegram chat metadata
|
||||
username = Column(String, nullable=True)
|
||||
@@ -49,10 +50,10 @@ class Message(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid = Column(String)
|
||||
mx_room = Column(String)
|
||||
tgid = Column(Integer, primary_key=True)
|
||||
tg_space = Column(Integer, primary_key=True)
|
||||
mxid = Column(String) # type: MatrixEventID
|
||||
mx_room = Column(String) # type: MatrixRoomID
|
||||
tgid = Column(Integer, primary_key=True) # type: TelegramID
|
||||
tg_space = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
||||
|
||||
@@ -62,9 +63,9 @@ class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True)
|
||||
portal = Column(Integer, primary_key=True)
|
||||
portal_receiver = Column(Integer, primary_key=True)
|
||||
primary_key=True) # type: TelegramID
|
||||
portal = Column(Integer, primary_key=True) # type: TelegramID
|
||||
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
@@ -75,12 +76,13 @@ class User(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "user"
|
||||
|
||||
mxid = Column(String, primary_key=True)
|
||||
tgid = Column(Integer, nullable=True, unique=True)
|
||||
mxid = Column(String, primary_key=True) # type: MatrixUserID
|
||||
tgid = Column(Integer, nullable=True, unique=True) # type: Optional[TelegramID]
|
||||
tg_username = Column(String, nullable=True)
|
||||
saved_contacts = Column(Integer, default=0, nullable=False)
|
||||
contacts = relationship("Contact", uselist=True,
|
||||
cascade="save-update, merge, delete, delete-orphan")
|
||||
cascade="save-update, merge, delete, delete-orphan"
|
||||
) # type: List[Contact]
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
|
||||
@@ -88,7 +90,7 @@ class RoomState(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "mx_room_state"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
_power_levels_text = Column("power_levels", Text, nullable=True)
|
||||
_power_levels_json = {} # type: Dict
|
||||
|
||||
@@ -112,8 +114,8 @@ class UserProfile(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "mx_user_profile"
|
||||
|
||||
room_id = Column(String, primary_key=True)
|
||||
user_id = Column(String, primary_key=True)
|
||||
room_id = Column(String, primary_key=True) # type: MatrixRoomID
|
||||
user_id = Column(String, primary_key=True) # type: MatrixUserID
|
||||
membership = Column(String, nullable=False, default="leave")
|
||||
displayname = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
@@ -130,19 +132,19 @@ class Contact(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "contact"
|
||||
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True)
|
||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
|
||||
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
custom_mxid = Column(String, nullable=True)
|
||||
id = Column(Integer, primary_key=True) # type: TelegramID
|
||||
custom_mxid = Column(String, nullable=True) # type: Optional[MatrixUserID]
|
||||
access_token = Column(String, nullable=True)
|
||||
displayname = Column(String, nullable=True)
|
||||
displayname_source = Column(Integer, nullable=True)
|
||||
displayname_source = Column(Integer, nullable=True) # type: Optional[TelegramID]
|
||||
username = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
is_bot = Column(Boolean, nullable=True)
|
||||
@@ -153,7 +155,7 @@ class Puppet(Base):
|
||||
class BotChat(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "bot_chat"
|
||||
id = Column(Integer, primary_key=True)
|
||||
id = Column(Integer, primary_key=True) # type: TelegramID
|
||||
type = Column(String, nullable=False)
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
|
||||
MessageEntityBotCommand, TypeMessageEntity)
|
||||
|
||||
from ... import user as u, puppet as pu, portal as po
|
||||
from ...types import MatrixUserID
|
||||
from ..util import html_to_unicode
|
||||
from .parser_common import MatrixParserCommon, ParsedMessage
|
||||
|
||||
@@ -55,7 +56,7 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
|
||||
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
|
||||
mention = self.mention_regex.match(url) # type: Match
|
||||
if mention:
|
||||
mxid = mention.group(1)
|
||||
mxid = MatrixUserID(mention.group(1))
|
||||
user = (pu.Puppet.get_by_mxid(mxid)
|
||||
or u.User.get_by_mxid(mxid, create=False))
|
||||
if not user:
|
||||
|
||||
@@ -26,12 +26,13 @@ from telethon.tl.types import (MessageEntityMention as Mention,
|
||||
InputMessageEntityMentionName as InputMentionName)
|
||||
|
||||
from ... import user as u, puppet as pu, portal as po
|
||||
from ...types import MatrixUserID
|
||||
from ..util import html_to_unicode
|
||||
from .parser_common import MatrixParserCommon, ParsedMessage
|
||||
|
||||
|
||||
def parse_html(html: str) -> ParsedMessage:
|
||||
return MatrixParser.parse(html)
|
||||
def parse_html(input_html: str) -> ParsedMessage:
|
||||
return MatrixParser.parse(input_html)
|
||||
|
||||
|
||||
class Entity:
|
||||
@@ -248,7 +249,7 @@ class MatrixParser(MatrixParserCommon):
|
||||
|
||||
mention = cls.mention_regex.match(href)
|
||||
if mention:
|
||||
mxid = mention.group(1)
|
||||
mxid = MatrixUserID(mention.group(1))
|
||||
user = (pu.Puppet.get_by_mxid(mxid)
|
||||
or u.User.get_by_mxid(mxid, create=False))
|
||||
if not user:
|
||||
|
||||
@@ -28,8 +28,8 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
from mautrix_appservice.intent_api import IntentAPI
|
||||
|
||||
from ..types import TelegramID
|
||||
from .. import user as u, puppet as pu, portal as po
|
||||
from ..types import TelegramID
|
||||
from ..db import Message as DBMessage
|
||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||
trim_reply_fallback_text, unicode_to_html)
|
||||
@@ -76,7 +76,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
|
||||
fwd_from_html = f"<a href='https://matrix.to/#/{user.mxid}'>{fwd_from_text}</a>"
|
||||
|
||||
if not fwd_from_text:
|
||||
puppet = pu.Puppet.get(fwd_from.from_id, create=False)
|
||||
puppet = pu.Puppet.get(TelegramID(fwd_from.from_id), create=False)
|
||||
if puppet and puppet.displayname:
|
||||
fwd_from_text = puppet.displayname or puppet.mxid
|
||||
fwd_from_html = f"<a href='https://matrix.to/#/{puppet.mxid}'>{fwd_from_text}</a>"
|
||||
@@ -247,7 +247,7 @@ def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -
|
||||
elif entity_type == MessageEntityMention:
|
||||
skip_entity = _parse_mention(html, entity_text)
|
||||
elif entity_type == MessageEntityMentionName:
|
||||
skip_entity = _parse_name_mention(html, entity_text, entity.user_id)
|
||||
skip_entity = _parse_name_mention(html, entity_text, TelegramID(entity.user_id))
|
||||
elif entity_type == MessageEntityEmail:
|
||||
html.append(f"<a href='mailto:{entity_text}'>{entity_text}</a>")
|
||||
elif entity_type in {MessageEntityTextUrl, MessageEntityUrl}:
|
||||
|
||||
@@ -25,11 +25,7 @@ from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
|
||||
from . import user as u, portal as po, puppet as pu, commands as com
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mautrix_appservice import AppService
|
||||
from .context import Context
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from .config import Config
|
||||
from .bot import Bot
|
||||
|
||||
|
||||
class MatrixHandler:
|
||||
@@ -130,7 +126,7 @@ class MatrixHandler:
|
||||
|
||||
if not inviter.whitelisted:
|
||||
await self.az.intent.send_notice(
|
||||
room_id, text=None,
|
||||
room_id, text="",
|
||||
html="You are not whitelisted to use this bridge.<br/><br/>"
|
||||
"If you are the owner of this bridge, see the "
|
||||
"<code>bridge.permissions</code> section in your config file.")
|
||||
|
||||
+25
-24
@@ -102,17 +102,17 @@ class Portal:
|
||||
mx_alias_regex = None # type: Pattern
|
||||
hs_domain = None # type: str
|
||||
|
||||
by_mxid = {} # type: Dict[str, Portal]
|
||||
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
|
||||
by_mxid = {} # type: Dict[MatrixRoomID, Portal]
|
||||
by_tgid = {} # type: Dict[Tuple[TelegramID, TelegramID], Portal]
|
||||
|
||||
def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[int] = None,
|
||||
def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[TelegramID] = None,
|
||||
mxid: Optional[MatrixRoomID] = None, username: Optional[str] = None,
|
||||
megagroup: Optional[bool] = False, title: Optional[str] = None,
|
||||
about: Optional[str] = None, photo_id: Optional[str] = None,
|
||||
db_instance: DBPortal = None) -> None:
|
||||
self.mxid = mxid # type: Optional[MatrixRoomID]
|
||||
self.tgid = tgid # type: TelegramID
|
||||
self.tg_receiver = tg_receiver or tgid # type: int
|
||||
self.tg_receiver = tg_receiver or tgid # type: TelegramID
|
||||
self.peer_type = peer_type # type: str
|
||||
self.username = username # type: str
|
||||
self.megagroup = megagroup # type: bool
|
||||
@@ -141,7 +141,7 @@ class Portal:
|
||||
# region Propegrties
|
||||
|
||||
@property
|
||||
def tgid_full(self) -> Tuple[int, int]:
|
||||
def tgid_full(self) -> Tuple[TelegramID, TelegramID]:
|
||||
return self.tgid, self.tg_receiver
|
||||
|
||||
@property
|
||||
@@ -174,7 +174,7 @@ class Portal:
|
||||
# endregion
|
||||
# region Filtering
|
||||
|
||||
def allow_bridging(self, tgid: Optional[int] = None) -> bool:
|
||||
def allow_bridging(self, tgid: Optional[TelegramID] = None) -> bool:
|
||||
tgid = tgid or self.tgid
|
||||
if self.peer_type == "user":
|
||||
return True
|
||||
@@ -270,8 +270,8 @@ class Portal:
|
||||
else:
|
||||
raise ValueError("Invalid invite identifier given to invite_matrix()")
|
||||
|
||||
async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
|
||||
puppet: p.Puppet = None, levels: Dict = None,
|
||||
async def update_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User],
|
||||
direct: bool, puppet: p.Puppet = None, levels: Dict = None,
|
||||
users: List[User] = None,
|
||||
participants: List[TypeParticipant] = None) -> None:
|
||||
if not direct:
|
||||
@@ -359,7 +359,7 @@ class Portal:
|
||||
if not room_id:
|
||||
raise Exception(f"Failed to create room for {self.tgid_log}")
|
||||
|
||||
self.mxid = room_id
|
||||
self.mxid = MatrixRoomID(room_id)
|
||||
self.by_mxid[self.mxid] = self
|
||||
self.save()
|
||||
self.az.state_store.set_power_levels(self.mxid, power_levels)
|
||||
@@ -420,7 +420,7 @@ class Portal:
|
||||
async def sync_telegram_users(self, source: "AbstractUser", users: List[User]) -> None:
|
||||
allowed_tgids = set()
|
||||
for entity in users:
|
||||
puppet = p.Puppet.get(entity.id)
|
||||
puppet = p.Puppet.get(TelegramID(entity.id))
|
||||
if entity.bot:
|
||||
self.add_bot_chat(entity)
|
||||
allowed_tgids.add(entity.id)
|
||||
@@ -464,7 +464,7 @@ class Portal:
|
||||
) -> None:
|
||||
puppet = p.Puppet.get(user_id)
|
||||
if source:
|
||||
entity = await source.client.get_entity(PeerUser(user_id))
|
||||
entity = await source.client.get_entity(PeerUser(user_id)) # type: User
|
||||
await puppet.update_info(source, entity)
|
||||
await puppet.intent.join_room(self.mxid)
|
||||
|
||||
@@ -571,8 +571,7 @@ class Portal:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_users(self,
|
||||
user: 'AbstractUser',
|
||||
async def _get_users(self, user: 'AbstractUser',
|
||||
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
|
||||
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
|
||||
if self.peer_type == "chat":
|
||||
@@ -640,7 +639,8 @@ class Portal:
|
||||
return []
|
||||
authenticated = []
|
||||
has_bot = self.has_bot
|
||||
for member in members:
|
||||
for member_str in members:
|
||||
member = MatrixUserID(member_str)
|
||||
if p.Puppet.get_id_from_mxid(member) or member == self.main_intent.mxid:
|
||||
continue
|
||||
user = await u.User.get_by_mxid(member).ensure_started()
|
||||
@@ -657,7 +657,7 @@ class Portal:
|
||||
except MatrixRequestError:
|
||||
members = []
|
||||
for user in members:
|
||||
puppet = p.Puppet.get_by_mxid(user, create=False)
|
||||
puppet = p.Puppet.get_by_mxid(MatrixUserID(user), create=False)
|
||||
if user != intent.mxid and (not puppets_only or puppet):
|
||||
try:
|
||||
if puppet:
|
||||
@@ -729,7 +729,7 @@ class Portal:
|
||||
or user.mxid_localpart)
|
||||
|
||||
def set_typing(self, user: 'u.User', typing: bool = True,
|
||||
action: type = SendMessageTypingAction) -> bool:
|
||||
action: type = SendMessageTypingAction) -> Awaitable[bool]:
|
||||
return user.client(SetTypingRequest(
|
||||
self.peer, action() if typing else SendMessageCancelAction()))
|
||||
|
||||
@@ -1077,7 +1077,8 @@ class Portal:
|
||||
async def _get_telegram_users_in_matrix_room(self) -> List[int]:
|
||||
user_tgids = set()
|
||||
user_mxids = await self.main_intent.get_room_members(self.mxid, ("join", "invite"))
|
||||
for user in user_mxids:
|
||||
for user_str in user_mxids:
|
||||
user = MatrixUserID(user_str)
|
||||
if user == self.az.bot_mxid:
|
||||
continue
|
||||
mx_user = u.User.get_by_mxid(user, create=False)
|
||||
@@ -1101,7 +1102,7 @@ class Portal:
|
||||
if not entity:
|
||||
raise ValueError("Upgrade may have failed: output channel not found.")
|
||||
self.peer_type = "channel"
|
||||
self.migrate_and_save(entity.id)
|
||||
self.migrate_and_save(TelegramID(entity.id))
|
||||
await self.update_info(source, entity)
|
||||
|
||||
async def set_telegram_username(self, source: 'u.User', username: str) -> None:
|
||||
@@ -1176,7 +1177,7 @@ class Portal:
|
||||
return None
|
||||
|
||||
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
|
||||
relates_to: Dict = {}) -> None:
|
||||
relates_to: Dict = None) -> Optional[Dict]:
|
||||
largest_size = self._get_largest_photo_size(evt.media.photo)
|
||||
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
|
||||
largest_size.location)
|
||||
@@ -1516,14 +1517,14 @@ class Portal:
|
||||
await self.remove_avatar(source, save=True)
|
||||
elif isinstance(action, MessageActionChatAddUser):
|
||||
for user_id in action.users:
|
||||
await self.add_telegram_user(user_id, source)
|
||||
await self.add_telegram_user(TelegramID(user_id), source)
|
||||
elif isinstance(action, MessageActionChatJoinedByLink):
|
||||
await self.add_telegram_user(sender.id, source)
|
||||
elif isinstance(action, MessageActionChatDeleteUser):
|
||||
await self.delete_telegram_user(action.user_id, sender)
|
||||
await self.delete_telegram_user(TelegramID(action.user_id), sender)
|
||||
elif isinstance(action, MessageActionChatMigrateTo):
|
||||
self.peer_type = "channel"
|
||||
self.migrate_and_save(action.channel_id)
|
||||
self.migrate_and_save(TelegramID(action.channel_id))
|
||||
await sender.intent.send_emote(self.mxid, "upgraded this group to a supergroup.")
|
||||
elif isinstance(action, MessageActionPinMessage):
|
||||
await self.receive_telegram_pin_sender(sender)
|
||||
@@ -1620,7 +1621,7 @@ class Portal:
|
||||
levels["events"]["m.room.power_levels"] = admin_power_level
|
||||
|
||||
for participant in participants:
|
||||
puppet = p.Puppet.get(participant.user_id)
|
||||
puppet = p.Puppet.get(TelegramID(participant.user_id))
|
||||
user = u.User.get_by_tgid(participant.user_id)
|
||||
new_level = self._get_level_from_participant(participant, levels)
|
||||
|
||||
@@ -1792,7 +1793,7 @@ class Portal:
|
||||
entity_id = entity.user_id
|
||||
else:
|
||||
raise ValueError(f"Unknown entity type {entity_type.__name__}")
|
||||
return cls.get_by_tgid(entity_id,
|
||||
return cls.get_by_tgid(TelegramID(entity_id),
|
||||
receiver_id if type_name == "user" else entity_id,
|
||||
type_name if create else None)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
|
||||
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, TYPE_CHECKING
|
||||
from difflib import SequenceMatcher
|
||||
import re
|
||||
import logging
|
||||
@@ -37,7 +37,6 @@ if TYPE_CHECKING:
|
||||
from . import user as u
|
||||
from .abstract_user import AbstractUser
|
||||
|
||||
|
||||
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
|
||||
|
||||
config = None # type: Config
|
||||
@@ -87,7 +86,7 @@ class Puppet:
|
||||
self.by_custom_mxid[self.custom_mxid] = self
|
||||
|
||||
@property
|
||||
def mxid(self):
|
||||
def mxid(self) -> MatrixUserID:
|
||||
return self.custom_mxid or self.default_mxid
|
||||
|
||||
@property
|
||||
@@ -109,7 +108,8 @@ class Puppet:
|
||||
return (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
|
||||
async def switch_mxid(self, access_token: str, mxid: MatrixUserID) -> PuppetError:
|
||||
async def switch_mxid(self, access_token: Optional[str],
|
||||
mxid: Optional[MatrixUserID]) -> PuppetError:
|
||||
prev_mxid = self.custom_mxid
|
||||
self.custom_mxid = mxid
|
||||
self.access_token = access_token
|
||||
@@ -217,7 +217,7 @@ class Puppet:
|
||||
for events in ephemeral.values()
|
||||
for event in self.filter_events(events)]
|
||||
|
||||
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
|
||||
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
|
||||
coro = asyncio.gather(*events, loop=self.loop)
|
||||
asyncio.ensure_future(coro, loop=self.loop)
|
||||
|
||||
@@ -355,10 +355,10 @@ class Puppet:
|
||||
if displayname != self.displayname:
|
||||
await self.default_mxid_intent.set_display_name(displayname)
|
||||
self.displayname = displayname
|
||||
self.displayname_source = TelegramID(source.tgid)
|
||||
self.displayname_source = source.tgid
|
||||
return True
|
||||
elif source.is_relaybot or self.displayname_source is None:
|
||||
self.displayname_source = TelegramID(source.tgid)
|
||||
self.displayname_source = source.tgid
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Dict, NewType
|
||||
|
||||
# MatrixId = NewType('MatrixId', str)
|
||||
MatrixUserID = NewType('MatrixUserID', str)
|
||||
MatrixRoomID = NewType('MatrixRoomID', str)
|
||||
MatrixEventID = NewType('MatrixEventID', str)
|
||||
|
||||
@@ -49,7 +49,8 @@ class User(AbstractUser):
|
||||
|
||||
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
|
||||
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
|
||||
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
|
||||
saved_contacts: int = 0, is_bot: bool = False,
|
||||
db_portals: Optional[List[DBPortal]] = None,
|
||||
db_instance: Optional[DBUser] = None) -> None:
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: MatrixUserID
|
||||
@@ -105,9 +106,11 @@ class User(AbstractUser):
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals: List[DBPortal]) -> None:
|
||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||
for portal in portals} if portals else {}
|
||||
self.portals = {
|
||||
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
|
||||
portal.tg_receiver)
|
||||
for portal in portals
|
||||
} if portals else {}
|
||||
|
||||
# region Database conversion
|
||||
|
||||
@@ -119,14 +122,14 @@ class User(AbstractUser):
|
||||
|
||||
def new_db_instance(self) -> DBUser:
|
||||
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
||||
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
|
||||
contacts=self.db_contacts, saved_contacts=self.saved_contacts,
|
||||
portals=self.db_portals)
|
||||
|
||||
def save(self) -> None:
|
||||
self.db_instance.tgid = self.tgid
|
||||
self.db_instance.username = self.username
|
||||
self.db_instance.contacts = self.db_contacts
|
||||
self.db_instance.saved_contacts = self.saved_contacts or 0
|
||||
self.db_instance.saved_contacts = self.saved_contacts
|
||||
self.db_instance.portals = self.db_portals
|
||||
self.db.commit()
|
||||
|
||||
@@ -143,7 +146,7 @@ class User(AbstractUser):
|
||||
@classmethod
|
||||
def from_db(cls, db_user: DBUser) -> 'User':
|
||||
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
|
||||
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
|
||||
db_user.saved_contacts, False, db_user.portals, db_instance=db_user)
|
||||
|
||||
# endregion
|
||||
# region Telegram connection management
|
||||
@@ -182,9 +185,9 @@ class User(AbstractUser):
|
||||
else:
|
||||
portal = po.Portal.get_by_entity(message.to_id, receiver_id=self.tgid)
|
||||
elif isinstance(update, UpdateShortChatMessage):
|
||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id), peer_type="chat")
|
||||
elif isinstance(update, UpdateShortMessage):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
@@ -28,22 +31,24 @@ from ...user import User
|
||||
|
||||
|
||||
class AuthAPI(abc.ABC):
|
||||
log = logging.getLogger("mau.web.auth")
|
||||
log = logging.getLogger("mau.web.auth") # type: logging.Logger
|
||||
|
||||
def __init__(self, loop):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self.loop = loop # type: asyncio.AbstractEventLoop
|
||||
|
||||
@abstractmethod
|
||||
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
|
||||
errcode=""):
|
||||
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
|
||||
mxid: str = "", message: str = "", error: str = "",
|
||||
errcode: str = "") -> web.Response:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
|
||||
error="", errcode=""):
|
||||
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
|
||||
mxid: str = "", message: str = "", error: str = "",
|
||||
errcode: str = "") -> web.Response:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def post_matrix_token(self, user: User, token):
|
||||
async def post_matrix_token(self, user: User, token: str) -> web.Response:
|
||||
puppet = Puppet.get(user.tgid)
|
||||
if puppet.is_real_user:
|
||||
return self.get_mx_login_response(state="already-logged-in", status=409,
|
||||
@@ -61,11 +66,11 @@ class AuthAPI(abc.ABC):
|
||||
|
||||
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
|
||||
|
||||
async def post_matrix_password(self, user, password):
|
||||
async def post_matrix_password(self, user: User, password: str) -> web.Response:
|
||||
return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented",
|
||||
errcode="not-yet-implemented")
|
||||
|
||||
async def post_login_phone(self, user, phone):
|
||||
async def post_login_phone(self, user: User, phone: str) -> web.Response:
|
||||
try:
|
||||
await user.client.sign_in(phone or "+123")
|
||||
return self.get_login_response(mxid=user.mxid, state="code", status=200,
|
||||
@@ -102,7 +107,7 @@ class AuthAPI(abc.ABC):
|
||||
errcode="unknown_error",
|
||||
error="Internal server error while requesting code.")
|
||||
|
||||
async def post_login_token(self, user, token):
|
||||
async def post_login_token(self, user: User, token: str) -> web.Response:
|
||||
try:
|
||||
user_info = await user.client.sign_in(bot_token=token)
|
||||
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
|
||||
@@ -123,7 +128,8 @@ class AuthAPI(abc.ABC):
|
||||
return self.get_login_response(mxid=user.mxid, state="token", status=500,
|
||||
error="Internal server error while sending token.")
|
||||
|
||||
async def post_login_code(self, user, code, password_in_data):
|
||||
async def post_login_code(self, user: User, code: int, password_in_data: bool
|
||||
) -> Optional[web.Response]:
|
||||
try:
|
||||
user_info = await user.client.sign_in(code=code)
|
||||
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
|
||||
@@ -156,7 +162,7 @@ class AuthAPI(abc.ABC):
|
||||
errcode="unknown_error",
|
||||
error="Internal server error while sending code.")
|
||||
|
||||
async def post_login_password(self, user, password):
|
||||
async def post_login_password(self, user: User, password: str) -> web.Response:
|
||||
try:
|
||||
user_info = await user.client.sign_in(password=password)
|
||||
asyncio.ensure_future(user.post_login(user_info), loop=self.loop)
|
||||
|
||||
@@ -35,15 +35,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ProvisioningAPI(AuthAPI):
|
||||
log = logging.getLogger("mau.web.provisioning")
|
||||
log = logging.getLogger("mau.web.provisioning") # type: logging.Logger
|
||||
|
||||
def __init__(self, context: "Context") -> None:
|
||||
super().__init__(context.loop)
|
||||
self.secret = context.config["appservice.provisioning.shared_secret"]
|
||||
self.secret = context.config["appservice.provisioning.shared_secret"] # type: str
|
||||
self.az = context.az # type: AppService
|
||||
self.context = context # type: Context
|
||||
|
||||
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware])
|
||||
self.app = web.Application(loop=context.loop, middlewares=[self.error_middleware]
|
||||
) # type: web.Application
|
||||
|
||||
portal_prefix = "/portal/{mxid:![^/]+}"
|
||||
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
|
||||
@@ -353,7 +354,8 @@ class ProvisioningAPI(AuthAPI):
|
||||
await user.log_out()
|
||||
|
||||
@staticmethod
|
||||
async def error_middleware(_, handler) -> Callable[[web.Request], Awaitable[web.Response]]:
|
||||
async def error_middleware(_, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> Callable[[web.Request], Awaitable[web.Response]]:
|
||||
async def middleware_handler(request: web.Request) -> web.Response:
|
||||
try:
|
||||
return await handler(request)
|
||||
|
||||
@@ -14,14 +14,17 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from aiohttp import web
|
||||
from mako.template import Template
|
||||
import pkg_resources
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from ...types import MatrixUserID
|
||||
from ...util import sign_token, verify_token
|
||||
from ...user import User
|
||||
from ...puppet import Puppet
|
||||
@@ -29,20 +32,20 @@ from ..common import AuthAPI
|
||||
|
||||
|
||||
class PublicBridgeWebsite(AuthAPI):
|
||||
log = logging.getLogger("mau.web.public")
|
||||
log = logging.getLogger("mau.web.public") # type: logging.Logger
|
||||
|
||||
def __init__(self, loop):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
self.secret_key = "".join(
|
||||
random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||
random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) # type: str
|
||||
|
||||
self.login = Template(
|
||||
pkg_resources.resource_string("mautrix_telegram", "web/public/login.html.mako"))
|
||||
self.login = Template(pkg_resources.resource_string(
|
||||
"mautrix_telegram", "web/public/login.html.mako")) # type: Template
|
||||
|
||||
self.mx_login = Template(
|
||||
pkg_resources.resource_string("mautrix_telegram", "web/public/matrix-login.html.mako"))
|
||||
self.mx_login = Template(pkg_resources.resource_string(
|
||||
"mautrix_telegram", "web/public/matrix-login.html.mako")) # type: Template
|
||||
|
||||
self.app = web.Application(loop=loop)
|
||||
self.app = web.Application(loop=loop) # type: web.Application
|
||||
self.app.router.add_route("GET", "/login", self.get_login)
|
||||
self.app.router.add_route("POST", "/login", self.post_login)
|
||||
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
|
||||
@@ -50,21 +53,21 @@ class PublicBridgeWebsite(AuthAPI):
|
||||
self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram",
|
||||
"web/public/"))
|
||||
|
||||
def make_token(self, mxid, endpoint="/login", expires_in=900):
|
||||
def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str:
|
||||
return sign_token(self.secret_key, {
|
||||
"mxid": mxid,
|
||||
"endpoint": endpoint,
|
||||
"expiry": int(time.time()) + expires_in,
|
||||
})
|
||||
|
||||
def verify_token(self, token, endpoint="/login"):
|
||||
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[MatrixUserID]:
|
||||
token = verify_token(self.secret_key, token)
|
||||
if token and (token.get("expiry", 0) > int(time.time()) and
|
||||
token.get("endpoint", None) == endpoint):
|
||||
return token.get("mxid", None)
|
||||
return MatrixUserID(token.get("mxid", None))
|
||||
return None
|
||||
|
||||
async def get_login(self, request):
|
||||
async def get_login(self, request: web.Request) -> web.Response:
|
||||
state = "bot_token" if request.rel_url.query.get("mode", "") == "bot" else "request"
|
||||
|
||||
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
|
||||
@@ -83,7 +86,7 @@ class PublicBridgeWebsite(AuthAPI):
|
||||
|
||||
return self.get_login_response(mxid=user.mxid, username=user.username)
|
||||
|
||||
async def get_matrix_login(self, request):
|
||||
async def get_matrix_login(self, request: web.Request) -> web.Response:
|
||||
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
|
||||
if not mxid:
|
||||
return self.get_mx_login_response(status=401, state="invalid-token")
|
||||
@@ -105,19 +108,21 @@ class PublicBridgeWebsite(AuthAPI):
|
||||
|
||||
return self.get_mx_login_response(mxid=user.mxid)
|
||||
|
||||
def get_login_response(self, status=200, state="", username="", mxid="", message="", error="",
|
||||
errcode=""):
|
||||
def get_login_response(self, status: int = 200, state: str = "", username: str = "",
|
||||
mxid: str = "", message: str = "", error: str = "",
|
||||
errcode: str = "") -> web.Response:
|
||||
return web.Response(status=status, content_type="text/html",
|
||||
text=self.login.render(username=username, state=state, error=error,
|
||||
message=message, mxid=mxid))
|
||||
|
||||
def get_mx_login_response(self, status=200, state="", username="", mxid="", message="",
|
||||
error="", errcode=""):
|
||||
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "",
|
||||
mxid: str = "", message: str = "", error: str = "",
|
||||
errcode: str = "") -> web.Response:
|
||||
return web.Response(status=status, content_type="text/html",
|
||||
text=self.mx_login.render(username=username, state=state, error=error,
|
||||
message=message, mxid=mxid))
|
||||
|
||||
async def post_matrix_login(self, request):
|
||||
async def post_matrix_login(self, request: web.Request) -> web.Response:
|
||||
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/matrix-login")
|
||||
if not mxid:
|
||||
return self.get_mx_login_response(status=401, state="invalid-token")
|
||||
@@ -140,7 +145,7 @@ class PublicBridgeWebsite(AuthAPI):
|
||||
error="You must provide an access token or "
|
||||
"password.")
|
||||
|
||||
async def post_login(self, request):
|
||||
async def post_login(self, request: web.Request) -> web.Response:
|
||||
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
|
||||
if not mxid:
|
||||
return self.get_login_response(status=401, state="invalid-token")
|
||||
|
||||
Reference in New Issue
Block a user