Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ options:
{
"name": "Example Scenario", // name of scenario
"address": "127.0.0.1:8080", // port to run dispatcher on
"type": "sam", // whether to use sam or denim infrastructure (valid: sam, denim)
"clients": 1, // how many clients to register
/* How many groups of clients that should communicate.
Each group has at least one denim client that communicates with a denim client from another group.
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ name = "sam_dispatcher"
version = "0.1.0"
authors = [{ name = "SAM Research" }]

dependencies = ["fastapi", "uvicorn", "pydantic"]
dependencies = ["fastapi", "uvicorn", "pydantic", "asyncio"]

[project.scripts]
sam-dispatch = "sam_dispatcher.server:main"

[project.optional-dependencies]
test = ["pytest", "pytest-asyncio"]
70 changes: 51 additions & 19 deletions src/sam_dispatcher/server.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,72 @@
from fastapi import FastAPI
from argparse import ArgumentParser
import uvicorn
from .state import State, ClientReport
from fastapi.responses import JSONResponse
from fastapi import Request
from .state import State, ClientReport, AccountId
from fastapi import Request, Response, HTTPException
import asyncio

app = FastAPI()
state = None


def auth(request: Request):
try:
client_id = request.cookies.get("id")
except:
raise HTTPException(status_code=401)
host = request.client.host
_id = create_id(host, client_id)

if not state.is_auth(_id):
raise HTTPException(status_code=401)
return _id


def create_id(host: str, id: str):
return f"{host}#{id}"


@app.get("/client")
async def client(request: Request):
client_data = state.get_client(request.client.host)
async def client(request: Request, response: Response):
client_id = await state.next_client_id()
response.set_cookie(key="id", value=client_id)
client_data = await state.get_client(create_id(request.client.host, client_id))
if client_data is None:
return JSONResponse(
status_code=403, content={"error": "Clients have been depleted"}
)
raise HTTPException(status_code=403)
return client_data


@app.post("/ready")
async def ready(request: Request):
state.ready(request.client.host)
@app.post("/id")
async def upload_id(request: Request, account_id: AccountId):
_id = auth(request)
await state.set_account_id(_id, account_id.account_id)


@app.get("/start")
async def start():
return {"start": state.clients_ready, "epoch": state.start_time}
@app.get("/sync")
async def sync(request: Request):
return await state.start(auth(request))


@app.post("/upload")
async def upload(request: Request, report: ClientReport):
state.report(request.client.host, report)
_id = auth(request)
await state.report(_id, report)

if not state.all_clients_have_uploaded:
return
state.save_report()


@app.get("/health")
async def health():
return "OK"


async def setup_state(path: str):
global state
state = State(path)
await state.init_state()
return state.scenario.address.split(":")


def main():
Expand All @@ -43,8 +78,5 @@ def main():
args = parser.parse_args()
config_path: str = args.config

state = State(config_path)

ip, port = state.scenario.address.split(":")

ip, port = asyncio.run(setup_state(config_path))
uvicorn.run("sam_dispatcher.server:app", host=ip, port=int(port))
Loading