Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.media import ReplicationCopyMediaServlet
from synapse.replication.http.media import (
ReplicationCopyMediaServlet,
ReplicationDeleteMediaServlet,
)
from synapse.state import CREATE_KEY, POWER_KEY
from synapse.storage.databases.main.media_repository import (
LocalMedia,
Expand Down Expand Up @@ -186,6 +189,13 @@ async def copy_media(
"Sorry Mario, your MediaRepository related function is in another castle"
)

async def _remove_local_media_from_disk(
self, media_ids: List[str]
) -> Tuple[List[str], int]:
raise NotImplementedError(
"Sorry Mario, your MediaRepository related function is in another castle"
)

async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
raise NotImplementedError(
"Sorry Mario, your MediaRepository related function is in another castle"
Expand Down Expand Up @@ -581,6 +591,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# initialize replication endpoint here
self.copy_media_client = ReplicationCopyMediaServlet.make_client(hs)
self.delete_media_client = ReplicationDeleteMediaServlet.make_client(hs)

async def copy_media(
self, existing_mxc: MXCUri, auth_user: UserID, max_timeout_ms: int
Expand All @@ -597,6 +608,18 @@ async def copy_media(
)
return MXCUri.from_str(result["content_uri"])

async def _remove_local_media_from_disk(
self, media_ids: List[str]
) -> Tuple[List[str], int]:
"""
Call out to the worker responsible for handling media to delete this media object
"""
result = await self.delete_media_client(
instance_name=self.hs.config.worker.workers_doing_media_duty[0],
media_ids=media_ids,
)
return result["deleted"], result["count"]


class MediaRepository(AbstractMediaRepository):
def __init__(self, hs: "HomeServer"):
Expand Down
41 changes: 41 additions & 0 deletions synapse/replication/http/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,46 @@ async def _handle_request( # type: ignore[override]
return 200, {"content_uri": str(mxc_uri)}


class ReplicationDeleteMediaServlet(ReplicationEndpoint):
"""Request the MediaRepository to delete a piece of media from filesystem.

Request format:

DELETE /_synapse/replication/delete_media

{
"media_ids": [...], # List of media IDs to delete
}

"""

NAME = "delete_media"
PATH_ARGS = ()

def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.media_repo = hs.get_media_repository()

@staticmethod
async def _serialize_payload( # type: ignore[override]
media_ids: list[str],
) -> JsonDict:
"""
Args:
media_ids: The list of media IDs to delete.
"""
return {"media_ids": media_ids}

async def _handle_request( # type: ignore[override]
self,
request: Request,
content: JsonDict,
) -> Tuple[int, JsonDict]:
media_ids = content["media_ids"]
deleted, count = await self.media_repo._remove_local_media_from_disk(media_ids)
return 200, {"deleted": deleted, "count": count}


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationCopyMediaServlet(hs).register(http_server)
ReplicationDeleteMediaServlet(hs).register(http_server)
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
@wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't
been censored yet.
been censored yet and deletes any media attached to the redacted events.

By censor we mean update the event_json table with the redacted event.
"""
Expand Down Expand Up @@ -104,12 +104,16 @@ async def _censor_redactions(self) -> None:
)

updates = []
media = []

for redaction_id, event_id in rows:
redaction_event = await self.get_event(redaction_id, allow_none=True)
original_event = await self.get_event(
event_id, allow_rejected=True, allow_none=True
)
attached_media_ids = (
await self.hs.get_datastores().main.get_attached_media_ids(event_id)
)

# The SQL above ensures that we have both the redaction and
# original event, so if the `get_event` calls return None it
Expand All @@ -131,6 +135,7 @@ async def _censor_redactions(self) -> None:
pruned_json = None

updates.append((redaction_id, event_id, pruned_json))
media.extend(attached_media_ids)

def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates:
Expand All @@ -145,6 +150,7 @@ def _update_censor_txn(txn: LoggingTransaction) -> None:
)

await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
await self.hs.get_media_repository()._remove_local_media_from_disk(media)

def _censor_event_txn(
self, txn: LoggingTransaction, event_id: str, pruned_json: str
Expand Down
37 changes: 37 additions & 0 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import json
import logging
from enum import Enum
from http import HTTPStatus
Expand All @@ -45,6 +46,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder

Expand Down Expand Up @@ -1022,6 +1024,11 @@ def delete_remote_media_txn(txn: LoggingTransaction) -> None:
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
self.db_pool.simple_delete_txn(
txn,
"media_attachments",
keyvalues={"server_name": media_origin, "media_id": media_id},
)

await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
Expand Down Expand Up @@ -1413,3 +1420,33 @@ def _get_reference_count_txn(txn: LoggingTransaction) -> int:
return await self.db_pool.runInteraction(
"get_media_reference_count_for_sha256", _get_reference_count_txn
)

async def get_attached_media_ids(self, event_id: str) -> list[str]:
"""
Get a list of media_ids that are attached to a specific event_id.
"""

def get_attached_media_ids_txn(txn: LoggingTransaction) -> list[str]:
if isinstance(self.db_pool.engine, PostgresEngine):
# Use GIN index for Postgres
sql = """
SELECT media_id
FROM media_attachments
WHERE restrictions_json @> %s AND server_name = %s
"""
json_param = json.dumps({"restrictions": {"event_id": event_id}})
txn.execute(sql, (json_param, self.hs.hostname))
else:
sql = """
SELECT media_id
FROM media_attachments
WHERE restrictions_json->'restrictions'->>'event_id' = ? AND server_name = ?
"""
txn.execute(sql, (event_id, self.hs.hostname))

return [row[0] for row in txn.fetchall()]

return await self.db_pool.runInteraction(
"get_attached_media_ids",
get_attached_media_ids_txn,
)
168 changes: 168 additions & 0 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from twisted.web.server import Request

from synapse.api.constants import EventTypes, HistoryVisibility
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.media._base import FileInfo
from synapse.media.media_repository import MediaRepository
from synapse.rest import admin
Expand Down Expand Up @@ -967,6 +968,173 @@ def test_copy_remote_restricted_resource_fails_when_requester_does_not_have_acce
self.assertEqual(channel.code, 403)


class DeleteRestrictedMediaOnEventRedactionReplicationTestCase(
BaseMultiWorkerStreamTestCase
):
"""
Tests that media attached to redacted events are deleted after the retention period
when `msc3911.enabled` is configured to be True.
"""

servlets = [
login.register_servlets,
admin.register_servlets,
room.register_servlets,
]
use_isolated_media_paths = True

def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config.update(
{
"experimental_features": {"msc3911": {"enabled": True}},
"media_repo_instances": ["media_worker_1"],
"run_background_tasks_on": MAIN_PROCESS_INSTANCE_NAME,
"redaction_retention_period": "7d",
}
)
config["instance_map"] = {
"main": {"host": "testserv", "port": 8765},
"media_worker_1": {"host": "testserv", "port": 1001},
}

return config

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "testpass")
self.user_tok = self.login("user", "testpass")
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")

def test_delete_media_on_event_redaction(self) -> None:
"""
Tests that media is deleted when its attached event is redacted.
"""
# Make sure that censor_redaction loops runs on main hs
assert self.hs.config.worker.run_background_tasks, (
"Main HS should run background tasks"
)
assert self.hs.config.server.redaction_retention_period is not None, (
"Redaction retention should be configured"
)

# Create media worker and it does not run the background tasks
media_worker = self.make_worker_hs(
"synapse.app.generic_worker",
{
"worker_name": "media_worker_1",
"run_background_tasks_on": MAIN_PROCESS_INSTANCE_NAME,
},
)
media_worker.get_media_repository_resource().register_servlets(
self._hs_to_site[media_worker].resource, media_worker
)
media.register_servlets(media_worker, self._hs_to_site[media_worker].resource)
media_repo = media_worker.get_media_repository()

assert not media_worker.config.worker.run_background_tasks, (
"Worker should not run background tasks"
)

# Create a private room
room_id = self.helper.create_room_as(
self.user,
is_public=False,
tok=self.user_tok,
)

# The media is created with user_tok
content = io.BytesIO(SMALL_PNG)
content_uri = self.get_success(
media_repo.create_or_update_content(
"image/png",
"test_png_upload",
content,
67,
UserID.from_string(self.user),
restricted=True,
)
)
media_id = content_uri.media_id

# User sends a message with media
channel = self.make_request(
"PUT",
f"/rooms/{room_id}/send/m.room.message/{str(time.time())}?org.matrix.msc3911.attach_media={str(content_uri)}",
content={"msgtype": "m.text", "body": "Hi, this is a message"},
access_token=self.user_tok,
)
assert channel.code == HTTPStatus.OK, channel.json_body
assert "event_id" in channel.json_body
event_id = channel.json_body["event_id"]

# Redact the event
channel = self.make_request(
"POST",
f"/_matrix/client/r0/rooms/{room_id}/redact/{event_id}",
content={},
access_token=self.user_tok,
)
assert channel.code == HTTPStatus.OK, channel.json_body

# Verify the event is redacted before censoring
event_dict = self.helper.get_event(room_id, event_id, self.user_tok)
assert "redacted_because" in event_dict, "Event should be redacted"

# Media should still be accessible before retention period is over
channel = make_request(
self.reactor,
self._hs_to_site[media_worker],
"GET",
f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}?allow_redacted_media=true",
shorthand=False,
access_token=self.user_tok,
)
assert channel.code == 200, channel.result
assert channel.result["body"] == SMALL_PNG

# Fast forward 7 days and 6 minutes to make sure the censor_redactions looping
# call detects the events are eligible for censorship.
self.reactor.advance(7 * 24 * 60 * 60 + 6 * 60)

# Since we fast forward the reactor time, give some moment for the background
# censor redactions task to get caught up.
self.pump(0.01)

# Check that the media has been deleted from the database
deleted_media = self.get_success(
self.hs.get_datastores().main.get_local_media(media_id)
)
assert deleted_media is None, deleted_media

# Check if the file is deleted from the storage as well.
assert isinstance(media_repo, MediaRepository)
assert not os.path.exists(media_repo.filepaths.local_media_filepath(media_id))

# Verify the redaction was censored in the database
redaction_censored = self.get_success(
self.hs.get_datastores().main.db_pool.simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": event_id},
retcol="have_censored",
)
)
assert redaction_censored, (
"Redaction should have been censored by _censor_redactions loop"
)

channel = make_request(
self.reactor,
self._hs_to_site[media_worker],
"GET",
f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}?allow_redacted_media=true",
shorthand=False,
access_token=self.user_tok,
)
assert channel.code == 404, channel.result
assert channel.json_body["errcode"] == "M_NOT_FOUND"


def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
Loading
Loading