diff --git a/README.md b/README.md index aa8aaac..47c1f87 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ options: { "name": "Example Scenario", // name of scenario "address": "127.0.0.1:8080", // port to run dispatcher on + "type": "sam", // whether to use sam or denim infrastructure (valid: sam, denim) "clients": 1, // how many clients to register /* How many groups of clients that should communicate. Each group has at least one denim client that communicates with a denim client from another group. diff --git a/pyproject.toml b/pyproject.toml index 860f9d9..45d4cca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ name = "sam_dispatcher" version = "0.1.0" authors = [{ name = "SAM Research" }] -dependencies = ["fastapi", "uvicorn", "pydantic"] +dependencies = ["fastapi", "uvicorn", "pydantic", "asyncio"] [project.scripts] sam-dispatch = "sam_dispatcher.server:main" + +[project.optional-dependencies] +test = ["pytest", "pytest-asyncio"] diff --git a/src/sam_dispatcher/server.py b/src/sam_dispatcher/server.py index d225423..c093b75 100644 --- a/src/sam_dispatcher/server.py +++ b/src/sam_dispatcher/server.py @@ -1,37 +1,72 @@ from fastapi import FastAPI from argparse import ArgumentParser import uvicorn -from .state import State, ClientReport -from fastapi.responses import JSONResponse -from fastapi import Request +from .state import State, ClientReport, AccountId +from fastapi import Request, Response, HTTPException +import asyncio app = FastAPI() state = None +def auth(request: Request): + try: + client_id = request.cookies.get("id") + except: + raise HTTPException(status_code=401) + host = request.client.host + _id = create_id(host, client_id) + + if not state.is_auth(_id): + raise HTTPException(status_code=401) + return _id + + +def create_id(host: str, id: str): + return f"{host}#{id}" + + @app.get("/client") -async def client(request: Request): - client_data = state.get_client(request.client.host) +async def client(request: Request, response: Response): + client_id = await state.next_client_id() + response.set_cookie(key="id", value=client_id) + client_data = await state.get_client(create_id(request.client.host, client_id)) if client_data is None: - return JSONResponse( - status_code=403, content={"error": "Clients have been depleted"} - ) + raise HTTPException(status_code=403) return client_data -@app.post("/ready") -async def ready(request: Request): - state.ready(request.client.host) +@app.post("/id") +async def upload_id(request: Request, account_id: AccountId): + _id = auth(request) + await state.set_account_id(_id, account_id.account_id) -@app.get("/start") -async def start(): - return {"start": state.clients_ready, "epoch": state.start_time} +@app.get("/sync") +async def sync(request: Request): + return await state.start(auth(request)) @app.post("/upload") async def upload(request: Request, report: ClientReport): - state.report(request.client.host, report) + _id = auth(request) + await state.report(_id, report) + + if not state.all_clients_have_uploaded: + return + state.save_report() + + +@app.get("/health") +async def health(): + return "OK" + + +async def setup_state(path: str): + global state + state = State(path) + await state.init_state() + return state.scenario.address.split(":") def main(): @@ -43,8 +78,5 @@ def main(): args = parser.parse_args() config_path: str = args.config - state = State(config_path) - - ip, port = state.scenario.address.split(":") - + ip, port = asyncio.run(setup_state(config_path)) uvicorn.run("sam_dispatcher.server:app", host=ip, port=int(port)) diff --git a/src/sam_dispatcher/state.py b/src/sam_dispatcher/state.py index f8dad3c..105ee61 100644 --- a/src/sam_dispatcher/state.py +++ b/src/sam_dispatcher/state.py @@ -5,10 +5,13 @@ import math import time from pathlib import Path +from asyncio import Lock +import asyncio class Scenario(BaseModel): name: str + type: str = Field() address: str = Field(alias="address", default="127.0.0.1:8080") clients: int = Field(alias="clients") groups: List[float] = Field(alias="groups") @@ -17,7 +20,6 @@ class Scenario(BaseModel): message_size_range: Tuple[int, int] = Field(alias="messageSizeRange") denim_probability: float = Field(alias="denimProbability") send_rate_range: Tuple[int, int] = Field(alias="sendRateRange") - start_epoch: int = Field(alias="startEpoch", default=10) report: str = Field(alias="report", default="report.json") @@ -29,6 +31,7 @@ class Friend(BaseModel): class Client(BaseModel): username: str = Field() + client_type: str = Field(alias="clientType") message_size_range: Tuple[int, int] = Field(alias="messageSizeRange") send_rate: int = Field(alias="sendRate") tick_millis: int = Field(alias="tickMillis") @@ -38,24 +41,33 @@ class Client(BaseModel): class MessageLog(BaseModel): - type: str = Field() # denim, regular, status, other protocol - to: str = Field() # server or username + type: str = Field() # denim, regular + to: str = Field() from_: str = Field(alias="from") size: int = Field() - timestamp: int = Field() + tick: int = Field() model_config = {"populate_by_name": True} class ClientReport(BaseModel): - websocket_port: int = Field(alias="websocketPort") + start_time: int = Field(alias="startTime") messages: List[MessageLog] +class AccountId(BaseModel): + account_id: str = Field(alias="accountId") + + class Report(BaseModel): scenario: Scenario - clients: dict[str, Client] - reports: dict[str, ClientReport] + ip_addresses: Dict[str, str] = Field(alias="ipAddresses") + clients: Dict[str, Client] + reports: Dict[str, ClientReport] + + +class StartInfo(BaseModel): + friends: dict[str, str] = Field() class ReportWriter: @@ -68,11 +80,12 @@ def write(self, path: str, report: Report): _dir = Path("reports/") _dir.mkdir(exist_ok=True) with open(_dir / path, "w") as f: - f.write(report.model_dump_json(indent=2)) + f.write(report.model_dump_json(indent=2, by_alias=True)) class State: def __init__(self, scenario: str | Scenario, writer: ReportWriter = None): + self.lock = Lock() if isinstance(scenario, str): with open(scenario, "r") as f: self.scenario = Scenario.model_validate_json(f.read()) @@ -89,16 +102,37 @@ def __init__(self, scenario: str | Scenario, writer: ReportWriter = None): raise RuntimeError("Groups must add up to 1") self.clients: dict[str, Client] = dict() - self.create_clients() - self.free_clients: set[str] = {c.username for c in self.clients.values()} + self.free_clients: set[str] = set() # ip, username - self.ips: dict[str, str] = dict() + self.usernames: dict[str, str] = dict() self.ready_clients: set[str] = set() - self.start_time = int(time.time()) + # username, account_id + self.account_ids: dict[str, str] = dict() self.reports: dict[str, ClientReport] = dict() + self.client_counter = 0 + + def _reset(self): + self.clients = dict() + self.free_clients = set() + self.usernames = dict() + self.ready_clients = set() + + self.reports = dict() + + self.client_counter = 0 + + async def next_client_id(self): + id = self.client_counter + async with self.lock: + self.client_counter += 1 + return id + + def is_auth(self, ip_id: str): + return ip_id in self.usernames + @property def client_amount(self): return len(self.clients) @@ -106,39 +140,62 @@ def client_amount(self): @property def clients_ready(self): return ( - all(ip in self.ready_clients for ip in self.ips) + all(ip in self.ready_clients for ip in self.usernames) + and all(user in self.account_ids for user in self.usernames.values()) and len(self.free_clients) == 0 ) - def ready(self, ip: str): - self.ready_clients.add(ip) - if self.clients_ready: - self.start_time = int(time.time()) + self.scenario.start_epoch - - def get_client(self, ip: str) -> Optional[Client]: - if len(self.free_clients) == 0: - return None - name = self.free_clients.pop() - self.ips[ip] = name - return self.clients[name] - - def report(self, ip: str, report: ClientReport): - username = self.ips[ip] - self.reports[username] = report - usernames = set(self.ips.values()) - if not all(user in usernames for user in self.reports): - return - self.save_report() + @property + def all_clients_have_uploaded(self): + usernames = set(self.usernames.values()) + return all(user in usernames for user in self.reports) + + async def set_account_id(self, ip_id: str, account_id: AccountId): + async with self.lock: + username = self.usernames[ip_id] + self.account_ids[username] = account_id + + async def start(self, ip_id: str) -> StartInfo: + await self._ready(ip_id) + while not self.clients_ready: + await asyncio.sleep(0.5) + + friends = self.clients[self.usernames[ip_id]].friends + friends = dict(map(lambda x: (x[0], self.account_ids[x[0]]), friends.items())) + + return StartInfo(friends=friends) + + async def get_client(self, ip_id: str) -> Optional[Client]: + async with self.lock: + if len(self.free_clients) == 0: + return None + name = self.free_clients.pop() + self.usernames[ip_id] = name + return self.clients[name] + + async def report(self, ip_id: str, report: ClientReport): + async with self.lock: + username = self.usernames[ip_id] + self.reports[username] = report def save_report(self): + ips = {v: k.split("#")[0] for k, v in self.usernames.items()} report = Report( - scenario=self.scenario, clients=self.clients, reports=self.reports + scenario=self.scenario, + ipAddresses=ips, + clients=self.clients, + reports=self.reports, ) self.writer.write(self.scenario.report, report) + async def _ready(self, ip: str): + async with self.lock: + self.ready_clients.add(ip) + @staticmethod - def init_client( + def _init_client( username: str, + client_type: str, msg_range: tuple[int, int], send_rate: int, tick_millis: int, @@ -147,6 +204,7 @@ def init_client( ): return Client( username=username, + clientType=client_type, messageSizeRange=msg_range, sendRate=send_rate, tickMillis=tick_millis, @@ -155,29 +213,33 @@ def init_client( friends=dict(), ) - def create_clients(self): - total_clients = self.scenario.clients - tick_millis = self.scenario.tick_millis - duration_ticks = self.scenario.duration_ticks - clients = dict() - - # initialize clients - for _ in range(total_clients): - username = str(uuid.uuid4()) - msg_range, send_rate = self.get_sizes_and_rate() - clients[username] = State.init_client( - username, - msg_range, - send_rate, - tick_millis, - duration_ticks, - self.scenario.denim_probability, - ) + async def init_state(self): + async with self.lock: + self._reset() + total_clients = self.scenario.clients + tick_millis = self.scenario.tick_millis + duration_ticks = self.scenario.duration_ticks + clients = dict() + + # initialize clients + for _ in range(total_clients): + username = str(uuid.uuid4()) + msg_range, send_rate = self._get_sizes_and_rate() + clients[username] = State._init_client( + username, + self.scenario.type, + msg_range, + send_rate, + tick_millis, + duration_ticks, + self.scenario.denim_probability, + ) - self.clients = clients - self.make_friends(self.clients) + self.clients = clients + self._make_friends(self.clients) + self.free_clients: set[str] = {c.username for c in self.clients.values()} - def make_friends(self, clients: dict[str, Client]): + def _make_friends(self, clients: dict[str, Client]): names = set(clients.keys()) client_amount = len(names) groups: list[list[str]] = [] @@ -229,12 +291,17 @@ def make_friends(self, clients: dict[str, Client]): for freq, friend in zip(frequencies, client.friends.values()): friend.frequency = freq - def get_sizes_and_rate(self): + def _get_sizes_and_rate(self): min_rate, max_rate = self.scenario.send_rate_range send_rate = random.randint(min_rate, max_rate) # normalized range [0, 1] - norm_rate = (send_rate - min_rate) / (max_rate - min_rate) + dividend = send_rate - min_rate + divisor = max_rate - min_rate + if divisor == 0: + norm_rate = 1 + else: + norm_rate = dividend / divisor min_size, max_size = self.scenario.message_size_range diff --git a/tests/test_state.py b/tests/test_state.py index 29312bc..d59794a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,10 +1,20 @@ import pytest -from sam_dispatcher.state import State, Scenario, Report, ClientReport, MessageLog +from sam_dispatcher.state import ( + State, + Scenario, + Report, + ClientReport, + MessageLog, + StartInfo, +) import math +import time +import asyncio scenarios = [ Scenario( name="test", + type="denim-on-sam", clients=10, groups=[0.5, 0.5], tickMillis=10, @@ -16,11 +26,12 @@ ] +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_can_get_client(scenario: Scenario): +async def test_can_get_client(scenario: Scenario): state = State(scenario) - - client = state.get_client("127.0.0.1") + await state.init_state() + client = await state.get_client("127.0.0.1") denim_friends = filter(lambda x: x.denim, client.friends.values()) assert client is not None @@ -35,10 +46,11 @@ def test_can_get_client(scenario: Scenario): assert all(f.denim for f in denim_friends) +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_groups_have_denim_connection(scenario: Scenario): +async def test_groups_have_denim_connection(scenario: Scenario): state = State(scenario) - + await state.init_state() groups = len(state.scenario.groups) count = 0 for c in state.clients.values(): @@ -48,9 +60,11 @@ def test_groups_have_denim_connection(scenario: Scenario): assert count == groups +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_no_friendless_clients(scenario: Scenario): +async def test_no_friendless_clients(scenario: Scenario): state = State(scenario) + await state.init_state() friend_counts = {f.username: len(f.friends) for f in state.clients.values()} friendless_clients = sum(1 for x in friend_counts.values() if x == 0) all_has_friends = all(friend_counts.values()) @@ -60,53 +74,65 @@ def test_no_friendless_clients(scenario: Scenario): ), f"Expected all clients to have friends found '{friendless_clients}' without any" +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_client_amount_persists(scenario: Scenario): +async def test_client_amount_persists(scenario: Scenario): state = State(scenario) + await state.init_state() expected = state.scenario.clients assert state.client_amount == expected +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_client_friend_freqs_add_to_one(scenario: Scenario): +async def test_client_friend_freqs_add_to_one(scenario: Scenario): state = State(scenario) + await state.init_state() for client in state.clients.values(): total = sum(f.frequency for f in client.friends.values()) assert math.isclose(total, 1.0, rel_tol=1e-9), f"Sum was {total}" class TestReportWriter: + report = None + def write(self, path: str, report: Report): self.report = report +@pytest.mark.asyncio @pytest.mark.parametrize("scenario", scenarios) -def test_ready_to_save(scenario: Scenario): +async def test_ready_to_save(scenario: Scenario): writer = TestReportWriter() state = State(scenario, writer) - time = state.start_time + await state.init_state() + start_time = int(time.time()) clients: list[tuple[str, str]] = [] + starts = [] for i in range(state.scenario.clients): ip = str(i) - client = state.get_client(ip) - state.ready(ip) + client = await state.get_client(ip) clients.append((ip, client.username)) + await state.set_account_id(ip, f"account_{ip}") + starts.append(state.start(ip)) - assert state.clients_ready - assert 10 <= state.start_time - time + await asyncio.gather(*starts) - expected_report = Report(scenario=scenario, clients=state.clients, reports=dict()) + expected_report = Report( + scenario=scenario, + ipAddresses={v: k for k, v in state.usernames.items()}, + clients=state.clients, + reports=dict(), + ) for ip, user in clients: report = ClientReport( - websocketPort=43434, - messages=[ - MessageLog(type="regular", to="x", from_="x", size=1, timestamp=10) - ], + startTime=0, + messages=[MessageLog(type="regular", to="x", from_="x", size=1, tick=10)], ) expected_report.reports[user] = report - state.report(ip, report) - + await state.report(ip, report) + state.save_report() report = writer.report assert report is not None assert report == expected_report