Blacken and isort code
This commit is contained in:
@@ -15,15 +15,14 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from .upgrade import upgrade_table
|
||||
|
||||
from .bot_chat import BotChat
|
||||
from .message import Message
|
||||
from .portal import Portal
|
||||
from .puppet import Puppet
|
||||
from .telegram_file import TelegramFile
|
||||
from .user import User
|
||||
from .telethon_session import PgSession
|
||||
from .upgrade import upgrade_table
|
||||
from .user import User
|
||||
|
||||
|
||||
def init(db: Database) -> None:
|
||||
@@ -31,5 +30,14 @@ def init(db: Database) -> None:
|
||||
table.db = db
|
||||
|
||||
|
||||
__all__ = ["upgrade_table", "init", "Portal", "Message", "User", "Puppet", "TelegramFile",
|
||||
"BotChat", "PgSession"]
|
||||
__all__ = [
|
||||
"upgrade_table",
|
||||
"init",
|
||||
"Portal",
|
||||
"Message",
|
||||
"User",
|
||||
"Puppet",
|
||||
"TelegramFile",
|
||||
"BotChat",
|
||||
"PgSession",
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import RoomID, EventID
|
||||
from mautrix.types import EventID, RoomID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
@@ -92,9 +92,12 @@ class Message:
|
||||
|
||||
@classmethod
|
||||
async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
|
||||
return await cls.db.fetchval(
|
||||
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
|
||||
) or 0
|
||||
return (
|
||||
await cls.db.fetchval(
|
||||
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
import json
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
import attr
|
||||
|
||||
from mautrix.types import RoomID, ContentURI
|
||||
from mautrix.types import ContentURI, RoomID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
@@ -93,9 +93,20 @@ class Portal:
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (self.tgid, self.tg_receiver, self.peer_type, self.mxid, self.avatar_url,
|
||||
self.encrypted, self.username, self.title, self.about, self.photo_id,
|
||||
self.megagroup, json.dumps(self.local_config) if self.local_config else None)
|
||||
return (
|
||||
self.tgid,
|
||||
self.tg_receiver,
|
||||
self.peer_type,
|
||||
self.mxid,
|
||||
self.avatar_url,
|
||||
self.encrypted,
|
||||
self.username,
|
||||
self.title,
|
||||
self.about,
|
||||
self.photo_id,
|
||||
self.megagroup,
|
||||
json.dumps(self.local_config) if self.local_config else None,
|
||||
)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.types import UserID, SyncToken
|
||||
from mautrix.types import SyncToken, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
@@ -92,10 +92,22 @@ class Puppet:
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (self.id, self.is_registered, self.displayname, self.displayname_source,
|
||||
self.displayname_contact, self.displayname_quality, self.disable_updates,
|
||||
self.username, self.photo_id, self.is_bot, self.custom_mxid, self.access_token,
|
||||
self.next_batch, str(self.base_url) if self.base_url else None)
|
||||
return (
|
||||
self.id,
|
||||
self.is_registered,
|
||||
self.displayname,
|
||||
self.displayname_source,
|
||||
self.displayname_contact,
|
||||
self.displayname_quality,
|
||||
self.disable_updates,
|
||||
self.username,
|
||||
self.photo_id,
|
||||
self.is_bot,
|
||||
self.custom_mxid,
|
||||
self.access_token,
|
||||
self.next_batch,
|
||||
str(self.base_url) if self.base_url else None,
|
||||
)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@@ -68,7 +68,15 @@ class TelegramFile:
|
||||
" thumbnail, decryption_info) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
await self.db.execute(q, self.id, self.mxc, self.mime_type, self.was_converted, self.size,
|
||||
self.width, self.height,
|
||||
self.thumbnail.id if self.thumbnail else None,
|
||||
self.decryption_info.json() if self.decryption_info else None)
|
||||
await self.db.execute(
|
||||
q,
|
||||
self.id,
|
||||
self.mxc,
|
||||
self.mime_type,
|
||||
self.was_converted,
|
||||
self.size,
|
||||
self.width,
|
||||
self.height,
|
||||
self.thumbnail.id if self.thumbnail else None,
|
||||
self.decryption_info.json() if self.decryption_info else None,
|
||||
)
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
from telethon.sessions import MemorySession
|
||||
from telethon.tl.types import updates, PeerUser, PeerChat, PeerChannel
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon import utils
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions import MemorySession
|
||||
from telethon.tl.types import PeerChannel, PeerChat, PeerUser, updates
|
||||
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
@@ -97,7 +97,10 @@ class PgSession(MemorySession):
|
||||
)
|
||||
|
||||
_tables: ClassVar[tuple[str, ...]] = (
|
||||
"telethon_sessions", "telethon_entities", "telethon_sent_files", "telethon_update_state"
|
||||
"telethon_sessions",
|
||||
"telethon_entities",
|
||||
"telethon_sent_files",
|
||||
"telethon_update_state",
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
@@ -196,7 +199,7 @@ class PgSession(MemorySession):
|
||||
ids = (
|
||||
utils.get_peer_id(PeerUser(key)),
|
||||
utils.get_peer_id(PeerChat(key)),
|
||||
utils.get_peer_id(PeerChannel(key))
|
||||
utils.get_peer_id(PeerChannel(key)),
|
||||
)
|
||||
if self.db.scheme == "postgres":
|
||||
return await self._select_entity("id=ANY($1)", ids)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from asyncpg import Connection
|
||||
|
||||
from . import upgrade_table
|
||||
|
||||
legacy_version_query = "SELECT version_num FROM alembic_version"
|
||||
@@ -40,8 +41,10 @@ async def upgrade_v1(conn: Connection, scheme: str) -> None:
|
||||
async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None:
|
||||
legacy_version = await conn.fetchval(legacy_version_query)
|
||||
if legacy_version != last_legacy_version:
|
||||
raise RuntimeError("Legacy database is not on last version. Please upgrade the old "
|
||||
"database with alembic or drop it completely first.")
|
||||
raise RuntimeError(
|
||||
"Legacy database is not on last version. "
|
||||
"Please upgrade the old database with alembic or drop it completely first."
|
||||
)
|
||||
if scheme != "sqlite":
|
||||
await conn.execute(
|
||||
"""
|
||||
@@ -128,13 +131,24 @@ async def varchar_to_text(conn: Connection) -> None:
|
||||
columns_to_adjust = {
|
||||
"user": ("mxid", "tg_username", "tg_phone"),
|
||||
"portal": (
|
||||
"peer_type", "mxid", "username", "title", "about", "photo_id", "avatar_url", "config"
|
||||
"peer_type",
|
||||
"mxid",
|
||||
"username",
|
||||
"title",
|
||||
"about",
|
||||
"photo_id",
|
||||
"avatar_url",
|
||||
"config",
|
||||
),
|
||||
"message": ("mxid", "mx_room"),
|
||||
"puppet": (
|
||||
"displayname", "username", "photo_id",
|
||||
) + (
|
||||
"access_token", "custom_mxid", "next_batch", "base_url"
|
||||
"displayname",
|
||||
"username",
|
||||
"photo_id",
|
||||
"access_token",
|
||||
"custom_mxid",
|
||||
"next_batch",
|
||||
"base_url",
|
||||
),
|
||||
"bot_chat": ("type",),
|
||||
"telegram_file": ("id", "mxc", "mime_type", "thumbnail"),
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, ClassVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, Iterable
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
@@ -73,20 +73,25 @@ class User:
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts
|
||||
self.mxid,
|
||||
self.tgid,
|
||||
self.tg_username,
|
||||
self.tg_phone,
|
||||
self.is_bot,
|
||||
self.saved_contacts,
|
||||
)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 '
|
||||
'WHERE mxid=$1'
|
||||
"WHERE mxid=$1"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) '
|
||||
'VALUES ($1, $2, $3, $4, $5, $6)'
|
||||
"VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
@@ -122,8 +127,10 @@ class User:
|
||||
await conn.executemany(q, records)
|
||||
|
||||
async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
|
||||
q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
|
||||
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING')
|
||||
q = (
|
||||
'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
|
||||
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING'
|
||||
)
|
||||
await self.db.execute(q, self.tgid, tgid, tg_receiver)
|
||||
|
||||
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
|
||||
|
||||
Reference in New Issue
Block a user