Compare commits

...

22 Commits

Author SHA1 Message Date
Tulir Asokan 36a654bcfe Bump version to 0.5.0rc4 2019-03-16 17:36:25 +02:00
Tulir Asokan e16182ee6a Fix Context initialization in tests 2019-03-16 17:22:16 +02:00
Tulir Asokan 7c46bf4b9e Remove remaining traces of ORM 2019-03-16 17:13:28 +02:00
Tulir Asokan 7c82580b4b Merge pull request #290 from V02460/tests
Add pytest unit testing framework
2019-03-16 17:13:19 +02:00
Kai A. Hiller 1e1e9b03c0 Revert absolute imports back to relative 2019-03-14 10:33:43 +01:00
Tulir Asokan 0587145145 Always flush stdout when logging in db migrate script 2019-03-13 23:50:40 +02:00
Tulir Asokan 7840da94b5 Fix verbose flag in db migrate script 2019-03-13 23:41:44 +02:00
Tulir Asokan 010866e0d0 Add verbose option to db migration script 2019-03-13 23:28:31 +02:00
Tulir Asokan c54b057d90 Add __init__.py's so scripts would be included in builds 2019-03-13 23:28:31 +02:00
Tulir Asokan b55f3a9c4d Merge pull request #291 from t2bot/travis/error-reporting
Log startup exceptions
2019-03-10 13:08:48 +02:00
Travis Ralston aa09e738e6 Log startup exceptions 2019-03-09 20:19:15 -06:00
Kai A. Hiller 4254b85628 Add pytest unit testing framework 2019-03-08 19:11:02 +01:00
Tulir Asokan 7d5e946067 Fix potential errors caused by deleted portals when logging out (ref #286) 2019-03-02 04:09:39 +02:00
Tulir Asokan 9eda525d2a Fix handling missing argument in clear-db-cache (ref #286) 2019-03-02 04:09:23 +02:00
Tulir Asokan 8ef337f40b Remove lxml HTML parser as it was messing up emoji offset handling 2019-03-01 23:45:30 +02:00
Tulir Asokan f5ac584ed5 Escape HTML in displaynames before putting it in the relaybot format 2019-03-01 23:11:54 +02:00
Tulir Asokan a3534d802a Wrap database-changing statements in db.begin() 2019-02-24 02:53:50 +02:00
Tulir Asokan 92b689255b Bump minimum alchemysession version and fix migrate script imports 2019-02-20 01:46:24 +02:00
Tulir Asokan fb5167963a Fix repadding base64 2019-02-17 16:14:38 +02:00
Tulir Asokan 50ac4b6381 Handle cases where entity.default_banned_rights is None 2019-02-16 23:22:04 +02:00
Tulir Asokan d842fc73cb Handle AuthKeyError when terminating sessions 2019-02-16 23:21:47 +02:00
Tulir Asokan 531d118ed0 Fix saving new users to database. Actually fixes #284 2019-02-16 23:12:39 +02:00
40 changed files with 821 additions and 217 deletions
+1
View File
@@ -1,6 +1,7 @@
.idea/
.venv
env/
pip-selfcheck.json
*.pyc
__pycache__
+1 -1
View File
@@ -1,2 +1,2 @@
__version__ = "0.5.0rc3"
__version__ = "0.5.0rc4"
__author__ = "Tulir Asokan <tulir@maunium.net>"
+8 -10
View File
@@ -23,7 +23,6 @@ import sys
import copy
import signal
from sqlalchemy import orm
import sqlalchemy as sql
from mautrix_appservice import AppService
@@ -73,13 +72,10 @@ log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}")
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind = db_engine
session_container = AlchemySessionContainer(engine=db_engine, session=db_session,
table_base=Base, table_prefix="telethon_",
manage_tables=False)
session_container = AlchemySessionContainer(engine=db_engine, table_base=Base, session=False,
table_prefix="telethon_", manage_tables=False)
session_container.core_mode = True
try:
@@ -102,8 +98,9 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
aiohttp_params={
"client_max_size": config["appservice.max_body_size"] * mebibyte
})
context = Context(appserv, db_session, config, loop, session_container)
bot = init_bot(config)
context = Context(appserv, config, loop, session_container, bot)
context.mx = MatrixHandler(context)
if config["appservice.public.enabled"]:
public_website = PublicBridgeWebsite(loop)
@@ -120,8 +117,6 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
start_ts = time()
init_db(db_engine)
init_abstract_user(context)
context.bot = init_bot(context)
context.mx = MatrixHandler(context)
init_formatter(context)
init_portal(context)
startup_actions = (init_puppet(context) +
@@ -150,3 +145,6 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop))
log.debug("Clients stopped, shutting down")
sys.exit(0)
except Exception as e:
log.exception("Unexpected error")
sys.exit(1)
+3 -8
View File
@@ -20,7 +20,6 @@ import asyncio
import logging
import platform
from sqlalchemy import orm
from telethon.tl.patched import MessageService, Message
from telethon.tl.types import (
Channel, ChannelForbidden, Chat, ChatForbidden, MessageActionChannelMigrateFrom, PeerUser,
@@ -56,7 +55,6 @@ class AbstractUser(ABC):
session_container = None # type: AlchemySessionContainer
loop = None # type: asyncio.AbstractEventLoop
log = None # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService
bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool
@@ -175,11 +173,8 @@ class AbstractUser(ABC):
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted or self.connected:
return self
session_count = self.session_container.Session.query.filter(
self.session_container.Session.session_id == self.mxid).count()
self.log.debug("ensure_started(%s, even_if_no_session=%s, session_count=%s)",
self.mxid, even_if_no_session, session_count)
if even_if_no_session or session_count > 0:
self.log.debug("ensure_started(%s, even_if_no_session=%s)", self.mxid, even_if_no_session)
if even_if_no_session or self.session_container.has_session(self.mxid):
await self.start(delete_unless_authenticated=not even_if_no_session)
return self
@@ -388,7 +383,7 @@ class AbstractUser(ABC):
def init(context: "Context") -> None:
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.az, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+4 -3
View File
@@ -56,7 +56,7 @@ class Bot(AbstractUser):
self.username = None # type: str
self.is_relaybot = True # type: bool
self.is_bot = True # type: bool
self.chats = {chat.id: chat.type for chat in BotChat.all()} # type: Dict[int, str]
self.chats = {} # type: Dict[int, str]
self.tg_whitelist = [] # type: List[int]
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
@@ -74,6 +74,7 @@ class Bot(AbstractUser):
self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
self.chats = {chat.id: chat.type for chat in BotChat.all()}
await super().start(delete_unless_authenticated)
if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token)
@@ -280,9 +281,9 @@ class Bot(AbstractUser):
return "bot"
def init(context: 'Context') -> Optional[Bot]:
def init(cfg: 'Config') -> Optional[Bot]:
global config
config = context.config
config = cfg
token = config["telegram.bot_token"]
if token and not token.lower().startswith("disable"):
return Bot(token)
+187 -13
View File
@@ -14,9 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, List, NamedTuple, Optional
import traceback
"""This module contains classes handling commands issued by Matrix users."""
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
NamedTuple,
Optional,
Union,
NewType,
)
import logging
import traceback
import commonmark
@@ -59,7 +71,28 @@ md_parser = commonmark.Parser()
md_renderer = HtmlEscapingRenderer()
def ensure_trailing_newline(s: str) -> str:
"""Returns the passed string, but with a guaranteed trailing newline."""
return s + ("" if s[-1] == "\n" else "\n")
class CommandEvent:
"""Holds information about a command issued in a Matrix room.
When a Matrix command was issued to the bot, CommandEvent will hold
information regarding the event.
Attributes:
room_id: The id of the Matrix room in which the command was issued.
event_id: The id of the matrix event which contained the command.
sender: The user who issued the command.
command: The issued command.
args: Arguments given with the issued command.
is_management: Determines whether the room in which the command wa
issued is a management room.
is_portal: Determines whether the room in which the command was issued
is a portal.
"""
def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, event: MatrixEventID,
sender: u.User, command: str, args: List[str], is_management: bool,
is_portal: bool) -> None:
@@ -78,28 +111,109 @@ class CommandEvent:
self.is_management = is_management
self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix)
html = None
def reply(
self,
message: str,
allow_html: bool = False,
render_markdown: bool = True,
) -> Awaitable[Dict]:
"""Write a reply to the room in which the command was issued.
Replaces occurences of "$cmdprefix" in the message with the command
prefix and replaces occurences of "$cmdprefix+sp " with the command
prefix if the command was not issued in a management room.
If allow_html and render_markdown are both False, the message will not
be rendered to html and sending of html is disabled.
Args:
message: The message to post in the room.
allow_html: Escape html in the message or don't render html at all
if markdown is disabled.
render_markdown: Use markdown formatting to render the passed
message to html.
Returns:
Handler for the message sending function.
"""
message_cmd = self._replace_command_prefix(message)
html = self._render_message(
message_cmd, allow_html=allow_html, render_markdown=render_markdown
)
return self.az.intent.send_notice(self.room_id, message_cmd, html=html)
def mark_read(self) -> Awaitable[Dict]:
"""Marks the command as read by the bot."""
return self.az.intent.mark_read(self.room_id, self.event_id)
def _replace_command_prefix(self, message: str) -> str:
"""Returns the string with the proper command prefix entered."""
message = message.replace(
"$cmdprefix+sp ", "" if self.is_management else f"{self.command_prefix} "
)
return message.replace("$cmdprefix", self.command_prefix)
def _render_message(
self, message: str, allow_html: bool, render_markdown: bool
) -> Optional[str]:
"""Renders the message as HTML.
Args:
allow_html: Flag to allow custom HTML in the message.
render_markdown: If true, markdown styling is applied to the message.
Returns:
The message rendered as HTML.
None is returned if no styled output is required.
"""
html = ""
if render_markdown:
md_renderer.allow_html = allow_html
html = md_renderer.render(md_parser.parse(message))
elif allow_html:
html = message
return self.az.intent.send_notice(self.room_id, message, html=html)
def mark_read(self) -> Awaitable[Dict]:
return self.az.intent.mark_read(self.room_id, self.event_id)
return ensure_trailing_newline(html) if html else None
class CommandHandler:
"""A command which can be executed from a Matrix room.
The command manages its permission and help texts.
When called, it will check the permission of the command event and execute
the command or, in case of error, report back to the user.
Attributes:
needs_auth: Flag indicating if the sender is required to be logged in.
needs_puppeting: Flag indicating if the sender is required to use
Telegram puppeteering for this command.
needs_matrix_puppeting: Flag indicating if the sender is required to use
Matrix pupeteering.
needs_admin: Flag for whether only admin users can issue this command.
management_only: Whether the command can exclusively be issued in a
management room.
name: The name of this command.
help_section: Section of the help in which this command will appear.
"""
def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection) -> None:
"""
Args:
handler: The function handling the execution of this command.
needs_auth: Flag indicating if the sender is required to be logged in.
needs_puppeting: Flag indicating if the sender is required to use
Telegram puppeteering for this command.
needs_matrix_puppeting: Flag indicating if the sender is required to
use Matrix pupeteering.
needs_admin: Flag for whether only admin users can issue this command.
management_only: Whether the command can exclusively be issued
in a management room.
name: The name of this command.
help_text: The text displayed in the help for this command.
help_args: Help text for the arguments of this command.
help_section: Section of the help in which this command will appear.
"""
self._handler = handler
self.needs_auth = needs_auth
self.needs_puppeting = needs_puppeting
@@ -112,6 +226,14 @@ class CommandHandler:
self.help_section = help_section
async def get_permission_error(self, evt: CommandEvent) -> Optional[str]:
"""Returns the reason why the command could not be issued.
Args:
evt: The event for which to get the error information.
Returns:
A string describing the error or None if there was no error.
"""
if self.management_only and not evt.is_management:
return (f"`{evt.command}` is a restricted command: "
"you may only run it in management rooms.")
@@ -127,6 +249,22 @@ class CommandHandler:
def has_permission(self, is_management: bool, puppet_whitelisted: bool,
matrix_puppet_whitelisted: bool, is_admin: bool, is_logged_in: bool) -> bool:
"""Checks the permission for this command with the given status.
Args:
is_management: If the room in which the command will be issued is a
management room.
puppet_whitelited: If the connected Telegram account puppet is
allowed to issue the command.
matrix_puppet_whitelisted: If the connected Matrix account puppet is
allowed to issue the command.
is_admin: If the issuing user is an admin.
is_logged_in: If the issuing user is logged in.
Returns:
True if a user with the given state is allowed to issue the
command.
"""
return ((not self.management_only or is_management) and
(not self.needs_puppeting or puppet_whitelisted) and
(not self.needs_matrix_puppeting or matrix_puppet_whitelisted) and
@@ -134,6 +272,17 @@ class CommandHandler:
(not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent) -> Dict:
"""Executes the command if evt was issued with proper rights.
Args:
evt: The CommandEvent for which to check permissions.
Returns:
The result of the command or the error message function.
Raises:
FloodWaitError
"""
error = await self.get_permission_error(evt)
if error is not None:
return await evt.reply(error)
@@ -141,10 +290,12 @@ class CommandHandler:
@property
def has_help(self) -> bool:
"""Returns true if this command has a help text."""
return bool(self.help_section) and bool(self._help_text)
@property
def help(self) -> str:
"""Returns the help text to this command."""
return f"**{self.name}** {self._help_args} - {self._help_text}"
@@ -173,16 +324,39 @@ def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] =
class CommandProcessor:
"""Handles the raw commands issued by a user to the Matrix bot."""
log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.az, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: MatrixRoomID, event_id: MatrixEventID, sender: u.User,
command: str, args: List[str], is_management: bool, is_portal: bool
) -> Optional[Dict]:
"""Handles the raw commands issued by a user to the Matrix bot.
If the command is not known, it might be a followup command and is
delegated to a command handler registered for that purpose in the
senders command_status as "next".
Args:
room: ID of the Matrix room in which the command was issued.
event_id: ID of the event by which the command was issued.
sender: The sender who issued the command.
command: The issued command, case insensitive.
args: Arguments given with the command.
is_management: Whether the room is a management room.
is_portal: Whether the room is a portal.
Returns:
The result of the error message function or None if no error
occured. Unknown and delegated commands do not count as errors.
"""
if not command_handlers or "unknown-command" not in command_handlers:
raise ValueError("command_handlers are not properly initialized.")
evt = CommandEvent(self, room, event_id, sender, command, args, is_management, is_portal)
orig_command = command
command = command.lower()
+1 -1
View File
@@ -52,7 +52,7 @@ async def set_power_level(evt: CommandEvent) -> Dict:
async def clear_db_cache(evt: CommandEvent) -> Dict:
try:
section = evt.args[0].lower()
except KeyError:
except IndexError:
return await evt.reply("**Usage:** `$cmdprefix+sp clear-db-cache <section>`")
if section == "portal":
po.Portal.by_tgid = {}
@@ -17,7 +17,7 @@
from typing import Dict, Optional
from telethon.errors import (UsernameInvalidError, UsernameNotModifiedError, UsernameOccupiedError,
HashInvalidError)
HashInvalidError, AuthKeyError)
from telethon.tl.types import Authorization
from telethon.tl.functions.account import (UpdateUsernameRequest, GetAuthorizationsRequest,
ResetAuthorizationRequest)
@@ -94,6 +94,11 @@ async def session(evt: CommandEvent) -> Optional[Dict]:
ok = await evt.sender.client(ResetAuthorizationRequest(hash=session_hash))
except HashInvalidError:
return await evt.reply("Invalid session hash.")
except AuthKeyError as e:
if e.message == "FRESH_RESET_AUTHORISATION_FORBIDDEN":
return await evt.reply("New sessions can't terminate other sessions. "
"Please wait a while.")
raise
if ok:
return await evt.reply("Session terminated successfully.")
else:
+1 -1
View File
@@ -255,7 +255,7 @@ async def vote(evt: CommandEvent) -> Optional[Dict]:
if not isinstance(msg.media, MessageMediaPoll):
return await evt.reply("Invalid poll ID (message doesn't look like a poll)")
options = [base64.b64decode(option + (4 - len(option) % 4) * "=")
options = [base64.b64decode(option + (3 - (len(option) + 3) % 4) * "=")
for option in evt.args[1:]]
try:
resp = await evt.sender.client(SendVoteRequest(peer=peer, msg_id=msg.id, options=options))
+8 -13
View File
@@ -19,8 +19,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
from sqlalchemy.orm import scoped_session
from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService
@@ -31,20 +29,17 @@ if TYPE_CHECKING:
class Context:
def __init__(self, az: 'AppService', db: 'scoped_session', config: 'Config',
loop: 'asyncio.AbstractEventLoop', session_container: 'AlchemySessionContainer'
) -> None:
def __init__(self, az: 'AppService', config: 'Config', loop: 'asyncio.AbstractEventLoop',
session_container: 'AlchemySessionContainer', bot: Optional['Bot']) -> None:
self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config
self.loop = loop # type: asyncio.AbstractEventLoop
self.bot = None # type: Optional[Bot]
self.mx = None # type: MatrixHandler
self.bot = bot # type: Optional[Bot]
self.mx = None # type: Optional[MatrixHandler]
self.session_container = session_container # type: AlchemySessionContainer
self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI
self.public_website = None # type: Optional[PublicBridgeWebsite]
self.provisioning_api = None # type: Optional[ProvisioningAPI]
@property
def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
'asyncio.AbstractEventLoop', Optional['Bot']]:
return (self.az, self.db, self.config, self.loop, self.bot)
def core(self) -> Tuple['AppService', 'Config', 'asyncio.AbstractEventLoop', Optional['Bot']]:
return self.az, self.config, self.loop, self.bot
+6 -4
View File
@@ -44,14 +44,16 @@ class BaseBase:
pass
def update(self, **values) -> None:
self.db.execute(self.t.update()
.where(self._edit_identity)
.values(**values))
with self.db.begin() as conn:
conn.execute(self.t.update()
.where(self._edit_identity)
.values(**values))
for key, value in values.items():
setattr(self, key, value)
def delete(self) -> None:
self.db.execute(self.t.delete().where(self._edit_identity))
with self.db.begin() as conn:
conn.execute(self.t.delete().where(self._edit_identity))
Base = declarative_base(cls=BaseBase)
+4 -2
View File
@@ -30,7 +30,8 @@ class BotChat(Base):
@classmethod
def delete(cls, id: TelegramID) -> None:
cls.db.execute(cls.t.delete().where(cls.c.id == id))
with cls.db.begin() as conn:
conn.execute(cls.t.delete().where(cls.c.id == id))
@classmethod
def all(cls) -> Iterable['BotChat']:
@@ -40,4 +41,5 @@ class BotChat(Base):
yield cls(id=id, type=type)
def insert(self) -> None:
self.db.execute(self.t.insert().values(id=self.id, type=self.type))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(id=self.id, type=self.type))
+11 -8
View File
@@ -68,20 +68,23 @@ class Message(Base):
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
with cls.db.begin() as conn:
conn.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
with cls.db.begin() as conn:
conn.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room,
tgid=self.tgid, tg_space=self.tg_space))
+5 -4
View File
@@ -74,7 +74,8 @@ class Portal(Base):
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
self.db.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
title=self.title, about=self.about, photo_id=self.photo_id))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config,
username=self.username, title=self.title, about=self.about, photo_id=self.photo_id))
+6 -5
View File
@@ -79,8 +79,9 @@ class Puppet(Base):
return self.c.id == self.id
def insert(self) -> None:
self.db.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
displayname=self.displayname, displayname_source=self.displayname_source,
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
matrix_registered=self.matrix_registered))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
displayname=self.displayname, displayname_source=self.displayname_source,
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
matrix_registered=self.matrix_registered))
+7 -5
View File
@@ -47,14 +47,16 @@ class RoomState(Base):
return None
def update(self) -> None:
self.db.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
with self.db.begin() as conn:
conn.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
@property
def _edit_identity(self):
return self.c.room_id == self.room_id
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id,
power_levels=self._power_levels_text))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(room_id=self.room_id,
power_levels=self._power_levels_text))
+7 -6
View File
@@ -15,7 +15,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
from sqlalchemy.orm import relationship
from typing import Optional
from .base import Base
@@ -33,7 +32,7 @@ class TelegramFile(Base):
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
thumbnail_id = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail = relationship("TelegramFile", uselist=False)
thumbnail = None # type: Optional[TelegramFile]
@classmethod
def get(cls, id: str) -> Optional['TelegramFile']:
@@ -49,7 +48,9 @@ class TelegramFile(Base):
return None
def insert(self) -> None:
self.db.execute(self.t.insert().values(
id=self.id, mxc=self.mxc, mime_type=self.mime_type, was_converted=self.was_converted,
timestamp=self.timestamp, size=self.size, width=self.width, height=self.height,
thumbnail=self.thumbnail.id if self.thumbnail else self.thumbnail_id))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
id=self.id, mxc=self.mxc, mime_type=self.mime_type,
was_converted=self.was_converted, timestamp=self.timestamp, size=self.size,
width=self.width, height=self.height,
thumbnail=self.thumbnail.id if self.thumbnail else self.thumbnail_id))
+18 -16
View File
@@ -65,9 +65,10 @@ class User(Base):
return self.c.mxid == self.mxid
def insert(self) -> None:
self.db.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username,
tg_phone=self.tg_phone, saved_contacts=self.saved_contacts))
@property
def contacts(self) -> Iterable[TelegramID]:
@@ -78,10 +79,11 @@ class User(Base):
@contacts.setter
def contacts(self, puppets: Iterable[TelegramID]) -> None:
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
if puppets:
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
for tgid in puppets])
with self.db.begin() as conn:
conn.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
insert_puppets = [{"user": self.tgid, "contact": tgid} for tgid in puppets]
if insert_puppets:
conn.execute(Contact.t.insert(), insert_puppets)
@property
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
@@ -92,14 +94,15 @@ class User(Base):
@portals.setter
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
if portals:
self.db.execute(UserPortal.t.insert(),
[{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals])
with self.db.begin() as conn:
conn.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
insert_portals = [{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals]
if insert_portals:
conn.execute(UserPortal.t.insert(), insert_portals)
def delete(self) -> None:
super().delete()
@@ -125,4 +128,3 @@ class Contact(Base):
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
+7 -6
View File
@@ -50,7 +50,8 @@ class UserProfile(Base):
@classmethod
def delete_all(cls, room_id: MatrixRoomID) -> None:
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
with cls.db.begin() as conn:
conn.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
super().update(membership=self.membership, displayname=self.displayname,
@@ -61,8 +62,8 @@ class UserProfile(Base):
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
membership=self.membership,
displayname=self.displayname,
avatar_url=self.avatar_url))
with self.db.begin() as conn:
conn.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
membership=self.membership,
displayname=self.displayname,
avatar_url=self.avatar_url))
@@ -76,7 +76,6 @@ def matrix_to_telegram(html: str) -> ParsedMessage:
if should_bridge_plaintext_highlights:
html = plain_mention_regex.sub(plain_mention_to_html, html)
html = add_surrogates(html)
text, entities = parse_html(add_surrogates(html))
text = remove_surrogates(text.strip())
text, entities = cut_long_message(text, entities)
@@ -1,4 +1,58 @@
try:
from .html_reader_lxml import HTMLNode, read_html
except ImportError:
from .html_reader_htmlparser import HTMLNode, read_html
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Tuple
from html.parser import HTMLParser
class HTMLNode(list):
def __init__(self, tag: str, attrs: List[Tuple[str, str]]):
super().__init__()
self.tag = tag # type: str
self.text = "" # type: str
self.tail = "" # type: str
self.attrib = dict(attrs) # type: Dict[str, str]
class NodeifyingParser(HTMLParser):
def __init__(self):
super().__init__()
self.stack = [HTMLNode("html", [])] # type: List[HTMLNode]
def handle_starttag(self, tag, attrs):
node = HTMLNode(tag, attrs)
self.stack[-1].append(node)
self.stack.append(node)
def handle_endtag(self, tag):
if tag == self.stack[-1].tag:
self.stack.pop()
def handle_data(self, data):
if len(self.stack[-1]) > 0:
self.stack[-1][-1].tail += data
else:
self.stack[-1].text += data
def error(self, message):
pass
def read_html(data: str) -> HTMLNode:
parser = NodeifyingParser()
parser.feed(data)
return parser.stack[0]
@@ -1,58 +0,0 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Tuple
from html.parser import HTMLParser
class HTMLNode(list):
def __init__(self, tag: str, attrs: List[Tuple[str, str]]):
super().__init__()
self.tag = tag # type: str
self.text = "" # type: str
self.tail = "" # type: str
self.attrib = dict(attrs) # type: Dict[str, str]
class NodeifyingParser(HTMLParser):
def __init__(self):
super().__init__()
self.stack = [HTMLNode("html", [])] # type: List[HTMLNode]
def handle_starttag(self, tag, attrs):
node = HTMLNode(tag, attrs)
self.stack[-1].append(node)
self.stack.append(node)
def handle_endtag(self, tag):
if tag == self.stack[-1].tag:
self.stack.pop()
def handle_data(self, data):
if len(self.stack[-1]) > 0:
self.stack[-1][-1].tail += data
else:
self.stack[-1].text += data
def error(self, message):
pass
def read_html(data: str) -> HTMLNode:
parser = NodeifyingParser()
parser.feed(data)
return parser.stack[0]
@@ -1,23 +0,0 @@
# -*- coding: future_fstrings -*-
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from lxml import html
HTMLNode = html.HtmlElement
def read_html(data: str) -> HTMLNode:
return html.fromstring(data)
+1 -1
View File
@@ -32,7 +32,7 @@ class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context.core
self.az, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[MatrixUserID]
+8 -4
View File
@@ -431,6 +431,10 @@ class Portal:
levels["events_default"] = 0
else:
dbr = entity.default_banned_rights
if not dbr:
self.log.debug(f"default_banned_rights is None in {entity}")
dbr = ChatBannedRights(invite_users=True, change_info=True, pin_messages=True,
send_stickers=False, send_messages=False, until_date=0)
levels["ban"] = 99
levels["kick"] = 50
levels["invite"] = 50 if dbr.invite_users else 0
@@ -440,7 +444,6 @@ class Portal:
levels["events"]["m.room.topic"] = 50 if dbr.change_info else 0
levels["events"][
"m.room.pinned_events"] = 50 if dbr.pin_messages else 0
levels["events"]["m.sticker"] = 50 if dbr.send_stickers else 0
levels["events"]["m.room.power_levels"] = 75
levels["events"]["m.room.history_visibility"] = 75
levels["state_default"] = 50
@@ -448,6 +451,7 @@ class Portal:
levels["events_default"] = (50 if (self.peer_type == "channel" and not entity.megagroup
or entity.default_banned_rights.send_messages)
else 0)
levels["events"]["m.sticker"] = 50 if dbr.send_stickers else levels["events_default"]
if "users" not in levels:
levels["users"] = {
self.main_intent.mxid: 100
@@ -771,7 +775,7 @@ class Portal:
tpl_args = dict(mxid=user.mxid,
username=user.mxid_localpart,
displayname=displayname)
displayname=escape_html(displayname))
tpl_args = {**tpl_args, **(arguments or {})}
message = Template(tpl).safe_substitute(tpl_args)
return {
@@ -903,7 +907,7 @@ class Portal:
displayname = await self.get_displayname(sender)
tpl_args = dict(sender_mxid=sender.mxid,
sender_username=sender.mxid_localpart,
sender_displayname=displayname,
sender_displayname=escape_html(displayname),
message=body)
message["formatted_body"] = Template(tpl).safe_substitute(tpl_args)
@@ -2035,7 +2039,7 @@ class Portal:
def init(context: Context) -> None:
global config
Portal.az, _, config, Portal.loop, Portal.bot = context.core
Portal.az, config, Portal.loop, Portal.bot = context.core
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
Portal.sync_channel_members = config["bridge.sync_channel_members"]
Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
+4 -9
View File
@@ -14,8 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union,
TYPE_CHECKING)
from typing import Awaitable, Any, Dict, List, Iterable, Optional, Pattern, Union, TYPE_CHECKING
from difflib import SequenceMatcher
from enum import Enum
from aiohttp import ServerDisconnectedError
@@ -23,8 +22,6 @@ import asyncio
import logging
import re
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto, User, FileLocation, UpdateUserName, PeerUser
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
@@ -45,7 +42,6 @@ config = None # type: Config
class Puppet:
log = logging.getLogger("mau.puppet") # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService
mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop
@@ -400,8 +396,7 @@ class Puppet:
if create:
puppet = cls(tgid)
cls.db.add(puppet.db_instance)
cls.db.commit()
puppet.db_instance.insert()
return puppet
return None
@@ -481,9 +476,9 @@ class Puppet:
# endregion
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetError]
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
@@ -11,12 +11,19 @@ parser.add_argument("-f", "--from-url", type=str, required=True, metavar="<url>"
help="the old database path")
parser.add_argument("-t", "--to-url", type=str, required=True, metavar="<url>",
help="the new database path")
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose logs while migrating")
args = parser.parse_args()
verbose = args.verbose or False
def log(message, end="\n"):
if verbose:
print(message, end=end, flush=True)
def connect(to):
import mautrix_telegram.base as base
base.Base = declarative_base()
import mautrix_telegram.db.base as base
base.Base = declarative_base(cls=base.BaseBase)
from mautrix_telegram.db import (Portal, Message, UserPortal, User, RoomState, UserProfile,
Contact, Puppet, BotChat, TelegramFile)
db_engine = sql.create_engine(to)
@@ -45,15 +52,30 @@ def connect(to):
"TelegramFile": TelegramFile,
}
log("Connecting to old database")
session, tables = connect(args.from_url)
data = {}
for name, table in tables.items():
log("Reading table {name}...".format(name=name), end=" ")
data[name] = session.query(table).all()
log("Done!")
log("Connecting to new database")
session, tables = connect(args.to_url)
for name, table in tables.items():
log("Writing table {name}".format(name=name), end="")
length = len(data[name])
n = 0
for row in data[name]:
session.merge(row)
n += 5
if n >= length:
log(".", end="")
n = 0
log(" Done!")
log("Committing changes to database...", end=" ")
session.commit()
log("Done!")
+3 -4
View File
@@ -133,8 +133,7 @@ class User(AbstractUser):
def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts,
portals=self.db_portals)
saved_contacts=self.saved_contacts, portals=self.db_portals)
def save(self, contacts: bool = False, portals: bool = False) -> None:
self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
@@ -240,7 +239,7 @@ class User(AbstractUser):
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
for _, portal in self.portals.items():
if not portal.mxid or portal.has_bot:
if not portal or portal.deleted or not portal.mxid or portal.has_bot:
continue
try:
await portal.main_intent.kick(portal.mxid, self.mxid, "Logged out of Telegram.")
@@ -318,7 +317,7 @@ class User(AbstractUser):
def unregister_portal(self, portal: po.Portal) -> None:
try:
del self.portals[portal.tgid_full]
self.save_portals()
self.save(portals=True)
except KeyError:
pass
+2
View File
@@ -0,0 +1,2 @@
[aliases]
test=pytest
+5 -3
View File
@@ -4,7 +4,6 @@ import mautrix_telegram
extras = {
"highlight_edits": ["lxml>=4.1.1,<5"],
"better_formatter": ["lxml>=4.1.1,<5"],
"fast_crypto": ["cryptg>=0.1,<0.2"],
"webp_convert": ["Pillow>=4.3.0,<6"],
"hq_thumbnails": ["moviepy>=1.0,<2.0"],
@@ -39,11 +38,14 @@ setuptools.setup(
"ruamel.yaml>=0.15.35,<0.16",
"future-fstrings>=0.4.2",
"python-magic>=0.4.15,<0.5",
"telethon>=1.5.5,<1.6",
"telethon-session-sqlalchemy>=0.2.8,<0.3",
"telethon>=1.5.5,<1.7",
"telethon-session-sqlalchemy>=0.2.11,<0.3",
],
extras_require=extras,
setup_requires=["pytest-runner"],
tests_require=["pytest", "pytest-asyncio", "pytest-mock"],
classifiers=[
"Development Status :: 4 - Beta",
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
View File
View File
+370
View File
@@ -0,0 +1,370 @@
from typing import Tuple
from unittest.mock import Mock
import pytest
from _pytest.fixtures import FixtureRequest
from pytest_mock import MockFixture
import mautrix_telegram.commands.handler
from mautrix_telegram.commands.handler import (CommandEvent, CommandHandler, CommandProcessor,
HelpSection)
from mautrix_telegram.config import Config
from mautrix_telegram.context import Context
from mautrix_telegram.types import MatrixEventID, MatrixRoomID, MatrixUserID
import mautrix_telegram.user as u
from tests.utils.helpers import AsyncMock, list_true_once_each
@pytest.fixture
def context(request: FixtureRequest) -> Context:
"""Returns a Context with mocked Attributes.
Uses the attribute cls.config as Config.
"""
# Config(path, registration_path, base_path)
config = getattr(request.cls, 'config', Config("", "", ""))
return Context(az=Mock(), config=config, loop=Mock(), session_container=Mock(), bot=Mock())
@pytest.fixture
def command_processor(context: Context) -> CommandProcessor:
"""Returns a mocked CommandProcessor."""
return CommandProcessor(context)
class TestCommandEvent:
config = Config("", "", "")
config["bridge.command_prefix"] = "tg"
config["bridge.permissions"] = {"*": "noperm"}
def test_reply(
self, command_processor: CommandProcessor, mocker: MockFixture
) -> None:
mocker.patch("mautrix_telegram.user.config", self.config)
evt = CommandEvent(
processor=command_processor,
room=MatrixRoomID("#mock_room:example.org"),
event=MatrixEventID("$H45H:example.org"),
sender=u.User(MatrixUserID("@sender:example.org")),
command="help",
args=[],
is_management=True,
is_portal=False,
)
mock_az = command_processor.az
message = "**This** <i>was</i><br/><strong>all</strong>fun*!"
# html, no markdown
evt.reply(message, allow_html=True, render_markdown=False)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"**This** <i>was</i><br/><strong>all</strong>fun*!",
html="**This** <i>was</i><br/><strong>all</strong>fun*!\n",
)
# html, markdown (default)
evt.reply(message, allow_html=True, render_markdown=True)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"**This** <i>was</i><br/><strong>all</strong>fun*!",
html=(
"<p><strong>This</strong> <i>was</i><br/>"
"<strong>all</strong>fun*!</p>\n"
),
)
# no html, no markdown
evt.reply(message, allow_html=False, render_markdown=False)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"**This** <i>was</i><br/><strong>all</strong>fun*!",
html=None,
)
# no html, markdown
evt.reply(message, allow_html=False, render_markdown=True)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"**This** <i>was</i><br/><strong>all</strong>fun*!",
html="<p><strong>This</strong> &lt;i&gt;was&lt;/i&gt;&lt;br/&gt;"
"&lt;strong&gt;all&lt;/strong&gt;fun*!</p>\n"
)
def test_reply_with_cmdprefix(self, command_processor: CommandProcessor, mocker: MockFixture
) -> None:
mocker.patch("mautrix_telegram.user.config", self.config)
evt = CommandEvent(
processor=command_processor,
room=MatrixRoomID("#mock_room:example.org"),
event=MatrixEventID("$H45H:example.org"),
sender=u.User(MatrixUserID("@sender:example.org")),
command="help",
args=[],
is_management=False,
is_portal=False,
)
mock_az = command_processor.az
evt.reply("$cmdprefix+sp ....$cmdprefix+sp...$cmdprefix $cmdprefix", allow_html=False,
render_markdown=False)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"tg ....tg+sp...tg tg",
html=None,
)
def test_reply_with_cmdprefix_in_management_room(self, command_processor: CommandProcessor,
mocker: MockFixture) -> None:
mocker.patch("mautrix_telegram.user.config", self.config)
evt = CommandEvent(
processor=command_processor,
room=MatrixRoomID("#mock_room:example.org"),
event=MatrixEventID("$H45H:example.org"),
sender=u.User(MatrixUserID("@sender:example.org")),
command="help",
args=[],
is_management=True,
is_portal=False,
)
mock_az = command_processor.az
evt.reply(
"$cmdprefix+sp ....$cmdprefix+sp...$cmdprefix $cmdprefix",
allow_html=True,
render_markdown=True,
)
mock_az.intent.send_notice.assert_called_with(
MatrixRoomID("#mock_room:example.org"),
"....tg+sp...tg tg",
html="<p>....tg+sp...tg tg</p>\n",
)
class TestCommandHandler:
config = Config("", "", "")
config["bridge.permissions"] = {"*": "noperm"}
@pytest.mark.parametrize(
(
"needs_auth,"
"needs_puppeting,"
"needs_matrix_puppeting,"
"needs_admin,"
"management_only,"
),
[l for l in list_true_once_each(length=5)]
)
@pytest.mark.asyncio
async def test_permissions_denied(
self,
needs_auth: bool,
needs_puppeting: bool,
needs_matrix_puppeting: bool,
needs_admin: bool,
management_only: bool,
command_processor: CommandProcessor,
boolean: bool,
mocker: MockFixture,
) -> None:
mocker.patch("mautrix_telegram.user.config", self.config)
command = "testcmd"
mock_handler = Mock()
command_handler = CommandHandler(
handler=mock_handler,
needs_auth=needs_auth,
needs_puppeting=needs_puppeting,
needs_matrix_puppeting=needs_matrix_puppeting,
needs_admin=needs_admin,
management_only=management_only,
name=command,
help_text="No real command",
help_args="mock mockmock",
help_section=HelpSection("Mock Section", 42, ""),
)
sender = u.User(MatrixUserID("@sender:example.org"))
sender.puppet_whitelisted = False
sender.matrix_puppet_whitelisted = False
sender.is_admin = False
event = CommandEvent(
processor=command_processor,
room=MatrixRoomID("#mock_room:example.org"),
event=MatrixEventID("$H45H:example.org"),
sender=sender,
command=command,
args=[],
is_management=False,
is_portal=boolean,
)
assert await command_handler.get_permission_error(event)
assert not command_handler.has_permission(False, False, False, False, False)
@pytest.mark.parametrize(
(
"is_management,"
"puppet_whitelisted,"
"matrix_puppet_whitelisted,"
"is_admin,"
"is_logged_in,"
),
[l for l in list_true_once_each(length=5)]
)
@pytest.mark.asyncio
async def test_permission_granted(
self,
is_management: bool,
puppet_whitelisted: bool,
matrix_puppet_whitelisted: bool,
is_admin: bool,
is_logged_in: bool,
command_processor: CommandProcessor,
boolean: bool,
mocker: MockFixture,
) -> None:
mocker.patch("mautrix_telegram.user.config", self.config)
command = "testcmd"
mock_handler = Mock()
command_handler = CommandHandler(
handler=mock_handler,
needs_auth=False,
needs_puppeting=False,
needs_matrix_puppeting=False,
needs_admin=False,
management_only=False,
name=command,
help_text="No real command",
help_args="mock mockmock",
help_section=HelpSection("Mock Section", 42, ""),
)
sender = u.User(MatrixUserID("@sender:example.org"))
sender.puppet_whitelisted = puppet_whitelisted
sender.matrix_puppet_whitelisted = matrix_puppet_whitelisted
sender.is_admin = is_admin
mocker.patch.object(u.User, 'is_logged_in', return_value=is_logged_in)
event = CommandEvent(
processor=command_processor,
room=MatrixRoomID("#mock_room:example.org"),
event=MatrixEventID("$H45H:example.org"),
sender=sender,
command=command,
args=[],
is_management=is_management,
is_portal=boolean,
)
assert not await command_handler.get_permission_error(event)
assert command_handler.has_permission(
is_management=is_management,
puppet_whitelisted=puppet_whitelisted,
matrix_puppet_whitelisted=matrix_puppet_whitelisted,
is_admin=is_admin,
is_logged_in=is_logged_in,
)
class TestCommandProcessor:
config = Config("", "", "")
config["bridge.command_prefix"] = "tg"
config["bridge.permissions"] = {"*": "relaybot"}
@pytest.mark.asyncio
async def test_handle(self, command_processor: CommandProcessor, boolean2: Tuple[bool, bool],
mocker: MockFixture) -> None:
mocker.patch('mautrix_telegram.user.config', self.config)
mocker.patch(
'mautrix_telegram.commands.handler.command_handlers',
{"help": AsyncMock(), "unknown-command": AsyncMock()}
)
sender = u.User(MatrixUserID("@sender:example.org"))
result = await command_processor.handle(
room=MatrixRoomID("#mock_room:example.org"),
event_id=MatrixEventID("$H45H:example.org"),
sender=sender,
command="hElp",
args=[],
is_management=boolean2[0],
is_portal=boolean2[1],
)
assert result is None
command_handlers = mautrix_telegram.commands.handler.command_handlers
command_handlers["help"].mock.assert_called_once() # type: ignore
@pytest.mark.asyncio
async def test_handle_unknown_command(self, command_processor: CommandProcessor,
boolean2: Tuple[bool, bool], mocker: MockFixture) -> None:
mocker.patch('mautrix_telegram.user.config', self.config)
mocker.patch(
'mautrix_telegram.commands.handler.command_handlers',
{"help": AsyncMock(), "unknown-command": AsyncMock()}
)
sender = u.User(MatrixUserID("@sender:example.org"))
sender.command_status = {}
result = await command_processor.handle(
room=MatrixRoomID("#mock_room:example.org"),
event_id=MatrixEventID("$H45H:example.org"),
sender=sender,
command="foo",
args=[],
is_management=boolean2[0],
is_portal=boolean2[1],
)
assert result is None
command_handlers = mautrix_telegram.commands.handler.command_handlers
command_handlers["help"].mock.assert_not_called() # type: ignore
command_handlers["unknown-command"].mock.assert_called_once() # type: ignore
@pytest.mark.asyncio
async def test_handle_delegated_handler(self, command_processor: CommandProcessor,
boolean2: Tuple[bool, bool],
mocker: MockFixture) -> None:
mocker.patch('mautrix_telegram.user.config', self.config)
mocker.patch(
'mautrix_telegram.commands.handler.command_handlers',
{"help": AsyncMock(), "unknown-command": AsyncMock()}
)
sender = u.User(MatrixUserID("@sender:example.org"))
sender.command_status = {"foo": AsyncMock(), "next": AsyncMock()}
result = await command_processor.handle(
room=MatrixRoomID("#mock_room:example.org"),
event_id=MatrixEventID("$H45H:example.org"),
sender=sender, # u.User
command="foo",
args=[],
is_management=boolean2[0],
is_portal=boolean2[1]
)
assert result is None
command_handlers = mautrix_telegram.commands.handler.command_handlers
command_handlers["help"].mock.assert_not_called() # type: ignore
command_handlers["unknown-command"].mock.assert_not_called() # type: ignore
sender.command_status["foo"].mock.assert_not_called() # type: ignore
sender.command_status["next"].mock.assert_called_once() # type: ignore
+3
View File
@@ -0,0 +1,3 @@
pytest_plugins = [
"tests.utils.fixtures",
]
View File
+27
View File
@@ -0,0 +1,27 @@
"""This module provides utility fixtures for testing."""
from typing import Tuple
from _pytest.fixtures import FixtureRequest
import pytest
@pytest.fixture(params=[True, False])
def boolean(request: FixtureRequest) -> bool:
return request.param
@pytest.fixture
def boolean1(boolean: bool) -> Tuple[bool]:
return boolean,
@pytest.fixture(params=[True, False])
def boolean2(request: FixtureRequest, boolean: bool) -> Tuple[bool, bool]:
return boolean, request.param
@pytest.fixture(params=[True, False])
def boolean3(request: FixtureRequest, boolean2: Tuple[bool, bool]) -> Tuple[bool, bool, bool]:
return boolean2[0], boolean2[1], request.param
# …
+24
View File
@@ -0,0 +1,24 @@
"""This module provides utility functions for testing."""
from typing import Generator, Tuple
from unittest.mock import Mock
def AsyncMock(*args, **kwargs):
"""Mocks a asyncronous coroutine which can be called with 'await'."""
m = Mock(*args, **kwargs)
async def mock_coro(*args, **kwargs):
return m(*args, **kwargs)
mock_coro.mock = m
return mock_coro
def list_true_once_each(length: int) -> Generator[Tuple[bool, ...], None, None]:
"""Yields tuples of bools with exactly one entry being True, starting left.
Args:
length: Length of the resulting tuples
"""
for i in range(length):
yield tuple(i == j for j in range(length))