Blacken and isort code

This commit is contained in:
Tulir Asokan
2021-12-21 01:36:24 +02:00
parent f2af17d359
commit 6d25e9687e
55 changed files with 3752 additions and 2018 deletions
+13 -5
View File
@@ -15,15 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database
from .upgrade import upgrade_table
from .bot_chat import BotChat
from .message import Message
from .portal import Portal
from .puppet import Puppet
from .telegram_file import TelegramFile
from .user import User
from .telethon_session import PgSession
from .upgrade import upgrade_table
from .user import User
def init(db: Database) -> None:
@@ -31,5 +30,14 @@ def init(db: Database) -> None:
table.db = db
__all__ = ["upgrade_table", "init", "Portal", "Message", "User", "Puppet", "TelegramFile",
"BotChat", "PgSession"]
__all__ = [
"upgrade_table",
"init",
"Portal",
"Message",
"User",
"Puppet",
"TelegramFile",
"BotChat",
"PgSession",
]
+1 -1
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
+8 -5
View File
@@ -15,12 +15,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.types import RoomID, EventID
from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database
from ..types import TelegramID
@@ -92,9 +92,12 @@ class Message:
@classmethod
async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
return await cls.db.fetchval(
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
) or 0
return (
await cls.db.fetchval(
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
)
or 0
)
@classmethod
async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
+16 -5
View File
@@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ClassVar
import json
from asyncpg import Record
from attr import dataclass
import attr
from mautrix.types import RoomID, ContentURI
from mautrix.types import ContentURI, RoomID
from mautrix.util.async_db import Database
from ..types import TelegramID
@@ -93,9 +93,20 @@ class Portal:
@property
def _values(self):
return (self.tgid, self.tg_receiver, self.peer_type, self.mxid, self.avatar_url,
self.encrypted, self.username, self.title, self.about, self.photo_id,
self.megagroup, json.dumps(self.local_config) if self.local_config else None)
return (
self.tgid,
self.tg_receiver,
self.peer_type,
self.mxid,
self.avatar_url,
self.encrypted,
self.username,
self.title,
self.about,
self.photo_id,
self.megagroup,
json.dumps(self.local_config) if self.local_config else None,
)
async def save(self) -> None:
q = (
+18 -6
View File
@@ -15,13 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from yarl import URL
from mautrix.types import UserID, SyncToken
from mautrix.types import SyncToken, UserID
from mautrix.util.async_db import Database
from ..types import TelegramID
@@ -92,10 +92,22 @@ class Puppet:
@property
def _values(self):
return (self.id, self.is_registered, self.displayname, self.displayname_source,
self.displayname_contact, self.displayname_quality, self.disable_updates,
self.username, self.photo_id, self.is_bot, self.custom_mxid, self.access_token,
self.next_batch, str(self.base_url) if self.base_url else None)
return (
self.id,
self.is_registered,
self.displayname,
self.displayname_source,
self.displayname_contact,
self.displayname_quality,
self.disable_updates,
self.username,
self.photo_id,
self.is_bot,
self.custom_mxid,
self.access_token,
self.next_batch,
str(self.base_url) if self.base_url else None,
)
async def save(self) -> None:
q = (
+13 -5
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from attr import dataclass
@@ -68,7 +68,15 @@ class TelegramFile:
" thumbnail, decryption_info) "
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
await self.db.execute(q, self.id, self.mxc, self.mime_type, self.was_converted, self.size,
self.width, self.height,
self.thumbnail.id if self.thumbnail else None,
self.decryption_info.json() if self.decryption_info else None)
await self.db.execute(
q,
self.id,
self.mxc,
self.mime_type,
self.was_converted,
self.size,
self.width,
self.height,
self.thumbnail.id if self.thumbnail else None,
self.decryption_info.json() if self.decryption_info else None,
)
+10 -7
View File
@@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING
import datetime
from typing import TYPE_CHECKING, ClassVar
import asyncio
import datetime
from telethon.sessions import MemorySession
from telethon.tl.types import updates, PeerUser, PeerChat, PeerChannel
from telethon.crypto import AuthKey
from telethon import utils
from telethon.crypto import AuthKey
from telethon.sessions import MemorySession
from telethon.tl.types import PeerChannel, PeerChat, PeerUser, updates
from mautrix.util.async_db import Database
@@ -97,7 +97,10 @@ class PgSession(MemorySession):
)
_tables: ClassVar[tuple[str, ...]] = (
"telethon_sessions", "telethon_entities", "telethon_sent_files", "telethon_update_state"
"telethon_sessions",
"telethon_entities",
"telethon_sent_files",
"telethon_update_state",
)
async def delete(self) -> None:
@@ -196,7 +199,7 @@ class PgSession(MemorySession):
ids = (
utils.get_peer_id(PeerUser(key)),
utils.get_peer_id(PeerChat(key)),
utils.get_peer_id(PeerChannel(key))
utils.get_peer_id(PeerChannel(key)),
)
if self.db.scheme == "postgres":
return await self._select_entity("id=ANY($1)", ids)
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from asyncpg import Connection
from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version"
@@ -40,8 +41,10 @@ async def upgrade_v1(conn: Connection, scheme: str) -> None:
async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None:
legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version:
raise RuntimeError("Legacy database is not on last version. Please upgrade the old "
"database with alembic or drop it completely first.")
raise RuntimeError(
"Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
if scheme != "sqlite":
await conn.execute(
"""
@@ -128,13 +131,24 @@ async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = {
"user": ("mxid", "tg_username", "tg_phone"),
"portal": (
"peer_type", "mxid", "username", "title", "about", "photo_id", "avatar_url", "config"
"peer_type",
"mxid",
"username",
"title",
"about",
"photo_id",
"avatar_url",
"config",
),
"message": ("mxid", "mx_room"),
"puppet": (
"displayname", "username", "photo_id",
) + (
"access_token", "custom_mxid", "next_batch", "base_url"
"displayname",
"username",
"photo_id",
"access_token",
"custom_mxid",
"next_batch",
"base_url",
),
"bot_chat": ("type",),
"telegram_file": ("id", "mxc", "mime_type", "thumbnail"),
+13 -6
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Iterable, ClassVar, TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, Iterable
from asyncpg import Record
from attr import dataclass
@@ -73,20 +73,25 @@ class User:
@property
def _values(self):
return (
self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts
self.mxid,
self.tgid,
self.tg_username,
self.tg_phone,
self.is_bot,
self.saved_contacts,
)
async def save(self) -> None:
q = (
'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 '
'WHERE mxid=$1'
"WHERE mxid=$1"
)
await self.db.execute(q, *self._values)
async def insert(self) -> None:
q = (
'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) '
'VALUES ($1, $2, $3, $4, $5, $6)'
"VALUES ($1, $2, $3, $4, $5, $6)"
)
await self.db.execute(q, *self._values)
@@ -122,8 +127,10 @@ class User:
await conn.executemany(q, records)
async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING')
q = (
'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING'
)
await self.db.execute(q, self.tgid, tgid, tg_receiver)
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: