Add option to sync portals in backfill queue

This commit is contained in:
Tulir Asokan
2022-10-14 13:55:12 +03:00
parent af2f20f7b2
commit 0bbf64d240
10 changed files with 315 additions and 116 deletions
+80 -27
View File
@@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar
from datetime import datetime, timedelta
from enum import Enum
import json
from asyncpg import Record
from attr import dataclass
@@ -29,6 +31,11 @@ from ..types import TelegramID
fake_db = Database.create("") if TYPE_CHECKING else None
class BackfillType(Enum):
HISTORICAL = "historical"
SYNC_DIALOG = "sync_dialog"
@dataclass
class Backfill:
db: ClassVar[Database] = fake_db
@@ -36,9 +43,11 @@ class Backfill:
queue_id: int | None
user_mxid: UserID
priority: int
type: BackfillType
portal_tgid: TelegramID
portal_tg_receiver: TelegramID
anchor_msg_id: TelegramID | None
extra_data: dict[str, Any]
messages_per_batch: int
post_batch_delay: int
max_batches: int
@@ -50,10 +59,12 @@ class Backfill:
def new(
user_mxid: UserID,
priority: int,
type: BackfillType,
portal_tgid: TelegramID,
portal_tg_receiver: TelegramID,
messages_per_batch: int,
anchor_msg_id: TelegramID | None = None,
extra_data: dict[str, Any] | None = None,
post_batch_delay: int = 0,
max_batches: int = -1,
) -> "Backfill":
@@ -61,9 +72,11 @@ class Backfill:
queue_id=None,
user_mxid=user_mxid,
priority=priority,
type=type,
portal_tgid=portal_tgid,
portal_tg_receiver=portal_tg_receiver,
anchor_msg_id=anchor_msg_id,
extra_data=extra_data or {},
messages_per_batch=messages_per_batch,
post_batch_delay=post_batch_delay,
max_batches=max_batches,
@@ -76,14 +89,19 @@ class Backfill:
def _from_row(cls, row: Record | None) -> Backfill | None:
if row is None:
return None
return cls(**row)
data = {**row}
type = BackfillType(data.pop("type"))
extra_data = json.loads(data.pop("extra_data", None) or "{}")
return cls(**data, type=type, extra_data=extra_data)
columns = [
"user_mxid",
"priority",
"type",
"portal_tgid",
"portal_tg_receiver",
"anchor_msg_id",
"extra_data",
"messages_per_batch",
"post_batch_delay",
"max_batches",
@@ -118,22 +136,37 @@ class Backfill:
)
@classmethod
async def get(
async def delete_existing(
cls,
user_mxid: UserID,
portal_tgid: int,
portal_tg_receiver: int,
type: BackfillType,
) -> Backfill | None:
q = f"""
SELECT queue_id, {cls.columns_str}
FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
ORDER BY priority, queue_id
LIMIT 1
WITH deleted_entries AS (
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING 1
)
WITH dispatched_entries AS (
SELECT 1 FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NOT NULL
AND completed_at IS NULL
)
"""
return cls._from_row(await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver))
return cls._from_row(
await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver, type.value)
)
@classmethod
async def delete_all(cls, user_mxid: UserID) -> None:
@@ -144,27 +177,47 @@ class Backfill:
q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2"
await cls.db.execute(q, tgid, tg_receiver)
async def insert(self) -> None:
async def insert(self) -> list[Backfill]:
delete_q = f"""
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING {self.columns_str}
"""
q = f"""
INSERT INTO backfill_queue ({self.columns_str})
VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))})
RETURNING queue_id
"""
row = await self.db.fetchrow(
q,
self.user_mxid,
self.priority,
self.portal_tgid,
self.portal_tg_receiver,
self.anchor_msg_id,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
self.queue_id = row["queue_id"]
async with self.db.acquire() as conn, conn.transaction():
deleted_rows = await self.db.fetch(
delete_q,
self.user_mxid,
self.portal_tgid,
self.portal_tg_receiver,
self.type.value,
)
self.queue_id = await self.db.fetchval(
q,
self.user_mxid,
self.priority,
self.type.value,
self.portal_tgid,
self.portal_tg_receiver,
self.anchor_msg_id,
json.dumps(self.extra_data) if self.extra_data else None,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
return [self._from_row(row) for row in deleted_rows]
async def mark_dispatched(self) -> None:
q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"