diff --git a/alembic/versions/501dad2868bc_move_sessions_to_main_database.py b/alembic/versions/501dad2868bc_move_sessions_to_main_database.py new file mode 100644 index 00000000..1878f30b --- /dev/null +++ b/alembic/versions/501dad2868bc_move_sessions_to_main_database.py @@ -0,0 +1,111 @@ +"""Move sessions to main database + +Revision ID: 501dad2868bc +Revises: 7d47d84380b6 +Create Date: 2018-03-02 19:15:53.826985 + +""" +from alembic import op +import sqlalchemy as sa +import sqlite3 +import os + +# revision identifiers, used by Alembic. +revision = '501dad2868bc' +down_revision = '7d47d84380b6' +branch_labels = None +depends_on = None + + +def upgrade(): + Session = op.create_table('telethon_sessions', + sa.Column('session_id', sa.VARCHAR(), nullable=False), + sa.Column('dc_id', sa.INTEGER(), nullable=False), + sa.Column('server_address', sa.VARCHAR(), nullable=True), + sa.Column('port', sa.INTEGER(), nullable=True), + sa.Column('auth_key', sa.BLOB(), nullable=True), + sa.PrimaryKeyConstraint('session_id', 'dc_id')) + SentFile = op.create_table('telethon_sent_files', + sa.Column('session_id', sa.VARCHAR(), nullable=False), + sa.Column('md5_digest', sa.BLOB(), nullable=False), + sa.Column('file_size', sa.INTEGER(), nullable=False), + sa.Column('type', sa.INTEGER(), nullable=False), + sa.Column('id', sa.INTEGER(), nullable=True), + sa.Column('hash', sa.INTEGER(), nullable=True), + sa.PrimaryKeyConstraint('session_id', 'md5_digest', 'file_size', + 'type')) + Entity = op.create_table('telethon_entities', + sa.Column('session_id', sa.VARCHAR(), nullable=False), + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('hash', sa.INTEGER(), nullable=False), + sa.Column('username', sa.VARCHAR(), nullable=True), + sa.Column('phone', sa.INTEGER(), nullable=True), + sa.Column('name', sa.VARCHAR(), nullable=True), + sa.PrimaryKeyConstraint('session_id', 'id')) + Version = op.create_table('telethon_version', + sa.Column('version', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('version')) + conn = op.get_bind() + sessions = [os.path.basename(f) for f in os.listdir(".") if f.endswith(".session")] + for session in sessions: + session_to_sqlalchemy(conn, session, Session, SentFile, Entity) + + +def session_to_sqlalchemy(conn, path, Session, SentFile, Entity): + session_conn = sqlite3.connect(path) + session_id = os.path.splitext(path)[0] + c = session_conn.cursor() + + auth_data_tuples = c.execute("SELECT * FROM sessions").fetchall() + auth_data_dicts = [] + for row in auth_data_tuples: + dc_id, server_address, port, auth_key = row + auth_data_dicts.append({ + "session_id": session_id, + "dc_id": dc_id, + "server_address": server_address, + "port": port, + "auth_key": auth_key, + }) + if auth_data_dicts: + conn.execute(Session.insert().values(auth_data_dicts)) + + sent_file_tuples = c.execute("SELECT * FROM sent_files").fetchall() + sent_file_dicts = [] + for row in sent_file_tuples: + md5_digest, file_size, type, id, hash = row + sent_file_dicts.append({ + "session_id": session_id, + "md5_digest": md5_digest, + "file_size": file_size, + "type": type, + "id": id, + "hash": hash, + }) + if sent_file_dicts: + conn.execute(SentFile.insert().values(sent_file_dicts)) + + entity_tuples = c.execute("SELECT * FROM entities").fetchall() + entity_dicts = [] + for row in entity_tuples: + id, hash, username, phone, name = row + entity_dicts.append({ + "session_id": session_id, + "id": id, + "hash": hash, + "username": username, + "phone": phone, + "name": name, + }) + if entity_dicts: + conn.execute(Entity.insert().values(entity_dicts)) + + c.close() + session_conn.close() + + +def downgrade(): + op.drop_table('telethon_version') + op.drop_table('telethon_entities') + op.drop_table('telethon_sent_files') + op.drop_table('telethon_sessions') diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py index 0bcf7b95..8e981c41 100644 --- a/mautrix_telegram/__main__.py +++ b/mautrix_telegram/__main__.py @@ -22,6 +22,7 @@ import asyncio import sqlalchemy as sql from sqlalchemy import orm +from telethon.sessions import AlchemySessionContainer from mautrix_appservice import AppService from .base import Base @@ -76,14 +77,17 @@ db_factory = orm.sessionmaker(bind=db_engine) db_session = orm.scoping.scoped_session(db_factory) Base.metadata.bind = db_engine +telethon_session_container = AlchemySessionContainer(engine=db_engine, session=db_session, + table_base=Base, table_prefix="telethon_", + manage_tables=False) + loop = asyncio.get_event_loop() appserv = AppService(config["homeserver.address"], config["homeserver.domain"], config["appservice.as_token"], config["appservice.hs_token"], config["appservice.bot_username"], log="mau.as", loop=loop) - -context = Context(appserv, db_session, config, loop, None, None) +context = Context(appserv, db_session, config, loop, None, None, telethon_session_container) if config["appservice.public.enabled"]: public = PublicBridgeWebsite(loop) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index 89fa5e4c..b8dc492c 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -30,6 +30,7 @@ MAX_DELETIONS = 10 class AbstractUser: + session_container = None loop = None log = None db = None @@ -46,9 +47,10 @@ class AbstractUser: self.log.debug(f"Initializing client for {self.name}") device = f"{platform.system()} {platform.release()}" sysversion = MautrixTelegramClient.__version__ - self.client = MautrixTelegramClient(self.name, - config["telegram.api_id"], - config["telegram.api_hash"], + self.session = self.session_container.new_session(self.name) + self.client = MautrixTelegramClient(session=self.session, + api_id=config["telegram.api_id"], + api_hash=config["telegram.api_hash"], loop=self.loop, app_version=__version__, system_version=sysversion, @@ -290,4 +292,5 @@ class AbstractUser: def init(context): global config, MAX_DELETIONS AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context + AbstractUser.session_container = context.telethon_session_container MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py index 6960321c..5b7fb8cf 100644 --- a/mautrix_telegram/context.py +++ b/mautrix_telegram/context.py @@ -17,13 +17,14 @@ class Context: - def __init__(self, az, db, config, loop, bot, mx): + def __init__(self, az, db, config, loop, bot, mx, telethon_session_container): self.az = az self.db = db self.config = config self.loop = loop self.bot = bot self.mx = mx + self.telethon_session_container = telethon_session_container def __iter__(self): yield self.az