From 47e54605322aea8c7b1a263399bc2f2b1a99431e Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 4 Nov 2025 15:55:14 -0800 Subject: [PATCH 01/40] added mtoken --- lib/backend.py | 4 ++++ lib/data_types.py | 1 + lib/metrics.py | 6 ++++++ 3 files changed, 11 insertions(+) diff --git a/lib/backend.py b/lib/backend.py index 6a2f3c0..3f308ed 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -69,10 +69,14 @@ class Backend: report_addr: str = dataclasses.field( default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai") ) + mtoken: str = dataclasses.field( + default_factory=lambda: os.environ.get("MASTER_TOKEN", "") + ) def __post_init__(self): self.metrics = Metrics() self.metrics._set_version(self.version) + self.metrics._set_mtoken(self.mtoken) self._total_pubkey_fetch_errors = 0 self._pubkey = self._fetch_pubkey() self.__start_healthcheck: bool = False diff --git a/lib/data_types.py b/lib/data_types.py index 77883c5..a7f0fad 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -286,6 +286,7 @@ class AutoScalerData: """Data that is reported to autoscaler""" id: int + moken: str version: str loadtime: float cur_load: float diff --git a/lib/metrics.py b/lib/metrics.py index 5f15f74..93b166d 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -28,6 +28,7 @@ def get_url() -> str: @dataclass class Metrics: version: str = "0" + mtoken: str = "" last_metric_update: float = 0.0 last_request_served: float = 0.0 update_pending: bool = False @@ -142,12 +143,16 @@ def _model_errored(self, error_msg: str) -> None: def _set_version(self, version: str) -> None: self.version = version + def _set_mtoken(self, mtoken: str) -> None: + self.mtoken = mtoken + #######################################Private####################################### async def __send_delete_requests_and_reset(self): async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: data = { "worker_id": self.id, + "mtoken": self.mtoken, "request_idxs": idxs, "success": success_flag, } @@ -209,6 +214,7 @@ async def __send_metrics_and_reset(self): def compute_autoscaler_data() -> AutoScalerData: return AutoScalerData( id=self.id, + mtoken=self.mtoken, version=self.version, loadtime=(loadtime_snapshot or 0.0), new_load=self.model_metrics.workload_processing, From f5134d4bf522679c81ac7a4bf0af8d37bbd9cde9 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 4 Nov 2025 16:59:39 -0800 Subject: [PATCH 02/40] Fix spelling mistake --- lib/data_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/data_types.py b/lib/data_types.py index a7f0fad..d948c60 100644 --- a/lib/data_types.py +++ b/lib/data_types.py @@ -286,7 +286,7 @@ class AutoScalerData: """Data that is reported to autoscaler""" id: int - moken: str + mtoken: str version: str loadtime: float cur_load: float From 106067d71612eaff7afe01f261bf7013d6e8bb9a Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 4 Nov 2025 17:15:59 -0800 Subject: [PATCH 03/40] bump version to 0.1.1 --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 3f308ed..8002f3b 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -30,7 +30,7 @@ BenchmarkResult ) -VERSION = "0.1.0" +VERSION = "0.1.1" MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) From 8ae7b746052d7bd746f0b631131db5dd1892e265 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 5 Nov 2025 13:32:21 -0800 Subject: [PATCH 04/40] bump version to 0.2.0 --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 8002f3b..5cbb7ff 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -30,7 +30,7 @@ BenchmarkResult ) -VERSION = "0.1.1" +VERSION = "0.2.0" MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) From b7fe4ebb91e1475bfb69e3a91ba0f20e41a0ad36 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Fri, 7 Nov 2025 10:02:39 -0800 Subject: [PATCH 05/40] Obfuscate mtoken in logs --- lib/metrics.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/metrics.py b/lib/metrics.py index 93b166d..6fb0fce 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -234,17 +234,25 @@ def compute_autoscaler_data() -> AutoScalerData: async def send_data(report_addr: str) -> bool: data = compute_autoscaler_data() - full_path = report_addr.rstrip("/") + "/worker_status/" + log_data = asdict(data) + def obfuscate(secret: str) -> str: + if secret is None: + return "" + return secret[:7] if len(secret) > 7 else ("*" * len(secret)) + + log_data["mtoken"] = obfuscate(log_data.get("mtoken")) log.debug( "\n".join( [ "#" * 60, f"sending data to autoscaler", - f"{json.dumps((asdict(data)), indent=2)}", + f"{json.dumps(log_data, indent=2)}", "#" * 60, ] ) ) + + full_path = report_addr.rstrip("/") + "/worker_status/" for attempt in range(1, 4): try: session = await self.http() From c6521cb6d4d300243193e153e71c68fe1a4bec57 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Fri, 7 Nov 2025 10:10:35 -0800 Subject: [PATCH 06/40] add ... --- lib/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/metrics.py b/lib/metrics.py index 6fb0fce..48774fe 100644 --- a/lib/metrics.py +++ b/lib/metrics.py @@ -238,7 +238,7 @@ async def send_data(report_addr: str) -> bool: def obfuscate(secret: str) -> str: if secret is None: return "" - return secret[:7] if len(secret) > 7 else ("*" * len(secret)) + return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret)) log_data["mtoken"] = obfuscate(log_data.get("mtoken")) log.debug( From b55bfa961124a2a124b78a57867f3732130e80f4 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:09:28 -0800 Subject: [PATCH 07/40] Updated clients, include vastai-sdk, handle non-UTF-8 --- lib/backend.py | 2 +- requirements.txt | 1 + workers/comfyui-json/client.py | 177 ++------ workers/comfyui/client.py | 17 +- workers/openai/client.py | 805 +++++++++++++++------------------ workers/tgi/client.py | 168 +++---- 6 files changed, 455 insertions(+), 715 deletions(-) diff --git a/lib/backend.py b/lib/backend.py index 5cbb7ff..bf1d746 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -417,7 +417,7 @@ async def handle_log_line(log_line: str) -> None: async def tail_log(): log.debug(f"tailing file: {self.model_log_file}") - async with await open_file(self.model_log_file) as f: + async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore'): while True: line = await f.readline() if line: diff --git a/requirements.txt b/requirements.txt index 1d99304..13b194e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 +vastai-sdk>=0.2.0g \ No newline at end of file diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index e4ac92c..c877df2 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,156 +1,35 @@ -import logging -import uuid -import random -from urllib.parse import urljoin -import json - -import requests - -from lib.test_utils import print_truncate_res -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path +from vastai import Serverless from .data_types import count_workload -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -def call_text2image_workflow( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: - """Simple Text2Image using the new modifier-based approach""" - - def make_request(url: str, payload: dict, timeout: int = None, verify=True, context: str = "request"): - """Helper function for making requests with consistent error handling""" - try: - response = requests.post( - url, - json=payload, - timeout=timeout, - verify=verify - ) - - response.raise_for_status() - return response.json() - - except requests.exceptions.HTTPError as http_err: - log.error(f"HTTP error occurred during {context}: {http_err}") - log.error(f"Status Code: {response.status_code}") - log.error("Response content:", response.text) - return None - except requests.exceptions.Timeout: - log.error(f"Timeout occurred during {context}: {url}") - return None - except requests.exceptions.ConnectionError: - log.error(f"Connection error occurred during {context}: {url}") - return None - except json.JSONDecodeError as json_err: - log.error(f"Failed to decode JSON response during {context}: {json_err}") - if 'response' in locals(): - print("Response content:", response.text) - return None - except Exception as err: - log.error(f"An unexpected error occurred during {context}: {err}") - if 'response' in locals(): - log.error("Response content (if available):", response.text) - return None - - WORKER_ENDPOINT = "/generate/sync" +import uuid +import random +import asyncio +import random - # This worker has concurrency = 1. All workloads have cost value 1.0 - COST = count_workload() - - # Route to get worker URL - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, - } - - # First request - get routing information - route_response = make_request( - url=urljoin(server_url, "/route/"), - payload=route_payload, - timeout=4, - context="route request" - ) - - if route_response is None: - return None - - if "url" not in route_response or not route_response["url"]: - log.error("Error: No worker in 'Ready' state. Please wait while the serverless engine removes errored workers or finishes loading new workers.") - return None - - if "status" in route_response: - print(f"Autoscaler status: {route_response['status']}") - return None - - # Extract data from route response - url = route_response["url"] - auth_data = dict( - signature=route_response["signature"], - cost=route_response["cost"], - endpoint=route_response["endpoint"], - reqnum=route_response["reqnum"], - url=route_response["url"], - request_idx=route_response["request_idx"], - ) - - # Build the payload for the worker request - worker_payload = { - "input": { - "request_id": str(uuid.uuid4()), - "modifier": "Text2Image", - "modifications": { - "prompt": "a beautiful landscape with mountains and lakes", - "width": 1024, - "height": 1024, - "steps": 20, - "seed": random.randint(0, 2**32 - 1) - }, - "workflow_json": {} # Empty since using modifier approach +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name + + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "modifier": "Text2Image", + "modifications": { + "prompt": "a beautiful landscape with mountains and lakes", + "width": 1024, + "height": 1024, + "steps": 20, + "seed": random.randint(0, 2**32 - 1) + }, + "workflow_json": {} # Empty since using modifier approach + } } - } - - req_data = dict(payload=worker_payload, auth_data=auth_data) - worker_url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {worker_url}") - - # Second request - call the worker endpoint - worker_response = make_request( - url=worker_url, - payload=req_data, - verify=get_cert_file_path(), - context="worker request" - ) - - return worker_response + + response = await endpoint.request("/generate/sync", payload, cost=count_workload()) + # Get the file from the path on the local machine using SCP or SFTP + # or configure S3 to upload to cloud storage. + print(response["response"]["output"][0]["local_path"]) if __name__ == "__main__": - from lib.test_utils import test_args - - args = test_args.parse_args() - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - - if endpoint_api_key: - result = call_text2image_workflow( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - if result is None: - log.error("Text2Image workflow failed") - else: - print(result) - else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name}") + asyncio.run(main()) \ No newline at end of file diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py index 986ff22..7d1935e 100644 --- a/workers/comfyui/client.py +++ b/workers/comfyui/client.py @@ -7,20 +7,13 @@ from utils.endpoint_util import Endpoint from utils.ssl import get_cert_file_path -""" -NOTE: this client example uses a custom comfy workflow compatible with SD3 only -""" -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) +from vastai import Serverless -def call_default_workflow( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: +ENDPOINT_NAME = "my-comfyui-endpoint" +COST = 100 # Use a constant cost for image generation + +def call_default_workflow(client: Serverless) -> None: WORKER_ENDPOINT = "/prompt" COST = 100 route_payload = { diff --git a/workers/openai/client.py b/workers/openai/client.py index e34cc90..1dadc68 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -1,14 +1,16 @@ +#!/usr/bin/env python3 import logging -import sys import json +import os +import sys import subprocess -from urllib.parse import urljoin -from typing import Dict, Any, Optional, Iterator, Union, List -import requests -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path -from .data_types.client import CompletionConfig, ChatCompletionConfig +import argparse +from typing import Any, Dict, List, Optional + +from vastai import Serverless +import asyncio +# ---------------------- Logging ---------------------- logging.basicConfig( level=logging.DEBUG, format="%(asctime)s[%(levelname)-5s] %(message)s", @@ -16,135 +18,20 @@ ) log = logging.getLogger(__file__) +# ---------------------- Prompts ---------------------- COMPLETIONS_PROMPT = "the capital of USA is" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." -TOOLS_PROMPT = "Can you list the files in the current working directory and tell me what you see? What do you think this directory might be for?" - - -class APIClient: - """Lightweight client focused solely on API communication""" - - # Remove the generic WORKER_ENDPOINT since we're now going direct - DEFAULT_COST = 100 - DEFAULT_TIMEOUT = 4 - - def __init__( - self, - endpoint_group_name: str, - api_key: str, - server_url: str, - endpoint_api_key: str, - ): - self.endpoint_group_name = endpoint_group_name - self.api_key = api_key - self.server_url = server_url - self.endpoint_api_key = endpoint_api_key - - def _get_worker_url(self, cost: int = DEFAULT_COST) -> Dict[str, Any]: - """Get worker URL and auth data from routing service""" - if not self.endpoint_api_key: - raise ValueError("No valid endpoint API key available") - - route_payload = { - "endpoint": self.endpoint_group_name, - "api_key": self.endpoint_api_key, - "cost": cost, - } - - response = requests.post( - urljoin(self.server_url, "/route/"), - json=route_payload, - timeout=self.DEFAULT_TIMEOUT, - ) - response.raise_for_status() - return response.json() - - def _create_auth_data(self, message: Dict[str, Any]) -> Dict[str, Any]: - """Create auth data from routing response""" - return { - "signature": message["signature"], - "cost": message["cost"], - "endpoint": message["endpoint"], - "reqnum": message["reqnum"], - "url": message["url"], - } - - def _make_request( - self, - payload: Dict[str, Any], - endpoint: str, - method: str = "POST", - stream: bool = False, - ) -> Union[Dict[str, Any], Iterator[str]]: - """Make request directly to the specific worker endpoint""" - # Get worker URL and auth data - cost = payload.get("max_tokens", self.DEFAULT_COST) - message = self._get_worker_url(cost=cost) - worker_url = message["url"] - auth_data = self._create_auth_data(message) - - req_data = {"payload": {"input": payload}, "auth_data": auth_data} - - url = urljoin(worker_url, endpoint) - log.debug(f"Making direct request to: {url}") - log.debug(f"Payload: {req_data}") - - # Make the request using the specified method - if method.upper() == "POST": - response = requests.post( - url, json=req_data, stream=stream, verify=get_cert_file_path() - ) - elif method.upper() == "GET": - response = requests.get( - url, params=req_data, stream=stream, verify=get_cert_file_path() - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - response.raise_for_status() - - if stream: - return self._handle_streaming_response(response) - else: - return response.json() - - def _handle_streaming_response(self, response: requests.Response) -> Iterator[str]: - """Handle streaming response and yield tokens""" - try: - for line in response.iter_lines(decode_unicode=True): - if line: - if line.startswith("data: "): - data_str = line[6:] - if data_str.strip() == "[DONE]": - break - try: - data = json.loads(data_str) - yield data # Yield the full chunk - except json.JSONDecodeError: - continue - except Exception as e: - log.error(f"Error handling streaming response: {e}") - raise - - def call_completions( - self, config: CompletionConfig - ) -> Union[Dict[str, Any], Iterator[str]]: - payload = config.to_dict() - - return self._make_request( - payload=payload, endpoint="/v1/completions", stream=config.stream - ) - - def call_chat_completions( - self, config: ChatCompletionConfig - ) -> Union[Dict[str, Any], Iterator[str]]: - payload = config.to_dict() - - return self._make_request( - payload=payload, endpoint="/v1/chat/completions", stream=config.stream - ) +TOOLS_PROMPT = ( + "Can you list the files in the current working directory and tell me what you see? " + "What do you think this directory might be for?" +) +ENDPOINT_NAME = "my-vllm-endpoint" # change this to your vLLM endpoint name +DEFAULT_MODEL = "Qwen/Qwen3-8B" # must support tool calling +MAX_TOKENS = 1024 +DEFAULT_TEMPERATURE = 0.7 +# ---------------------- Tooling ---------------------- class ToolManager: """Handles tool definitions and execution""" @@ -164,7 +51,7 @@ def list_files() -> str: @staticmethod def get_ls_tool_definition() -> List[Dict[str, Any]]: - """Get the ls tool definition""" + """OpenAI-compatible tool schema""" return [ { "type": "function", @@ -178,269 +65,365 @@ def get_ls_tool_definition() -> List[Dict[str, Any]]: def execute_tool_call(self, tool_call: Dict[str, Any]) -> str: """Execute a tool call and return the result""" - function_name = tool_call["function"]["name"] - + function_name = (tool_call.get("function") or {}).get("name") if function_name == "list_files": return self.list_files() - else: - raise ValueError(f"Unknown tool function: {function_name}") + raise ValueError(f"Unknown tool function: {function_name}") + + +# ----- Helpers to handle streamed tool_calls assembly ----- +def _merge_tool_call_delta(state: Dict[int, Dict[str, Any]], tc_delta: Dict[str, Any]) -> None: + """ + OpenAI-style streaming sends partial tool_calls with an index and partial fields. + We merge into a per-index state dict until the assistant message finishes. + """ + idx = tc_delta.get("index") + if idx is None: + return + + entry = state.setdefault(idx, {"id": None, "function": {"name": None, "arguments": ""}, "type": "function"}) + + if tc_delta.get("id"): + entry["id"] = tc_delta["id"] + + fn_delta = tc_delta.get("function") or {} + if "name" in fn_delta and fn_delta["name"]: + entry["function"]["name"] = fn_delta["name"] + if "arguments" in fn_delta and fn_delta["arguments"]: + entry["function"]["arguments"] += fn_delta["arguments"] + + +def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[Dict[str, Any]]: + return [state[i] for i in sorted(state.keys())] + + +# ---- OpenAI-compatible calls (non-streaming) ---- +async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + payload = { + "input": { + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + } + } + log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) + return resp["response"] + +async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), + } + } + log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) + return resp["response"] + +# ---- Streaming variants ---- +async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), + } + } + log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + return resp["response"] # async generator + +async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): + + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "input": { + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), + } + } + log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + return resp["response"] # async generator + + +# ---------------------- Demo Runner ---------------------- class APIDemo: """Demo and testing functionality for the API client""" - def __init__( - self, client: APIClient, model: str, tool_manager: Optional[ToolManager] = None - ): + def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): self.client = client self.model = model self.tool_manager = tool_manager or ToolManager() - def handle_streaming_response( - self, response_stream, show_reasoning: bool = True - ) -> str: - """ - Handle streaming chat response and display all output. - """ - + # ----- Streaming handler ----- + async def handle_streaming_response(self, stream, show_reasoning: bool = True) -> str: full_response = "" reasoning_content = "" - reasoning_started = False - content_started = False - - for chunk in response_stream: - # Normalize the chunk - if isinstance(chunk, str): - chunk = chunk.strip() - if chunk.startswith("data: "): - chunk = chunk[6:].strip() - if chunk in ["[DONE]", ""]: - continue - try: - parsed_chunk = json.loads(chunk) - except json.JSONDecodeError: - continue - elif isinstance(chunk, dict): - parsed_chunk = chunk - else: - continue + printed_reasoning = False + printed_answer = False - # Parse delta from the chunk - choices = parsed_chunk.get("choices", []) - if not choices: - continue + async for chunk in stream: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) - delta = choices[0].get("delta", {}) - reasoning_token = delta.get("reasoning_content", "") - content_token = delta.get("content", "") - - # Print reasoning token if applicable - if show_reasoning and reasoning_token: - if not reasoning_started: + # reasoning tokens + rc = delta.get("reasoning_content") + if rc and show_reasoning: + if not printed_reasoning: print("\n🧠 Reasoning: ", end="", flush=True) - reasoning_started = True - print(f"\033[90m{reasoning_token}\033[0m", end="", flush=True) - reasoning_content += reasoning_token - - # Print content token - if content_token: - if not content_started: - if show_reasoning and reasoning_started: - print(f"\n💬 Response: ", end="", flush=True) + printed_reasoning = True + print(rc, end="", flush=True) + reasoning_content += rc + + # content tokens + content_part = delta.get("content") + if content_part: + if not printed_answer: + if show_reasoning and printed_reasoning: + print("\n💬 Response: ", end="", flush=True) else: print("Assistant: ", end="", flush=True) - content_started = True - print(content_token, end="", flush=True) - full_response += content_token - - print() # Ensure newline after response + printed_answer = True + print(content_part, end="", flush=True) + full_response += content_part + print() # newline if show_reasoning: - if reasoning_started or content_started: + if printed_reasoning or printed_answer: print("\nStreaming completed.") - if reasoning_started: + if printed_reasoning: print(f"Reasoning tokens: {len(reasoning_content.split())}") - if content_started: + if printed_answer: print(f"Response tokens: {len(full_response.split())}") return full_response - - def test_tool_support(self) -> bool: - """Test if the endpoint supports function calling""" - log.debug("Testing endpoint tool calling support...") - - # Try a simple request with minimal tools to test support - messages = [{"role": "user", "content": "Hello"}] - minimal_tool = [ - { - "type": "function", - "function": {"name": "test_function", "description": "Test function"}, - } - ] - - config = ChatCompletionConfig( - model=self.model, - messages=messages, - max_tokens=10, - tools=minimal_tool, - tool_choice="none", # Don't actually call the tool - ) - - try: - response = self.client.call_chat_completions(config) - return True - except Exception as e: - log.error(f"Error: Endpoint does not support tool calling: {e}") - return False - - def demo_completions(self) -> None: - """Demo: test basic completions endpoint""" + + async def demo_completions(self) -> None: print("=" * 60) print("COMPLETIONS DEMO") print("=" * 60) - config = CompletionConfig( - model=self.model, prompt=COMPLETIONS_PROMPT, stream=False - ) - - log.info( - f"Testing completions with model '{self.model}' and prompt: '{config.prompt}'" + response = await call_completions( + client=self.client, + model=self.model, + prompt=COMPLETIONS_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, ) - response = self.client.call_completions(config) - - if isinstance(response, dict): - print("\nResponse:") - print(json.dumps(response, indent=2)) - else: - log.error("Unexpected response format") + print("\nResponse:") + print(json.dumps(response, indent=2)) - def demo_chat(self, use_streaming: bool = True) -> None: - """ - Demo: test chat completions endpoint with optional streaming - """ + async def demo_chat(self, use_streaming: bool = True) -> None: print("=" * 60) - print( - f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}" - ) + print(f"CHAT COMPLETIONS DEMO {'(STREAMING)' if use_streaming else '(NON-STREAMING)'}") print("=" * 60) - config = ChatCompletionConfig( - model=self.model, - messages=[{"role": "user", "content": CHAT_PROMPT}], - stream=use_streaming, - ) - - log.info(f"Testing chat completions with model '{self.model}'...") - response = self.client.call_chat_completions(config) + messages = [{"role": "user", "content": CHAT_PROMPT}] if use_streaming: + stream = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE + ) try: - self.handle_streaming_response(response, show_reasoning=True) + await self.handle_streaming_response(stream, show_reasoning=True) except Exception as e: - log.error(f"\nError during streaming: {e}") - import traceback - - traceback.print_exc() - return - + log.error("\nError during streaming: %s", e, exc_info=True) else: - if isinstance(response, dict): - choice = response.get("choices", [{}])[0] - message = choice.get("message", {}) - content = message.get("content", "") - reasoning = message.get("reasoning_content", "") or message.get( - "reasoning", "" - ) - - if reasoning: - print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") - - print(f"\n💬 Assistant: {content}") - print(f"\nFull Response:") - print(json.dumps(response, indent=2)) - else: - log.error("Unexpected response format") + response = await call_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE + ) + choice = (response.get("choices") or [{}])[0] + message = choice.get("message", {}) + content = message.get("content", "") + reasoning = message.get("reasoning_content", "") or message.get("reasoning", "") + if reasoning: + print(f"\n🧠 Reasoning: \033[90m{reasoning}\033[0m") + print(f"\n💬 Assistant: {content}") + print(f"\nFull Response:\n{json.dumps(response, indent=2)}") + + async def test_tool_support(self) -> bool: + """Probe that tool schema is accepted (no actual call)""" + messages = [{"role": "user", "content": "Hello"}] + minimal_tool = [ + { + "type": "function", + "function": {"name": "test_function", "description": "Test function"}, + } + ] + try: + _ = await call_chat_completions( + client=self.client, + model=self.model, + messages=messages, + tools=minimal_tool, + tool_choice="none", + max_tokens=10 + ) + return True + except Exception as e: + log.error("Endpoint does not support tool calling: %s", e) + return False - def demo_ls_tool(self) -> None: - """Demo: ask LLM to list files in the current directory and describe what it sees""" + async def demo_ls_tool(self) -> None: + """Ask to list files using function calling, then provide final analysis""" print("=" * 60) print("TOOL USE DEMO: List Directory Contents") print("=" * 60) - # Test if tools are supported first - if not self.test_tool_support(): + if not await self.test_tool_support(): return - # Request with tool available - messages = [{"role": "user", "content": TOOLS_PROMPT}] + messages: List[Dict[str, Any]] = [{"role": "user", "content": TOOLS_PROMPT}] - config = ChatCompletionConfig( + # First pass: let the model decide tools, stream tool_calls and partial content + stream = await stream_chat_completions( + client=self.client, model=self.model, messages=messages, tools=self.tool_manager.get_ls_tool_definition(), tool_choice="auto", + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, ) - log.info(f"Making initial request with tool using model '{self.model}'...") - response = self.client.call_chat_completions(config) - - if not isinstance(response, dict): - raise ValueError("Expected dict response for tool use") - - choice = response.get("choices", [{}])[0] - message = choice.get("message", {}) - - print(f"Assistant response: {message.get('content', 'No content')}") - - # Check for tool calls - tool_calls = message.get("tool_calls") - if not tool_calls: - raise ValueError( - "No tool calls made - model may not support function calling" - ) - - print(f"Tool calls detected: {len(tool_calls)}") + assistant_content_buf: List[str] = [] + tool_calls_state: Dict[int, Dict[str, Any]] = {} + printed_reasoning = False + printed_answer = False + + async for chunk in stream: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) + + rc = delta.get("reasoning_content") + if rc: + if not printed_reasoning: + printed_reasoning = True + print("🧠 Reasoning: ", end="", flush=True) + print(rc, end="", flush=True) + + content_part = delta.get("content") + if content_part: + assistant_content_buf.append(content_part) + if not printed_answer: + printed_answer = True + print("\n💬 Response: ", end="", flush=True) + print(content_part, end="", flush=True) + + if "tool_calls" in delta and delta["tool_calls"]: + for tc_delta in delta["tool_calls"]: + _merge_tool_call_delta(tool_calls_state, tc_delta) + + # If no tool calls, we’re done. + if not tool_calls_state: + print("\n(No tool calls were made.)") + return - # Execute the tool call - for tool_call in tool_calls: - function_name = tool_call["function"]["name"] - print(f"Executing tool: {function_name}") + # Build assistant message with tool_calls + assistant_message = { + "role": "assistant", + "content": "".join(assistant_content_buf) if assistant_content_buf else None, + "tool_calls": _tool_state_to_message_tool_calls(tool_calls_state), + } + messages.append(assistant_message) - tool_result = self.tool_manager.execute_tool_call(tool_call) - print(f"Tool result:\n{tool_result}") + # Execute tools and feed results back + for tc in assistant_message["tool_calls"]: + tool_name = (tc.get("function") or {}).get("name") + call_id = tc.get("id") + raw_args = (tc.get("function") or {}).get("arguments") or "{}" - # Add tool result and continue conversation - messages.append(message) # Add assistant's message with tool call - messages.append( - { - "role": "tool", - "tool_call_id": tool_call["id"], - "content": tool_result, - } - ) + try: + args = json.loads(raw_args) if raw_args.strip() else {} + except Exception as e: + tool_result = json.dumps({"error": f"Argument parse failed: {str(e)}", "raw_arguments": raw_args}) + messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result}) + continue - # Get final response - final_config = ChatCompletionConfig( - model=self.model, - messages=messages, - tools=self.tool_manager.get_ls_tool_definition(), - ) + try: + if tool_name == "list_files": + tool_result = self.tool_manager.list_files() + else: + tool_result = json.dumps({"error": f"Unknown tool '{tool_name}'"}) + except Exception as e: + tool_result = json.dumps({"error": f"Tool '{tool_name}' failed: {str(e)}"}) - print("Getting final response...") - final_response = self.client.call_chat_completions(final_config) + print("\n[Tool executed]", tool_name) + print(tool_result[:500] + ("..." if len(tool_result) > 500 else "")) + messages.append({"role": "tool", "tool_call_id": call_id, "content": tool_result}) - if isinstance(final_response, dict): - final_choice = final_response.get("choices", [{}])[0] - final_message = final_choice.get("message", {}) - final_content = final_message.get("content", "") + # Second pass: get final streamed answer after tool results + stream2 = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) - print("\n" + "=" * 60) - print("FINAL LLM ANALYSIS:") - print("=" * 60) - print(final_content) - print("=" * 60) + final_buf = [] + printed_reasoning2 = False + printed_answer2 = False + + async for chunk in stream2: + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta", {}) + + rc2 = delta.get("reasoning_content") + if rc2: + if not printed_reasoning2: + printed_reasoning2 = True + print("\n🧠 Reasoning (post-tools): ", end="", flush=True) + print(rc2, end="", flush=True) + + c2 = delta.get("content") + if c2: + final_buf.append(c2) + if not printed_answer2: + printed_answer2 = True + print("\n💬 Response (final): ", end="", flush=True) + print(c2, end="", flush=True) + + print("\n" + "=" * 60) + print("FINAL LLM ANALYSIS:") + print("=" * 60) + print("".join(final_buf)) + print("=" * 60) - def interactive_chat(self) -> None: + async def interactive_chat(self) -> None: """Interactive chat session with streaming""" print("=" * 60) print("INTERACTIVE STREAMING CHAT") @@ -449,7 +432,7 @@ def interactive_chat(self) -> None: print("Type 'quit' to exit, 'clear' to clear history") print() - messages = [] + messages: List[Dict[str, Any]] = [] while True: try: @@ -467,16 +450,15 @@ def interactive_chat(self) -> None: messages.append({"role": "user", "content": user_input}) - config = ChatCompletionConfig( - model=self.model, messages=messages, stream=True, temperature=0.7 - ) - print("Assistant: ", end="", flush=True) - - response = self.client.call_chat_completions(config) - assistant_content = self.handle_streaming_response( - response, show_reasoning=True + stream = await stream_chat_completions( + client=self.client, + model=self.model, + messages=messages, + max_tokens=MAX_TOKENS, + temperature=0.7 ) + assistant_content = await self.handle_streaming_response(stream, show_reasoning=True) # Add assistant response to conversation history messages.append({"role": "assistant", "content": assistant_content}) @@ -485,115 +467,64 @@ def interactive_chat(self) -> None: print("\n👋 Chat interrupted. Goodbye!") break except Exception as e: - log.error(f"\nError: {e}") + log.error("\nError: %s", e) continue -def main(): - """Main function with CLI switches for different tests""" - from lib.test_utils import test_args - - # Add mandatory model argument - test_args.add_argument( - "--model", required=True, help="Model to use for requests (required)" - ) - - # Add test mode arguments - test_args.add_argument( - "--completion", action="store_true", help="Test completions endpoint" - ) - test_args.add_argument( - "--chat", - action="store_true", - help="Test chat completions endpoint (non-streaming)", - ) - test_args.add_argument( - "--chat-stream", - action="store_true", - help="Test chat completions endpoint with streaming", - ) - test_args.add_argument( - "--tools", - action="store_true", - help="Test function calling with ls tool (non-streaming)", - ) - test_args.add_argument( - "--interactive", - action="store_true", - help="Start interactive streaming chat session", - ) - - args = test_args.parse_args() - - # Check that only one test mode is selected - test_modes = [ - args.completion, - args.chat, - args.chat_stream, - args.tools, - args.interactive, - ] - selected_count = sum(test_modes) - - if selected_count == 0: +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") + p.add_argument("--model", required=True, help="Model to use for requests (required)") + p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") + + modes = p.add_mutually_exclusive_group(required=False) + modes.add_argument("--completion", action="store_true", help="Test completions endpoint") + modes.add_argument("--chat", action="store_true", help="Test chat completions endpoint (non-streaming)") + modes.add_argument("--chat-stream", action="store_true", help="Test chat completions endpoint with streaming") + modes.add_argument("--tools", action="store_true", help="Test function calling with ls tool (non-streaming+streamed phases)") + modes.add_argument("--interactive", action="store_true", help="Start interactive streaming chat session") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + selected = sum([args.completion, args.chat, args.chat_stream, args.tools, args.interactive]) + if selected == 0: print("Please specify exactly one test mode:") print(" --completion : Test completions endpoint") print(" --chat : Test chat completions endpoint (non-streaming)") print(" --chat-stream : Test chat completions endpoint with streaming") - print(" --tools : Test function calling with ls tool (non-streaming)") + print(" --tools : Test function calling with ls tool") print(" --interactive : Start interactive streaming chat session") - print( - f"\nExample: python {sys.argv[0]} --model Qwen/Qwen3-8B --chat-stream -k YOUR_KEY -e YOUR_ENDPOINT" - ) + print(f"\nExample: python {os.path.basename(sys.argv[0])} --model Qwen/Qwen3-8B --chat-stream --endpoint my-vllm-endpoint") sys.exit(1) - elif selected_count > 1: + elif selected > 1: print("Please specify exactly one test mode") sys.exit(1) - try: - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - - if not endpoint_api_key: - log.error( - f"Could not retrieve API key for endpoint '{args.endpoint_group_name}'. Exiting." - ) - sys.exit(1) - - # Create the core API client - client = APIClient( - endpoint_group_name=args.endpoint_group_name, - api_key=args.api_key, - server_url=Endpoint.get_autoscaler_server_url(args.instance), - endpoint_api_key=endpoint_api_key, - ) + print(f"Using model: {args.model}") + print("=" * 60) - # Create tool manager and demo (passing the model parameter) - tool_manager = ToolManager() - demo = APIDemo(client, args.model, tool_manager) - - print(f"Using model: {args.model}") - print("=" * 60) - - # Run the selected test - if args.completion: - demo.demo_completions() - elif args.chat: - demo.demo_chat(use_streaming=False) - elif args.chat_stream: - demo.demo_chat(use_streaming=True) - elif args.tools: - demo.demo_ls_tool() - elif args.interactive: - demo.interactive_chat() + try: + async with Serverless() as client: + demo = APIDemo(client, args.model, ToolManager()) + + if args.completion: + await demo.demo_completions() + elif args.chat: + await demo.demo_chat(use_streaming=False) + elif args.chat_stream: + await demo.demo_chat(use_streaming=True) + elif args.tools: + await demo.demo_ls_tool() + elif args.interactive: + await demo.interactive_chat() except Exception as e: - log.error(f"Error during test: {e}", exc_info=True) + log.error("Error during test: %s", e, exc_info=True) sys.exit(1) if __name__ == "__main__": - main() + asyncio.run(main_async()) diff --git a/workers/tgi/client.py b/workers/tgi/client.py index 66dacb9..f307602 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,125 +1,61 @@ -import logging -import sys -import json -from urllib.parse import urljoin -import requests -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -def call_generate(endpoint_group_name: str, api_key: str, server_url: str) -> None: - WORKER_ENDPOINT = "/generate" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, +from vastai import Serverless +import asyncio + +ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name +MAX_TOKENS = 1024 +PROMPT = "Think step by step: Tell me about the Python programming language." + +async def call_generate(client: Serverless) -> None: + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + + payload = { + "inputs": PROMPT, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "return_full_text": False + } } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() # Raise an exception for bad status codes - message = response.json() - url = message["url"] - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=url, - ) + resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) + + print(resp["response"]["generated_text"]) - payload = dict(inputs="tell me about cats", parameters=dict(max_new_tokens=500)) - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {url}") - response = requests.post( - url, - json=req_data, - verify=get_cert_file_path(), - ) - response.raise_for_status() - res = response.json() - print(res) +async def call_generate_stream(client: Serverless) -> None: + endpoint = await client.get_endpoint(name=ENDPOINT_NAME) -def call_generate_stream( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: - WORKER_ENDPOINT = "/generate_stream" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, + payload = { + "inputs": PROMPT, + "parameters": { + "max_new_tokens": MAX_TOKENS, + "temperature": 0.7, + "do_sample": True, + "return_full_text": False, + } } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() # Raise an exception for bad status codes - message = response.json() - url = message["url"] - print(f"url: {url}") - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=message["url"], - ) - payload = dict(inputs="tell me about dogs", parameters=dict(max_new_tokens=500)) - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - response = requests.post(url, json=req_data, stream=True) - response.raise_for_status() # Raise an exception for bad status codes - for line in response.iter_lines(): - payload = line.decode().lstrip("data:").rstrip() - if payload: - try: - data = json.loads(payload) - print(data["token"]["text"], end="") - sys.stdout.flush() - except (json.JSONDecodeError, KeyError) as e: - log.warning(f"Failed to parse streaming response: {e}") - continue - print() + resp = await endpoint.request( + "/generate_stream", + payload, + cost=MAX_TOKENS, + stream=True, + ) + stream = resp["response"] + + printed_answer = False + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("Answer:\n", end="", flush=True) + print(tok, end="", flush=True) + +async def main(): + async with Serverless() as client: + await call_generate(client) + await call_generate_stream(client) if __name__ == "__main__": - from lib.test_utils import test_args - - args = test_args.parse_args() - - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - if endpoint_api_key: - try: - call_generate( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - call_generate_stream( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - except Exception as e: - log.error(f"Error during API call: {e}") - else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") + asyncio.run(main()) From 3adec1826d982231805ea8a91745fadb37283d28 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:11:38 -0800 Subject: [PATCH 08/40] minor changes --- workers/comfyui-json/client.py | 4 ++-- workers/openai/client.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index c877df2..93e184c 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,11 +1,11 @@ -from vastai import Serverless from .data_types import count_workload - import uuid import random import asyncio import random +from vastai import Serverless + async def main(): async with Serverless() as client: endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name diff --git a/workers/openai/client.py b/workers/openai/client.py index 1dadc68..8c88444 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import logging import json import os From eedf81c0a314c759977109a4f91ca3059402bd1f Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:18:40 -0800 Subject: [PATCH 09/40] Updated readme and .gitignore --- .gitignore | 3 ++- README.md | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 226869e..dc47eed 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .envrc __pycache__ bin/ -lib64 \ No newline at end of file +lib64 +.venv \ No newline at end of file diff --git a/README.md b/README.md index 117600d..dda0ea2 100644 --- a/README.md +++ b/README.md @@ -39,11 +39,12 @@ reporting these metrics to the autoscaler. If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: -* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=72d8dcb41ea3a58e06c741e2c725bc00) -* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=140778&template_id=ad72c8bf7cf695c3c9ddf0eaf6da0447) +* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d) +* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d) +* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938) Currently available workers: -* `hello_world`: A simple example worker for a basic LLM server. +* `openai`: A simple example worker for a basic vLLM server. * `comfyui`: A worker for the ComfyUI image generation backend. * `tgi`: A worker for the Text Generation Inference backend. From a12523b1d29c8a1a8ee2e402c67c73992fcb8e22 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:41:12 -0800 Subject: [PATCH 10/40] Added bad code to tgi server to test --- lib/server.py | 75 +++++++++++++++++++++++++++---------------- start_server.sh | 45 +++++++++++++++++++++++++- workers/tgi/server.py | 1 + 3 files changed, 93 insertions(+), 28 deletions(-) diff --git a/lib/server.py b/lib/server.py index b21c880..25250ea 100644 --- a/lib/server.py +++ b/lib/server.py @@ -3,38 +3,59 @@ from typing import List import ssl from asyncio import run, gather - +import asyncio from lib.backend import Backend +from lib.metrics import Metrics from aiohttp import web log = logging.getLogger(__file__) def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): - log.debug("getting certificate...") - use_ssl = os.environ.get("USE_SSL", "false") == "true" - if use_ssl is True: - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain( - certfile="/etc/instance.crt", - keyfile="/etc/instance.key", - ) - else: - ssl_context = None - - async def main(): - log.debug("starting server...") - app = web.Application() - app.add_routes(routes) - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite( - runner, - ssl_context=ssl_context, - port=int(os.environ["WORKER_PORT"]), - **kwargs - ) - await gather(site.start(), backend._start_tracking()) - - run(main()) + try: + log.debug("getting certificate...") + use_ssl = os.environ.get("USE_SSL", "false") == "true" + if use_ssl is True: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + raise Exception("Oh no the SSL cert is gone!") + ssl_context.load_cert_chain( + certfile="/etc/instance.crt", + keyfile="/etc/instance.key", + ) + else: + ssl_context = None + + async def main(): + log.debug("starting server...") + app = web.Application() + app.add_routes(routes) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite( + runner, + ssl_context=ssl_context, + port=int(os.environ["WORKER_PORT"]), + **kwargs + ) + await gather(site.start(), backend._start_tracking()) + + run(main()) + + except Exception as e: + err_msg = f"PyWorker failed to launch: {e}" + log.error(err_msg) + + async def beacon(): + metrics = Metrics() + metrics._set_version(getattr(backend, "version", "0")) + metrics._set_mtoken(getattr(backend, "mtoken", "")) + try: + while True: + metrics._model_errored(err_msg) + await metrics.__send_metrics_and_reset() + await asyncio.sleep(10) + finally: + await metrics.aclose() + + run(beacon()) diff --git a/start_server.sh b/start_server.sh index edc16a4..87c7702 100755 --- a/start_server.sh +++ b/start_server.sh @@ -128,5 +128,48 @@ echo "launching PyWorker server" # from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" -(python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & + +# Run the worker in foreground so we can detect non-zero exit and report it +python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG" +STATUS=$? + +if [ $STATUS -ne 0 ]; then + echo "PyWorker exited with status $STATUS; notifying autoscaler..." + + ERROR_MSG="PyWorker exited: code ${STATUS}" + MTOKEN="${MASTER_TOKEN:-}" + VERSION="${PYWORKER_VERSION:-0}" + + # Comma-separated REPORT_ADDR is supported + IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" + for addr in "${REPORT_ADDRS[@]}"; do + # minimal, schema-compatible payload + curl -sS -X POST -H 'Content-Type: application/json' \ + -d "$(cat < Date: Tue, 11 Nov 2025 17:49:34 -0800 Subject: [PATCH 11/40] fix --- lib/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/server.py b/lib/server.py index 25250ea..def7340 100644 --- a/lib/server.py +++ b/lib/server.py @@ -53,7 +53,7 @@ async def beacon(): try: while True: metrics._model_errored(err_msg) - await metrics.__send_metrics_and_reset() + await metrics._Metrics__send_metrics_and_reset() await asyncio.sleep(10) finally: await metrics.aclose() From de9b50abb9d5043cd0b08d5cc556d3dba2616f80 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 17:53:36 -0800 Subject: [PATCH 12/40] use set +e --- start_server.sh | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/start_server.sh b/start_server.sh index 87c7702..4763895 100755 --- a/start_server.sh +++ b/start_server.sh @@ -129,21 +129,19 @@ echo "launching PyWorker server" [ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" -# Run the worker in foreground so we can detect non-zero exit and report it +set +e python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG" -STATUS=$? +PY_STATUS=${PIPESTATUS[0]} +set -e -if [ $STATUS -ne 0 ]; then - echo "PyWorker exited with status $STATUS; notifying autoscaler..." - - ERROR_MSG="PyWorker exited: code ${STATUS}" +if [ "${PY_STATUS}" -ne 0 ]; then + echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..." + ERROR_MSG="PyWorker exited: code ${PY_STATUS}" MTOKEN="${MASTER_TOKEN:-}" VERSION="${PYWORKER_VERSION:-0}" - # Comma-separated REPORT_ADDR is supported IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" for addr in "${REPORT_ADDRS[@]}"; do - # minimal, schema-compatible payload curl -sS -X POST -H 'Content-Type: application/json' \ -d "$(cat < Date: Tue, 11 Nov 2025 17:57:08 -0800 Subject: [PATCH 13/40] dont exit on pyworker fail --- start_server.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/start_server.sh b/start_server.sh index 4763895..09c33d8 100755 --- a/start_server.sh +++ b/start_server.sh @@ -166,8 +166,6 @@ if [ "${PY_STATUS}" -ne 0 ]; then JSON )" "${addr%/}/worker_status/" || true done - - exit "${PY_STATUS}" fi echo "launching PyWorker server done" \ No newline at end of file From a47c9d1ed0821aab48ec904a2dba863df3feff1a Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 11 Nov 2025 18:13:46 -0800 Subject: [PATCH 14/40] remove test bugs --- lib/server.py | 1 - workers/tgi/server.py | 1 - 2 files changed, 2 deletions(-) diff --git a/lib/server.py b/lib/server.py index def7340..0029311 100644 --- a/lib/server.py +++ b/lib/server.py @@ -18,7 +18,6 @@ def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): use_ssl = os.environ.get("USE_SSL", "false") == "true" if use_ssl is True: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - raise Exception("Oh no the SSL cert is gone!") ssl_context.load_cert_chain( certfile="/etc/instance.crt", keyfile="/etc/instance.key", diff --git a/workers/tgi/server.py b/workers/tgi/server.py index 9ce8374..99fc810 100644 --- a/workers/tgi/server.py +++ b/workers/tgi/server.py @@ -127,5 +127,4 @@ async def handle_ping(_): ] if __name__ == "__main__": - blips = blorps start_server(backend, routes) From 2b26e5e20c072f650847976291e5284f0fef1420 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 12 Nov 2025 16:01:57 -0800 Subject: [PATCH 15/40] hotfix: remove g --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 13b194e..377b20a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 -vastai-sdk>=0.2.0g \ No newline at end of file +vastai-sdk>=0.2.0 \ No newline at end of file From a4339bd3f1ba2752ad166131eab8809691551510 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 12 Nov 2025 16:10:55 -0800 Subject: [PATCH 16/40] hotfix: add f --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index bf1d746..19764bd 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -417,7 +417,7 @@ async def handle_log_line(log_line: str) -> None: async def tail_log(): log.debug(f"tailing file: {self.model_log_file}") - async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore'): + async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f: while True: line = await f.readline() if line: From e0449cb3c7e7719389ec8ea4192b588d70f93ebd Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Fri, 21 Nov 2025 10:22:16 -0800 Subject: [PATCH 17/40] add llama log --- workers/openai/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/workers/openai/server.py b/workers/openai/server.py index bfeb819..8dc962f 100644 --- a/workers/openai/server.py +++ b/workers/openai/server.py @@ -11,6 +11,7 @@ "llama runner started", # Ollama '"message":"Connected","target":"text_generation_router"', # TGI '"message":"Connected","target":"text_generation_router::server"', # TGI + "main: model loaded" # llama.cpp ] MODEL_SERVER_ERROR_LOG_MSGS = [ From 45e0c7d9caf62805495fa3d499a91909ba64363f Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 24 Nov 2025 15:02:33 -0800 Subject: [PATCH 18/40] Move model log rotate to top --- start_server.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/start_server.sh b/start_server.sh index edc16a4..dd57b5a 100755 --- a/start_server.sh +++ b/start_server.sh @@ -41,6 +41,14 @@ echo_var DEBUG_LOG echo_var PYWORKER_LOG echo_var MODEL_LOG +# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines +# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only +if [ -e "$MODEL_LOG" ]; then + echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old" + cat "$MODEL_LOG" >> "$MODEL_LOG.old" + : > "$MODEL_LOG" +fi + # Populate /etc/environment with quoted values if ! grep -q "VAST" /etc/environment; then env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do @@ -124,9 +132,7 @@ cd "$SERVER_DIR" echo "launching PyWorker server" -# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines -# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only -[ -e "$MODEL_LOG" ] && cat "$MODEL_LOG" >> "$MODEL_LOG.old" && : > "$MODEL_LOG" +# Model log line used to be here ! (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & echo "launching PyWorker server done" From 9c6ab7850343a2ac6d83e37f628ea8de61499d38 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Mon, 24 Nov 2025 15:22:23 -0800 Subject: [PATCH 19/40] Move model log line --- start_server.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/start_server.sh b/start_server.sh index dd57b5a..e30a2bc 100755 --- a/start_server.sh +++ b/start_server.sh @@ -132,7 +132,5 @@ cd "$SERVER_DIR" echo "launching PyWorker server" -# Model log line used to be here ! - (python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG") & echo "launching PyWorker server done" From e14316243859f100f040b22f13a0412fda789a51 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Tue, 25 Nov 2025 16:01:23 -0800 Subject: [PATCH 20/40] bumpy pyworker version --- lib/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/backend.py b/lib/backend.py index 19764bd..0d9a273 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -30,7 +30,7 @@ BenchmarkResult ) -VERSION = "0.2.0" +VERSION = "0.2.1" MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) From 0bcd2219ea550c4f34bb67702bf555fd16a06057 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 3 Dec 2025 12:38:52 -0800 Subject: [PATCH 21/40] Increase model wait time for vLLM --- workers/openai/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/workers/openai/server.py b/workers/openai/server.py index 8dc962f..63f21f9 100644 --- a/workers/openai/server.py +++ b/workers/openai/server.py @@ -35,6 +35,7 @@ model_server_url=os.environ["MODEL_SERVER_URL"], model_log_file=os.environ["MODEL_LOG"], allow_parallel_requests=True, + max_wait_time=600.0, benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), log_actions=[ *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], From adedb8ba909e387bb6fdd66faed0bbddf450f387 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 16:57:28 -0800 Subject: [PATCH 22/40] defaults to ENDPOINT_NAME and DEFAULT_MODEL but uses the flag first if present --- workers/openai/README.md | 38 ++++++++++++++++++------------- workers/openai/client.py | 48 ++++++++++++++++++++++++++-------------- 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/workers/openai/README.md b/workers/openai/README.md index 2436784..0dbaaa4 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -34,28 +34,38 @@ uv pip install -r requirements.txt Several examples have been provided in the client to help you get started with your own implementation. -### Completions +First, set your API key as an environment variable: -Call to `/v1/completions` with json response +```bash +export VAST_API_KEY= +``` + +The `--model` and `--endpoint` flags are optional. If not provided, they default to `Qwen/Qwen3-8B` and `my-vllm-endpoint` respectively. + +### Chat Completion (streaming) + +Call to `/v1/chat/completions` with streaming response ```bash -python -m workers.openai.client -k -e --completion --model +python -m workers.openai.client --chat-stream --endpoint --model ``` -### Chat Completion (json) +### Interactive Chat (streaming) -Call to `/v1/chat/completions` with json response +Interactive session with calls to `/v1/chat/completions`. + +Type `clear` to clear the chat history or `quit` to exit. ```bash -python -m workers.openai.client -k -e --chat --model +python -m workers.openai.client --interactive --endpoint --model ``` -### Chat Completion (streaming) +### Chat Completion (json) -Call to `/v1/chat/completions` with streaming response +Call to `/v1/chat/completions` with json response ```bash -python -m workers.openai.client -k -e --chat-stream --model +python -m workers.openai.client --chat --endpoint --model ``` ### Tool Use (json) @@ -65,16 +75,14 @@ Call to `/v1/chat/completions` with tool and json response. This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. ```bash -python -m workers.openai.client -k -e --tools --model +python -m workers.openai.client --tools --endpoint --model ``` -### Interactive Chat (streaming) - -Interactive session with calls to `/v1/chat/completions`. +### Completions -Type `clear` to clear the chat history or `quit` to exit. +Call to `/v1/completions` with json response ```bash -python -m workers.openai.client -k -e --interactive --model +python -m workers.openai.client --completion --endpoint --model ``` diff --git a/workers/openai/client.py b/workers/openai/client.py index 8c88444..a92ad95 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -18,7 +18,7 @@ log = logging.getLogger(__file__) # ---------------------- Prompts ---------------------- -COMPLETIONS_PROMPT = "the capital of USA is" +COMPLETIONS_PROMPT = "Zebras are primarily grazers and can subsist on lower-quality vegetation. They are preyed on mainly by" CHAT_PROMPT = "Think step by step: Tell me about the Python programming language." TOOLS_PROMPT = ( "Can you list the files in the current working directory and tell me what you see? " @@ -97,9 +97,9 @@ def _tool_state_to_message_tool_calls(state: Dict[int, Dict[str, Any]]) -> List[ # ---- OpenAI-compatible calls (non-streaming) ---- -async def call_completions(client: Serverless, *, model: str, prompt: str, **kwargs) -> Dict[str, Any]: +async def call_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -113,9 +113,9 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, **kwa resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) return resp["response"] -async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: +async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -132,9 +132,9 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis return resp["response"] # ---- Streaming variants ---- -async def stream_completions(client: Serverless, *, model: str, prompt: str, **kwargs): +async def stream_completions(client: Serverless, *, model: str, prompt: str, endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -150,9 +150,9 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, **k resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) return resp["response"] # async generator -async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], **kwargs): +async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs): - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + endpoint = await client.get_endpoint(name=endpoint_name) payload = { "input": { @@ -174,9 +174,10 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L class APIDemo: """Demo and testing functionality for the API client""" - def __init__(self, client: Serverless, model: str, tool_manager: Optional[ToolManager] = None): + def __init__(self, client: Serverless, model: str, endpoint_name: str, tool_manager: Optional[ToolManager] = None): self.client = client self.model = model + self.endpoint_name = endpoint_name self.tool_manager = tool_manager or ToolManager() # ----- Streaming handler ----- @@ -185,10 +186,15 @@ async def handle_streaming_response(self, stream, show_reasoning: bool = True) - reasoning_content = "" printed_reasoning = False printed_answer = False + finish_reason = None async for chunk in stream: choice = (chunk.get("choices") or [{}])[0] delta = choice.get("delta", {}) + + # Track finish reason + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") # reasoning tokens rc = delta.get("reasoning_content") @@ -219,6 +225,8 @@ async def handle_streaming_response(self, stream, show_reasoning: bool = True) - print(f"Reasoning tokens: {len(reasoning_content.split())}") if printed_answer: print(f"Response tokens: {len(full_response.split())}") + if finish_reason: + print(f"Finish reason: {finish_reason}") return full_response @@ -231,6 +239,7 @@ async def demo_completions(self) -> None: client=self.client, model=self.model, prompt=COMPLETIONS_PROMPT, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -249,6 +258,7 @@ async def demo_chat(self, use_streaming: bool = True) -> None: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -261,6 +271,7 @@ async def demo_chat(self, use_streaming: bool = True) -> None: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE ) @@ -287,6 +298,7 @@ async def test_tool_support(self) -> bool: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=minimal_tool, tool_choice="none", max_tokens=10 @@ -312,6 +324,7 @@ async def demo_ls_tool(self) -> None: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, tools=self.tool_manager.get_ls_tool_definition(), tool_choice="auto", max_tokens=MAX_TOKENS, @@ -389,6 +402,7 @@ async def demo_ls_tool(self) -> None: client=self.client, model=self.model, messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ) @@ -427,7 +441,6 @@ async def interactive_chat(self) -> None: print("=" * 60) print("INTERACTIVE STREAMING CHAT") print("=" * 60) - print(f"Using model: {self.model}") print("Type 'quit' to exit, 'clear' to clear history") print() @@ -453,7 +466,8 @@ async def interactive_chat(self) -> None: stream = await stream_chat_completions( client=self.client, model=self.model, - messages=messages, + messages=messages, + endpoint_name=self.endpoint_name, max_tokens=MAX_TOKENS, temperature=0.7 ) @@ -473,8 +487,8 @@ async def interactive_chat(self) -> None: # ---------------------- CLI ---------------------- def build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Vast vLLM Demo (Serverless SDK)") - p.add_argument("--model", required=True, help="Model to use for requests (required)") - p.add_argument("--endpoint", default="my-vllm-endpoint", help="Vast endpoint name (default: my-vllm-endpoint)") + p.add_argument("--model", default=DEFAULT_MODEL, help=f"Model to use for requests (default: {DEFAULT_MODEL})") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") modes = p.add_mutually_exclusive_group(required=False) modes.add_argument("--completion", action="store_true", help="Test completions endpoint") @@ -502,12 +516,14 @@ async def main_async(): print("Please specify exactly one test mode") sys.exit(1) - print(f"Using model: {args.model}") print("=" * 60) + print(f"Using model: {args.model}") + print(f"Using endpoint: {args.endpoint}") + try: async with Serverless() as client: - demo = APIDemo(client, args.model, ToolManager()) + demo = APIDemo(client, args.model, args.endpoint, ToolManager()) if args.completion: await demo.demo_completions() From 6b5b1341a79387a0bb953eaf304335ab50a8c0bf Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 18:38:42 -0800 Subject: [PATCH 23/40] update tgi client --- workers/openai/README.md | 5 +- workers/tgi/README.md | 102 +++++++++++++++-- workers/tgi/client.py | 231 +++++++++++++++++++++++++++++++++------ 3 files changed, 291 insertions(+), 47 deletions(-) diff --git a/workers/openai/README.md b/workers/openai/README.md index 0dbaaa4..f7596f3 100644 --- a/workers/openai/README.md +++ b/workers/openai/README.md @@ -8,14 +8,13 @@ This is the base PyWorker for OpenAI compatible inference servers. See the [Ser This worker is compatible with any backend API that properly implements the `/v1/completions` and `/v1/chat/completions` endpoints. We currently have three templates you can choose from but you can also create your own without having to modify the PyWorker. -- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20%2B%20Qwen%2FQwen3-8B%20(Serverless)) (recommended) +- [vLLM](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=vLLM%20(Serverless)) (recommended) - [Ollama](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=Ollama%20%2B%20Qwen3%3A32b%20(Serverless)) -- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20%2B%20Qwen3-8B%20(Serverless)) All of these templates can be configured via the template interface. You may want to change the model or startup arguments, depending on the template you selected. -2. Follow the [getting started guide](https://docs.vast.ai/serverless/getting-started) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. ## Client Setup (Demo) diff --git a/workers/tgi/README.md b/workers/tgi/README.md index 5cf8488..9147e38 100644 --- a/workers/tgi/README.md +++ b/workers/tgi/README.md @@ -1,19 +1,103 @@ -This is the base PyWorker for TGI, designed to create PyWorkers that can utilize various LLMs. It offers two primary endpoints: +# HuggingFace TGI PyWorker -1. `generate`: Generates the LLM's response to a given prompt in a single request. -2. `generate_stream`: Streams the LLM's response token by token. +This is the base PyWorker for HuggingFace Text Generation Inference (TGI) servers. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -Both endpoints use the following API payload format: +## Instance Setup + +1. Pick a template + +This worker is compatible with any TGI backend. We have a template you can use or you can create your own. + +- [HuggingFace TGI](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=TGI%20(Serverless)) + +The template can be configured via the template interface. You may want to change the model or startup arguments. + +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. + +## Client Setup (Demo) + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +## Using the Test Client + +The test client demonstrates both streaming and non-streaming generation using TGI's native API. + +First, set your API key as an environment variable: + +```bash +export VAST_API_KEY= +``` + +The `--endpoint` flag is optional. If not provided, it defaults to `my-tgi-endpoint`. + +### Generate (Streaming) + +Call to `/generate_stream` with streaming response: + +```bash +python -m workers.tgi.client --generate-stream --endpoint +``` + +### Generate (Non-Streaming) + +Call to `/generate` with json response: + +```bash +python -m workers.tgi.client --generate --endpoint +``` + +### Interactive Session (Streaming) + +Interactive session with streaming responses. Type `quit` to exit. + +```bash +python -m workers.tgi.client --interactive --endpoint +``` + +## API Endpoints + +TGI provides two primary endpoints: + +### Generate (Non-Streaming) + +`/generate` - Returns the complete response in a single request. ```json { - "inputs": "PROMPT", + "inputs": "Your prompt here", "parameters": { - "max_new_tokens": 250 + "max_new_tokens": 1024, + "temperature": 0.7, + "return_full_text": false } } ``` -Note that the max_new_tokens parameter, rather than the prompt size, impacts performance. For example, if an -instance is benchmarked to process 100 tokens per second, a request with max_new_tokens = 200 will take -approximately 2 seconds to complete. +### Generate Stream (Streaming) + +`/generate_stream` - Streams the response token by token. + +```json +{ + "inputs": "Your prompt here", + "parameters": { + "max_new_tokens": 1024, + "temperature": 0.7, + "do_sample": true, + "return_full_text": false + } +} +``` + +## Performance Notes + +The `max_new_tokens` parameter (not the prompt size) primarily impacts performance. For example, if an instance is benchmarked to process 100 tokens per second, a request with `max_new_tokens = 200` will take approximately 2 seconds to complete. diff --git a/workers/tgi/client.py b/workers/tgi/client.py index f307602..23b40c2 100644 --- a/workers/tgi/client.py +++ b/workers/tgi/client.py @@ -1,61 +1,222 @@ +import logging +import json +import os +import sys +import argparse + from vastai import Serverless import asyncio -ENDPOINT_NAME = "my-tgi-endpoint" # Change this to match your endpoint name +# ---------------------- Logging ---------------------- +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s[%(levelname)-5s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +log = logging.getLogger(__file__) + +# ---------------------- Defaults ---------------------- +DEFAULT_PROMPT = "Think step by step: Tell me about the Python programming language." + +ENDPOINT_NAME = "TGI-Prod2" # change this to your TGI endpoint name MAX_TOKENS = 1024 -PROMPT = "Think step by step: Tell me about the Python programming language." +DEFAULT_TEMPERATURE = 0.7 -async def call_generate(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) + +# ---------------------- API Calls ---------------------- +async def call_generate(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs) -> dict: + """Non-streaming generation via /generate endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, - "return_full_text": False + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "return_full_text": False, } } + log.debug("POST /generate %s", json.dumps(payload)[:500]) + resp = await endpoint.request("/generate", payload, cost=payload["parameters"]["max_new_tokens"]) + return resp["response"] - resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS) - print(resp["response"]["generated_text"]) - - -async def call_generate_stream(client: Serverless) -> None: - endpoint = await client.get_endpoint(name=ENDPOINT_NAME) +async def call_generate_stream(client: Serverless, *, endpoint_name: str, prompt: str, **kwargs): + """Streaming generation via /generate_stream endpoint""" + endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "inputs": PROMPT, + "inputs": prompt, "parameters": { - "max_new_tokens": MAX_TOKENS, - "temperature": 0.7, + "max_new_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), "do_sample": True, "return_full_text": False, } } - + log.debug("STREAM /generate_stream %s", json.dumps(payload)[:500]) resp = await endpoint.request( "/generate_stream", payload, - cost=MAX_TOKENS, + cost=payload["parameters"]["max_new_tokens"], stream=True, ) - stream = resp["response"] - - printed_answer = False - async for event in stream: - tok = (event.get("token") or {}).get("text") - if tok: - if not printed_answer: - printed_answer = True - print("Answer:\n", end="", flush=True) - print(tok, end="", flush=True) - -async def main(): - async with Serverless() as client: - await call_generate(client) - await call_generate_stream(client) + return resp["response"] # async generator + + +# ---------------------- Demo Runner ---------------------- +class APIDemo: + """Demo and testing functionality for the TGI API client""" + + def __init__(self, client: Serverless, endpoint_name: str): + self.client = client + self.endpoint_name = endpoint_name + + async def handle_streaming_response(self, stream) -> str: + """Process streaming response and print tokens""" + full_response = "" + printed_answer = False + + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + if not printed_answer: + printed_answer = True + print("\n💬 Response: ", end="", flush=True) + print(tok, end="", flush=True) + full_response += tok + + print() # newline + if printed_answer: + print(f"\nStreaming completed. Response tokens: {len(full_response.split())}") + + return full_response + + async def demo_generate(self) -> None: + """Demo non-streaming generation""" + print("=" * 60) + print("GENERATE DEMO (NON-STREAMING)") + print("=" * 60) + + response = await call_generate( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + print(f"\n💬 Response: {response.get('generated_text', '')}") + print(f"\nFull Response:\n{json.dumps(response, indent=2)}") + + async def demo_generate_stream(self) -> None: + """Demo streaming generation""" + print("=" * 60) + print("GENERATE DEMO (STREAMING)") + print("=" * 60) + + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=DEFAULT_PROMPT, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + try: + await self.handle_streaming_response(stream) + except Exception as e: + log.error("\nError during streaming: %s", e, exc_info=True) + + async def interactive_chat(self) -> None: + """Interactive session with streaming generation""" + print("=" * 60) + print("INTERACTIVE STREAMING SESSION") + print("=" * 60) + print(f"Using endpoint: {self.endpoint_name}") + print("Type 'quit' to exit") + print() + + while True: + try: + user_input = input("You: ").strip() + + if user_input.lower() == "quit": + print("👋 Goodbye!") + break + elif not user_input: + continue + + print("Assistant: ", end="", flush=True) + stream = await call_generate_stream( + client=self.client, + endpoint_name=self.endpoint_name, + prompt=user_input, + max_tokens=MAX_TOKENS, + temperature=DEFAULT_TEMPERATURE, + ) + + full_response = "" + async for event in stream: + tok = (event.get("token") or {}).get("text") + if tok: + print(tok, end="", flush=True) + full_response += tok + print() # newline + + except KeyboardInterrupt: + print("\n👋 Session interrupted. Goodbye!") + break + except Exception as e: + log.error("\nError: %s", e) + continue + + +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast TGI Demo (Serverless SDK)") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") + + modes = p.add_mutually_exclusive_group(required=False) + modes.add_argument("--generate", action="store_true", help="Test generate endpoint (non-streaming)") + modes.add_argument("--generate-stream", action="store_true", help="Test generate endpoint with streaming") + modes.add_argument("--interactive", action="store_true", help="Start interactive streaming session") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + selected = sum([args.generate, args.generate_stream, args.interactive]) + if selected == 0: + print("Please specify exactly one test mode:") + print(" --generate : Test generate endpoint (non-streaming)") + print(" --generate-stream : Test generate endpoint with streaming") + print(" --interactive : Start interactive streaming session") + print(f"\nExample: python {os.path.basename(sys.argv[0])} --generate-stream --endpoint my-tgi-endpoint") + sys.exit(1) + elif selected > 1: + print("Please specify exactly one test mode") + sys.exit(1) + + print("=" * 60) + print(f"Using endpoint: {args.endpoint}") + + try: + async with Serverless() as client: + demo = APIDemo(client, args.endpoint) + + if args.generate: + await demo.demo_generate() + elif args.generate_stream: + await demo.demo_generate_stream() + elif args.interactive: + await demo.interactive_chat() + + except Exception as e: + log.error("Error during test: %s", e, exc_info=True) + sys.exit(1) + if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main_async()) From f04138e13bee6835e9237fd129f4f7a40c1550ff Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:16:25 -0800 Subject: [PATCH 24/40] update to be able to get images --- workers/comfyui-json/client.py | 351 ++++++++++++++++++++++++++++++--- workers/comfyui-json/server.py | 32 +++ 2 files changed, 359 insertions(+), 24 deletions(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index 93e184c..b80a9ba 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -1,35 +1,338 @@ -from .data_types import count_workload +import os +import sys +import json import uuid import random +import base64 import asyncio -import random +import logging +import argparse from vastai import Serverless -async def main(): - async with Serverless() as client: - endpoint = await client.get_endpoint(name="my-comfy-endpoint") # Change this to your endpoint name - - payload = { - "input": { - "request_id": str(uuid.uuid4()), - "modifier": "Text2Image", - "modifications": { - "prompt": "a beautiful landscape with mountains and lakes", - "width": 1024, - "height": 1024, - "steps": 20, - "seed": random.randint(0, 2**32 - 1) - }, - "workflow_json": {} # Empty since using modifier approach - } +# ---------------------- Config ---------------------- +DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" +ENDPOINT_NAME = "Comfy-Prod2" +DEFAULT_WIDTH = 512 +DEFAULT_HEIGHT = 512 +DEFAULT_STEPS = 20 +COST = 100 # Fixed cost for ComfyUI requests + +logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") +log = logging.getLogger(__name__) + + +# ---------------------- API Functions ---------------------- +async def call_generate( + client: Serverless, + *, + endpoint_name: str, + prompt: str, + width: int, + height: int, + steps: int, + seed: int, +) -> dict: + """Generate image using Text2Image modifier""" + endpoint = await client.get_endpoint(name=endpoint_name) + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "modifier": "Text2Image", + "modifications": { + "prompt": prompt, + "width": width, + "height": height, + "steps": steps, + "seed": seed, + }, + } + } + return await endpoint.request("/generate/sync", payload, cost=COST) + + +async def call_generate_workflow( + client: Serverless, + *, + endpoint_name: str, + workflow_json: dict, +) -> dict: + """Generate using custom workflow JSON""" + endpoint = await client.get_endpoint(name=endpoint_name) + payload = { + "input": { + "request_id": str(uuid.uuid4()), + "workflow_json": workflow_json, } + } + return await endpoint.request("/generate/sync", payload, cost=COST) + + +# ---------------------- Demo Class ---------------------- +class APIDemo: + def __init__(self, client: Serverless, endpoint_name: str): + self.client = client + self.endpoint_name = endpoint_name + + def extract_images(self, response: dict) -> list: + """Extract image info from ComfyUI response""" + images = [] - response = await endpoint.request("/generate/sync", payload, cost=count_workload()) + # Check for output array (S3/webhook configured) + if "output" in response: + for item in response["output"]: + if "url" in item: + images.append({"type": "url", "path": item["url"]}) + elif "local_path" in item: + images.append({"type": "local", "path": item["local_path"]}) + elif "base64" in item: + images.append({"type": "base64", "data": item["base64"]}) + + # Check for comfyui_response format (default) + if "comfyui_response" in response: + for prompt_id, data in response["comfyui_response"].items(): + if isinstance(data, dict) and "outputs" in data: + for node_id, node_output in data["outputs"].items(): + if "images" in node_output: + for img in node_output["images"]: + images.append({ + "type": "remote", + "filename": img.get("filename"), + "subfolder": img.get("subfolder", ""), + }) + + return images + + async def save_images(self, images: list, worker_url: str, prefix: str = "comfy") -> list: + """Save images locally by fetching from remote server""" + os.makedirs("generated_images", exist_ok=True) + saved = [] + seen = set() + + for i, img in enumerate(images): + if img["type"] == "base64": + data = img["data"] + if data.startswith("data:"): + data = data.split(",", 1)[-1] + path = f"generated_images/{prefix}_{i}.png" + with open(path, "wb") as f: + f.write(base64.b64decode(data)) + print(f" 💾 Saved: {path}") + saved.append(path) + + elif img["type"] == "url": + url = img["path"] + if url in seen: + continue + seen.add(url) + try: + import urllib.request + path = f"generated_images/{prefix}_{len(saved)}.png" + urllib.request.urlretrieve(url, path) + print(f" 💾 Downloaded: {path}") + saved.append(path) + except Exception as e: + print(f" 🔗 URL: {url}") + saved.append(url) + + elif img["type"] == "local": + remote_path = img["path"] + if remote_path in seen: + continue + seen.add(remote_path) + filename = os.path.basename(remote_path) + # Try to fetch via /view endpoint + local_path = await self._fetch_image(worker_url, filename, "", f"{prefix}_{len(saved)}.png") + if local_path: + saved.append(local_path) + else: + print(f" 📂 Remote: {remote_path}") + saved.append(remote_path) + + elif img["type"] == "remote": + filename = img["filename"] + if filename in seen: + continue + seen.add(filename) + subfolder = img.get("subfolder", "") + # Try to fetch via /view endpoint + local_path = await self._fetch_image(worker_url, filename, subfolder, f"{prefix}_{len(saved)}.png") + if local_path: + saved.append(local_path) + else: + print(f" 🖼️ Remote: {filename}") + saved.append(filename) + + return saved + + async def _fetch_image(self, worker_url: str, filename: str, subfolder: str, local_name: str) -> str | None: + """Fetch image directly from worker's /view endpoint""" + if not worker_url: + print(f" ⚠️ No worker URL available") + return None + + try: + import aiohttp + + params = {"filename": filename, "type": "output"} + if subfolder: + params["subfolder"] = subfolder + + url = f"{worker_url}/view" + print(f" 🔗 Fetching from: {url}") + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, ssl=False) as resp: + if resp.status == 200: + raw_bytes = await resp.read() + path = f"generated_images/{local_name}" + with open(path, "wb") as f: + f.write(raw_bytes) + print(f" 💾 Saved: {path}") + return path + else: + text = await resp.text() + print(f" ❌ HTTP {resp.status}: {text[:100]}") + return None + except Exception as e: + print(f" ❌ Fetch error: {e}") + return None + + async def demo_prompt( + self, + prompt: str, + width: int, + height: int, + steps: int, + seed: int | None, + ): + """Demo: Generate image from text prompt""" + print("=" * 60) + print("COMFYUI TEXT-TO-IMAGE DEMO") + print("=" * 60) + + if seed is None: + seed = random.randint(0, 2**32 - 1) + + print(f"Prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Prompt: {prompt}") + print(f"Size: {width}x{height}, Steps: {steps}, Seed: {seed}") + print("\n🎨 Generating image...") + + response = await call_generate( + self.client, + endpoint_name=self.endpoint_name, + prompt=prompt, + width=width, + height=height, + steps=steps, + seed=seed, + ) + + print("\n✅ Generation complete!") + + # Get worker URL for fetching images + worker_url = response.get("url", "") + print(f"Worker URL: {worker_url}") + + # Extract and handle images + if "response" in response: + images = self.extract_images(response["response"]) + if images: + print(f"\n📁 {len(images)} image(s) generated:") + await self.save_images(images, worker_url, prefix=f"comfy_{seed}") + else: + print("\nNo images found in response") + print(json.dumps(response, indent=2, default=str)[:2000]) + else: + print("\nUnexpected response format") + print(json.dumps(response, indent=2, default=str)[:2000]) + + async def demo_workflow(self, workflow_file: str): + """Demo: Generate using custom workflow file""" + print("=" * 60) + print("COMFYUI CUSTOM WORKFLOW DEMO") + print("=" * 60) + + if not os.path.exists(workflow_file): + log.error(f"Workflow file not found: {workflow_file}") + return + + with open(workflow_file, "r") as f: + workflow_json = json.load(f) + + print(f"Workflow: {workflow_file}") + print("\n🎨 Generating...") + + response = await call_generate_workflow( + self.client, + endpoint_name=self.endpoint_name, + workflow_json=workflow_json, + ) + + print("\n✅ Generation complete!") + + worker_url = response.get("url", "") + + if "response" in response: + images = self.extract_images(response["response"]) + if images: + print(f"\n📁 {len(images)} image(s) generated:") + await self.save_images(images, worker_url, prefix="workflow") + else: + print("\nNo images found in response") + print(json.dumps(response, indent=2, default=str)[:2000]) + else: + print("\nUnexpected response format") + print(json.dumps(response, indent=2, default=str)[:2000]) + + +# ---------------------- CLI ---------------------- +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Vast ComfyUI-JSON Demo (Serverless SDK)") + p.add_argument("--endpoint", default=ENDPOINT_NAME, help=f"Vast endpoint name (default: {ENDPOINT_NAME})") + p.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, metavar="TEXT", + help=f"Prompt text (default: '{DEFAULT_PROMPT[:30]}...')") + p.add_argument("--workflow", type=str, metavar="FILE", help="Use custom workflow JSON file instead") + p.add_argument("--width", type=int, default=DEFAULT_WIDTH, help=f"Image width (default: {DEFAULT_WIDTH})") + p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})") + p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})") + p.add_argument("--seed", type=int, default=None, help="Seed (default: random)") + return p + + +async def main_async(): + args = build_arg_parser().parse_args() + + print("=" * 60) + print(f"Using endpoint: {args.endpoint}") + + try: + async with Serverless() as client: + demo = APIDemo(client, args.endpoint) + + if args.workflow: + await demo.demo_workflow(workflow_file=args.workflow) + else: + await demo.demo_prompt( + prompt=args.prompt, + width=args.width, + height=args.height, + steps=args.steps, + seed=args.seed, + ) + + except AttributeError as e: + if "API key" in str(e): + log.error("API key missing. Set VAST_API_KEY environment variable.") + else: + log.error(f"Error: {e}") + sys.exit(1) + except Exception as e: + log.error(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) - # Get the file from the path on the local machine using SCP or SFTP - # or configure S3 to upload to cloud storage. - print(response["response"]["output"][0]["local_path"]) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main_async()) diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py index ed4e578..7998e71 100644 --- a/workers/comfyui-json/server.py +++ b/workers/comfyui-json/server.py @@ -4,6 +4,7 @@ import base64 from typing import Optional, Union, Type +import aiohttp from aiohttp import web, ClientResponse from lib.backend import Backend, LogAction @@ -108,8 +109,39 @@ async def handle_ping(_): return web.Response(body="pong") +async def handle_view(request: web.Request) -> web.Response: + """Proxy /view requests to ComfyUI to fetch generated images""" + # Forward query params to ComfyUI + query_string = request.query_string + url = f"{MODEL_SERVER_URL}/view?{query_string}" + + log.debug(f"Proxying /view request to: {url}") + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + if resp.status == 200: + content = await resp.read() + return web.Response( + body=content, + status=200, + content_type=resp.content_type or "image/png" + ) + else: + text = await resp.text() + return web.Response( + text=text, + status=resp.status, + content_type="text/plain" + ) + except Exception as e: + log.error(f"Error proxying /view: {e}") + return web.Response(text=str(e), status=500) + + routes = [ web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())), + web.get("/view", handle_view), web.get("/ping", handle_ping), ] From e839cfc6e8fa3a32eef152381359df27bf15a953 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:22:45 -0800 Subject: [PATCH 25/40] include view in API wrapper --- workers/comfyui-json/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py index 7998e71..daf35e5 100644 --- a/workers/comfyui-json/server.py +++ b/workers/comfyui-json/server.py @@ -14,6 +14,7 @@ MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288") +COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server # This is the last log line that gets emitted once comfyui+extensions have been fully loaded MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: " @@ -110,10 +111,10 @@ async def handle_ping(_): async def handle_view(request: web.Request) -> web.Response: - """Proxy /view requests to ComfyUI to fetch generated images""" - # Forward query params to ComfyUI + """Proxy /view requests to raw ComfyUI server to fetch generated images""" + # Forward query params to raw ComfyUI (not the API wrapper) query_string = request.query_string - url = f"{MODEL_SERVER_URL}/view?{query_string}" + url = f"{COMFYUI_URL}/view?{query_string}" log.debug(f"Proxying /view request to: {url}") From d4d36bf86e03f40f727179975dfb8d53518e9ed2 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Wed, 3 Dec 2025 20:45:55 -0800 Subject: [PATCH 26/40] done with comfy updates --- workers/comfyui-json/README.md | 71 ++++++++++++--- workers/comfyui-json/client.py | 160 ++++++++------------------------- 2 files changed, 96 insertions(+), 135 deletions(-) diff --git a/workers/comfyui-json/README.md b/workers/comfyui-json/README.md index 7aa1ba3..5306a23 100644 --- a/workers/comfyui-json/README.md +++ b/workers/comfyui-json/README.md @@ -1,8 +1,16 @@ # ComfyUI PyWorker -This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. +This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. +The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. + +## Instance Setup + +1. Pick a template + +- [ComfyUI (Serverless)](https://cloud.vast.ai/?ref_id=62897&creator_id=62897&name=ComfyUI%20(Serverless)) + +2. Follow the [getting started guide](https://docs.vast.ai/documentation/serverless/quickstart) for help with configuring your serverless setup. For testing, we recommend that you use the default options presented by the web interface. ## Requirements @@ -10,6 +18,57 @@ This worker requires both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) a A docker image is provided but you may use any if the above requirements are met. +## Client + +The client demonstrates how to use the Vast Serverless SDK to generate images and save them locally. + +### Setup + +1. Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. + +```bash +git clone https://github.com/vast-ai/pyworker +cd pyworker +pip install uv +uv venv -p 3.12 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +2. Set your API key: + +```bash +export VAST_API_KEY= +``` + +### Usage + +```bash +# Default prompt +python -m workers.comfyui-json.client + +# Custom prompt +python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow" + +# With options +python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30 +``` + +### CLI Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name | +| `--prompt` | (default) | Text prompt for image generation | +| `--width` | 512 | Image width in pixels | +| `--height` | 512 | Image height in pixels | +| `--steps` | 20 | Number of denoising steps | +| `--seed` | (random) | Random seed for reproducibility | + +### Output + +Images are saved to `./generated_images/comfy_{seed}.png`. + ## Benchmarking ### Custom Benchmark Workflows @@ -212,11 +271,3 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds } } ``` - -## Client Libraries - -See the test client examples for implementation details on how to integrate with the ComfyUI worker. - ---- - -See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler. \ No newline at end of file diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index b80a9ba..a243183 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -3,16 +3,16 @@ import json import uuid import random -import base64 import asyncio import logging import argparse +import aiohttp from vastai import Serverless # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "Comfy-Prod2" +ENDPOINT_NAME = "my-comfyui-endpoint" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 @@ -74,128 +74,40 @@ def __init__(self, client: Serverless, endpoint_name: str): self.client = client self.endpoint_name = endpoint_name - def extract_images(self, response: dict) -> list: - """Extract image info from ComfyUI response""" - images = [] - - # Check for output array (S3/webhook configured) - if "output" in response: - for item in response["output"]: - if "url" in item: - images.append({"type": "url", "path": item["url"]}) - elif "local_path" in item: - images.append({"type": "local", "path": item["local_path"]}) - elif "base64" in item: - images.append({"type": "base64", "data": item["base64"]}) - - # Check for comfyui_response format (default) + def extract_filename(self, response: dict) -> str | None: + """Extract the generated image filename from ComfyUI response""" if "comfyui_response" in response: - for prompt_id, data in response["comfyui_response"].items(): + for data in response["comfyui_response"].values(): if isinstance(data, dict) and "outputs" in data: - for node_id, node_output in data["outputs"].items(): - if "images" in node_output: - for img in node_output["images"]: - images.append({ - "type": "remote", - "filename": img.get("filename"), - "subfolder": img.get("subfolder", ""), - }) - - return images - - async def save_images(self, images: list, worker_url: str, prefix: str = "comfy") -> list: - """Save images locally by fetching from remote server""" + for node_output in data["outputs"].values(): + if "images" in node_output and node_output["images"]: + return node_output["images"][0].get("filename") + return None + + async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None: + """Fetch and save image locally from the worker""" os.makedirs("generated_images", exist_ok=True) - saved = [] - seen = set() - - for i, img in enumerate(images): - if img["type"] == "base64": - data = img["data"] - if data.startswith("data:"): - data = data.split(",", 1)[-1] - path = f"generated_images/{prefix}_{i}.png" - with open(path, "wb") as f: - f.write(base64.b64decode(data)) - print(f" 💾 Saved: {path}") - saved.append(path) - - elif img["type"] == "url": - url = img["path"] - if url in seen: - continue - seen.add(url) - try: - import urllib.request - path = f"generated_images/{prefix}_{len(saved)}.png" - urllib.request.urlretrieve(url, path) - print(f" 💾 Downloaded: {path}") - saved.append(path) - except Exception as e: - print(f" 🔗 URL: {url}") - saved.append(url) - - elif img["type"] == "local": - remote_path = img["path"] - if remote_path in seen: - continue - seen.add(remote_path) - filename = os.path.basename(remote_path) - # Try to fetch via /view endpoint - local_path = await self._fetch_image(worker_url, filename, "", f"{prefix}_{len(saved)}.png") - if local_path: - saved.append(local_path) - else: - print(f" 📂 Remote: {remote_path}") - saved.append(remote_path) - - elif img["type"] == "remote": - filename = img["filename"] - if filename in seen: - continue - seen.add(filename) - subfolder = img.get("subfolder", "") - # Try to fetch via /view endpoint - local_path = await self._fetch_image(worker_url, filename, subfolder, f"{prefix}_{len(saved)}.png") - if local_path: - saved.append(local_path) - else: - print(f" 🖼️ Remote: {filename}") - saved.append(filename) - - return saved - - async def _fetch_image(self, worker_url: str, filename: str, subfolder: str, local_name: str) -> str | None: - """Fetch image directly from worker's /view endpoint""" + return await self._fetch_image(worker_url, filename, local_name) + + async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None: + """Fetch image from worker's /view endpoint and save locally""" if not worker_url: - print(f" ⚠️ No worker URL available") return None try: - import aiohttp - - params = {"filename": filename, "type": "output"} - if subfolder: - params["subfolder"] = subfolder - url = f"{worker_url}/view" - print(f" 🔗 Fetching from: {url}") + params = {"filename": filename, "type": "output"} async with aiohttp.ClientSession() as session: async with session.get(url, params=params, ssl=False) as resp: if resp.status == 200: - raw_bytes = await resp.read() path = f"generated_images/{local_name}" with open(path, "wb") as f: - f.write(raw_bytes) + f.write(await resp.read()) print(f" 💾 Saved: {path}") return path - else: - text = await resp.text() - print(f" ❌ HTTP {resp.status}: {text[:100]}") - return None - except Exception as e: - print(f" ❌ Fetch error: {e}") + return None + except Exception: return None async def demo_prompt( @@ -234,18 +146,17 @@ async def demo_prompt( worker_url = response.get("url", "") print(f"Worker URL: {worker_url}") - # Extract and handle images + # Fetch and save image if "response" in response: - images = self.extract_images(response["response"]) - if images: - print(f"\n📁 {len(images)} image(s) generated:") - await self.save_images(images, worker_url, prefix=f"comfy_{seed}") + filename = self.extract_filename(response["response"]) + if filename: + path = await self.save_image(worker_url, filename, f"comfy_{seed}.png") + if not path: + print(f"❌ Failed to fetch image") else: - print("\nNo images found in response") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("❌ No image in response") else: - print("\nUnexpected response format") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("❌ Unexpected response format") async def demo_workflow(self, workflow_file: str): """Demo: Generate using custom workflow file""" @@ -274,16 +185,15 @@ async def demo_workflow(self, workflow_file: str): worker_url = response.get("url", "") if "response" in response: - images = self.extract_images(response["response"]) - if images: - print(f"\n📁 {len(images)} image(s) generated:") - await self.save_images(images, worker_url, prefix="workflow") + filename = self.extract_filename(response["response"]) + if filename: + path = await self.save_image(worker_url, filename, "workflow.png") + if not path: + print(f"❌ Failed to fetch image") else: - print("\nNo images found in response") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("❌ No image in response") else: - print("\nUnexpected response format") - print(json.dumps(response, indent=2, default=str)[:2000]) + print("❌ Unexpected response format") # ---------------------- CLI ---------------------- From 40aed9b5f8d85f3f5589cf10cc528901f57b8976 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 4 Dec 2025 10:52:57 -0800 Subject: [PATCH 27/40] adding s3 as an option --- workers/comfyui-json/README.md | 33 ++++++++++++++- workers/comfyui-json/client.py | 74 +++++++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/workers/comfyui-json/README.md b/workers/comfyui-json/README.md index 5306a23..9517dbb 100644 --- a/workers/comfyui-json/README.md +++ b/workers/comfyui-json/README.md @@ -20,7 +20,7 @@ A docker image is provided but you may use any if the above requirements are met ## Client -The client demonstrates how to use the Vast Serverless SDK to generate images and save them locally. +The client demonstrates how to use the Vast Serverless SDK to generate images, save them locally, and optionally upload to S3-compatible storage. ### Setup @@ -52,6 +52,12 @@ python -m workers.comfyui-json.client --prompt "a cat sitting on a rainbow" # With options python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 1024 --steps 30 + +# Using a custom workflow file +python -m workers.comfyui-json.client --workflow my_workflow.json + +# With S3 upload +python -m workers.comfyui-json.client --s3 ``` ### CLI Flags @@ -60,15 +66,40 @@ python -m workers.comfyui-json.client --prompt "sunset" --width 1024 --height 10 |------|---------|-------------| | `--endpoint` | `my-comfyui-endpoint` | Vast endpoint name | | `--prompt` | (default) | Text prompt for image generation | +| `--workflow` | (none) | Path to custom workflow JSON file | | `--width` | 512 | Image width in pixels | | `--height` | 512 | Image height in pixels | | `--steps` | 20 | Number of denoising steps | | `--seed` | (random) | Random seed for reproducibility | +| `--s3` | (disabled) | Upload generated images to S3 | ### Output Images are saved to `./generated_images/comfy_{seed}.png`. +### S3 Upload (Optional) + +You can optionally upload generated images to an S3-compatible storage service (AWS S3, Cloudflare R2, Backblaze B2, etc.) by using the `--s3` flag. + +**1. Set environment variables:** + +```bash +export S3_ENDPOINT_URL="https://your-account.r2.cloudflarestorage.com" +export S3_BUCKET_NAME="my-bucket" +export S3_ACCESS_KEY_ID="your-access-key-id" +export S3_SECRET_ACCESS_KEY="your-secret-access-key" +``` + +**2. Run with S3 upload enabled:** + +```bash +python -m workers.comfyui-json.client --prompt "a beautiful landscape" --s3 +``` + +Images will be saved locally AND uploaded to `s3://{bucket}/comfyui/{filename}`. + +**Note:** Requires `boto3` (`pip install boto3`). + ## Benchmarking ### Custom Benchmark Workflows diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index a243183..10a1d91 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -12,16 +12,45 @@ # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "my-comfyui-endpoint" +ENDPOINT_NAME = "Comfy-Prod" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 COST = 100 # Fixed cost for ComfyUI requests +# Optional S3 Configuration (from environment variables) +S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") +S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") +S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") +S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") + logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") log = logging.getLogger(__name__) +def get_s3_client(): + """Create and return an S3 client configured for the S3-compatible endpoint""" + try: + import boto3 + from botocore.config import Config + except ImportError: + log.error("boto3 is required for S3 uploads. Install with: pip install boto3") + return None + + if not all([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): + log.error("S3 environment variables not fully configured. Required:") + log.error(" S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY") + return None + + return boto3.client( + "s3", + endpoint_url=S3_ENDPOINT_URL, + aws_access_key_id=S3_ACCESS_KEY_ID, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + config=Config(signature_version="s3v4"), + ) + + # ---------------------- API Functions ---------------------- async def call_generate( client: Serverless, @@ -70,9 +99,14 @@ async def call_generate_workflow( # ---------------------- Demo Class ---------------------- class APIDemo: - def __init__(self, client: Serverless, endpoint_name: str): + def __init__(self, client: Serverless, endpoint_name: str, upload_s3: bool = False): self.client = client self.endpoint_name = endpoint_name + self.upload_s3 = upload_s3 + self.s3_client = get_s3_client() if upload_s3 else None + + if upload_s3 and not self.s3_client: + log.warning("S3 upload requested but client creation failed. Images will only be saved locally.") def extract_filename(self, response: dict) -> str | None: """Extract the generated image filename from ComfyUI response""" @@ -85,10 +119,29 @@ def extract_filename(self, response: dict) -> str | None: return None async def save_image(self, worker_url: str, filename: str, local_name: str) -> str | None: - """Fetch and save image locally from the worker""" + """Fetch and save image locally from the worker, optionally upload to S3""" os.makedirs("generated_images", exist_ok=True) return await self._fetch_image(worker_url, filename, local_name) + def _upload_to_s3(self, local_path: str, s3_key: str) -> str | None: + """Upload a local file to S3 and return the S3 URL""" + if not self.s3_client: + return None + + try: + self.s3_client.upload_file( + local_path, + S3_BUCKET_NAME, + s3_key, + ExtraArgs={"ContentType": "image/png"} + ) + s3_url = f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3_key}" + print(f" ☁️ Uploaded to S3: {s3_key}") + return s3_url + except Exception as e: + log.error(f"Failed to upload to S3: {e}") + return None + async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> str | None: """Fetch image from worker's /view endpoint and save locally""" if not worker_url: @@ -102,9 +155,16 @@ async def _fetch_image(self, worker_url: str, filename: str, local_name: str) -> async with session.get(url, params=params, ssl=False) as resp: if resp.status == 200: path = f"generated_images/{local_name}" + image_data = await resp.read() with open(path, "wb") as f: - f.write(await resp.read()) + f.write(image_data) print(f" 💾 Saved: {path}") + + # Upload to S3 if enabled + if self.upload_s3 and self.s3_client: + s3_key = f"comfyui/{local_name}" + self._upload_to_s3(path, s3_key) + return path return None except Exception: @@ -207,6 +267,8 @@ def build_arg_parser() -> argparse.ArgumentParser: p.add_argument("--height", type=int, default=DEFAULT_HEIGHT, help=f"Image height (default: {DEFAULT_HEIGHT})") p.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Steps (default: {DEFAULT_STEPS})") p.add_argument("--seed", type=int, default=None, help="Seed (default: random)") + p.add_argument("--s3", action="store_true", + help="Upload generated images to S3 (requires S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY env vars)") return p @@ -215,10 +277,12 @@ async def main_async(): print("=" * 60) print(f"Using endpoint: {args.endpoint}") + if args.s3: + print(f"S3 upload: enabled (bucket: {S3_BUCKET_NAME})") try: async with Serverless() as client: - demo = APIDemo(client, args.endpoint) + demo = APIDemo(client, args.endpoint, upload_s3=args.s3) if args.workflow: await demo.demo_workflow(workflow_file=args.workflow) From 222ac2a0ddfe77c96abd5036bb78af6534274d85 Mon Sep 17 00:00:00 2001 From: Colter Downing Date: Thu, 4 Dec 2025 10:54:55 -0800 Subject: [PATCH 28/40] default endpoint name --- workers/comfyui-json/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workers/comfyui-json/client.py b/workers/comfyui-json/client.py index 10a1d91..d79b30d 100644 --- a/workers/comfyui-json/client.py +++ b/workers/comfyui-json/client.py @@ -12,7 +12,7 @@ # ---------------------- Config ---------------------- DEFAULT_PROMPT = "a beautiful sunset over mountains, digital art, highly detailed" -ENDPOINT_NAME = "Comfy-Prod" +ENDPOINT_NAME = "my-comfyui-endpoint" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_STEPS = 20 From 7be8aa63978ef0e061d2dc155484b65dd99e9b02 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 10 Dec 2025 17:38:03 -0800 Subject: [PATCH 29/40] pin pycares --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 377b20a..8583584 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -aiohttp[speedups]==3.10.1 +aiohttp~=3.10.1 +aiodns~=3.6.0 +pycares~=4.11.0 anyio~=4.4 lib~=4.0 nltk~=3.9 From df61e6e9467a7ad7650995f6109bfab366717158 Mon Sep 17 00:00:00 2001 From: edgaratvast Date: Wed, 10 Dec 2025 19:34:52 -0800 Subject: [PATCH 30/40] correct version pin for aiohttp (#73) Co-authored-by: Edgar Lin --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8583584..b484d2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp~=3.10.1 +aiohttp==3.10.1 aiodns~=3.6.0 pycares~=4.11.0 anyio~=4.4 @@ -10,4 +10,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 -vastai-sdk>=0.2.0 \ No newline at end of file +vastai-sdk>=0.2.0 From 4ecc07032ff829baeae0f0a807bb5df8c9a7e536 Mon Sep 17 00:00:00 2001 From: Abiola Akinnubi Date: Thu, 11 Dec 2025 12:51:56 -0800 Subject: [PATCH 31/40] Mark pyworkers as "Error" if startup script fails. to avoid silent fail that waits for autoscaler. --- start_server.sh | 167 +++++++++++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 57 deletions(-) diff --git a/start_server.sh b/start_server.sh index 4b07e01..2f5ecdc 100755 --- a/start_server.sh +++ b/start_server.sh @@ -22,10 +22,49 @@ function echo_var(){ echo "$1: ${!1}" } -[ -z "$BACKEND" ] && echo "BACKEND must be set!" && exit 1 -[ -z "$MODEL_LOG" ] && echo "MODEL_LOG must be set!" && exit 1 -[ -z "$HF_TOKEN" ] && echo "HF_TOKEN must be set!" && exit 1 -[ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && echo "For comfyui backends, COMFY_MODEL must be set!" && exit 1 +function report_error_and_exit(){ + local error_msg="$1" + echo "ERROR: $error_msg" + + # Report error to autoscaler + MTOKEN="${MASTER_TOKEN:-}" + VERSION="${PYWORKER_VERSION:-0}" + + IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" + for addr in "${REPORT_ADDRS[@]}"; do + curl -sS -X POST -H 'Content-Type: application/json' \ + -d "$(cat <> "$MODEL_LOG.old" - : > "$MODEL_LOG" + if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then + report_error_and_exit "Failed to rotate model log" + fi + if ! : > "$MODEL_LOG"; then + report_error_and_exit "Failed to truncate model log" + fi fi # Populate /etc/environment with quoted values if ! grep -q "VAST" /etc/environment; then - env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do + if ! env -0 | grep -zEv "^(HOME=|SHLVL=)|CONDA" | while IFS= read -r -d '' line; do name=${line%%=*} value=${line#*=} printf '%s="%s"\n' "$name" "$value" - done > /etc/environment + done > /etc/environment; then + echo "WARNING: Failed to populate /etc/environment, continuing anyway" + fi fi if [ ! -d "$ENV_PATH" ] then echo "setting up venv" if ! which uv; then - curl -LsSf https://astral.sh/uv/install.sh | sh - source ~/.local/bin/env + if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then + report_error_and_exit "Failed to install uv package manager" + fi + if [[ -f ~/.local/bin/env ]]; then + if ! source ~/.local/bin/env; then + report_error_and_exit "Failed to source uv environment" + fi + else + echo "WARNING: ~/.local/bin/env not found after uv installation" + fi fi # Fork testing - [[ ! -d $SERVER_DIR ]] && git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR" + if [[ ! -d $SERVER_DIR ]]; then + if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then + report_error_and_exit "Failed to clone pyworker repository" + fi + fi if [[ -n ${PYWORKER_REF:-} ]]; then - (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF") + if ! (cd "$SERVER_DIR" && git checkout "$PYWORKER_REF"); then + report_error_and_exit "Failed to checkout pyworker reference: $PYWORKER_REF" + fi fi - uv venv --python-preference only-managed "$ENV_PATH" -p 3.10 - source "$ENV_PATH/bin/activate" + if ! uv venv --python-preference only-managed "$ENV_PATH" -p 3.10; then + report_error_and_exit "Failed to create virtual environment" + fi + + if ! source "$ENV_PATH/bin/activate"; then + report_error_and_exit "Failed to activate virtual environment" + fi - uv pip install -r "${SERVER_DIR}/requirements.txt" + if ! uv pip install -r "${SERVER_DIR}/requirements.txt"; then + report_error_and_exit "Failed to install Python requirements" + fi - touch ~/.no_auto_tmux + if ! touch ~/.no_auto_tmux; then + report_error_and_exit "Failed to create ~/.no_auto_tmux" + fi else - [[ -f ~/.local/bin/env ]] && source ~/.local/bin/env - source "$WORKSPACE_DIR/worker-env/bin/activate" + if [[ -f ~/.local/bin/env ]]; then + if ! source ~/.local/bin/env; then + report_error_and_exit "Failed to source uv environment" + fi + fi + if ! source "$WORKSPACE_DIR/worker-env/bin/activate"; then + report_error_and_exit "Failed to activate existing virtual environment" + fi echo "environment activated" echo "venv: $VIRTUAL_ENV" fi -[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && echo "$BACKEND not supported!" && exit 1 +[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && report_error_and_exit "$BACKEND not supported!" if [ "$USE_SSL" = true ]; then - cat << EOF > /etc/openssl-san.cnf + if ! cat << EOF > /etc/openssl-san.cnf [req] default_bits = 2048 distinguished_name = req_distinguished_name @@ -109,18 +183,25 @@ if [ "$USE_SSL" = true ]; then [alt_names] IP.1 = 0.0.0.0 EOF + then + report_error_and_exit "Failed to write OpenSSL config" + fi - openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ + if ! openssl req -newkey rsa:2048 -subj "/C=US/ST=CA/CN=pyworker.vast.ai/" \ -nodes \ -sha256 \ -keyout /etc/instance.key \ -out /etc/instance.csr \ - -config /etc/openssl-san.cnf + -config /etc/openssl-san.cnf; then + report_error_and_exit "Failed to generate SSL certificate request" + fi - curl --header 'Content-Type: application/octet-stream' \ - --data-binary @//etc/instance.csr \ + if ! curl --header 'Content-Type: application/octet-stream' \ + --data-binary @/etc/instance.csr \ -X \ - POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; + POST "https://console.vast.ai/api/v0/sign_cert/?instance_id=$CONTAINER_ID" > /etc/instance.crt; then + report_error_and_exit "Failed to sign SSL certificate" + fi fi @@ -128,7 +209,9 @@ fi export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED -cd "$SERVER_DIR" +if ! cd "$SERVER_DIR"; then + report_error_and_exit "Failed to cd into SERVER_DIR: $SERVER_DIR" +fi echo "launching PyWorker server" @@ -138,37 +221,7 @@ PY_STATUS=${PIPESTATUS[0]} set -e if [ "${PY_STATUS}" -ne 0 ]; then - echo "PyWorker exited with status ${PY_STATUS}; notifying autoscaler..." - ERROR_MSG="PyWorker exited: code ${PY_STATUS}" - MTOKEN="${MASTER_TOKEN:-}" - VERSION="${PYWORKER_VERSION:-0}" - - IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" - for addr in "${REPORT_ADDRS[@]}"; do - curl -sS -X POST -H 'Content-Type: application/json' \ - -d "$(cat < Date: Mon, 15 Dec 2025 22:33:03 -0500 Subject: [PATCH 32/40] Use PyWorker SDK (#67) * Change PyWorker to Worker SDK * Moved /lib to vast-sdk (https://github.com/vast-ai/vast-sdk) --- README.md | 192 +++++--- lib/backend.py | 434 ------------------ lib/data_types.py | 324 ------------- lib/metrics.py | 286 ------------ lib/server.py | 60 --- lib/test_utils.py | 310 ------------- lib/wheres-my-pyworker.txt | 22 + requirements.txt | 2 +- start_server.sh | 63 ++- utils/endpoint_util.py | 136 ------ utils/ssl.py | 15 - workers/ace/README.md | 168 +++++++ {lib => workers/ace}/__init__.py | 0 workers/ace/client.py | 149 ++++++ workers/ace/worker.py | 184 ++++++++ workers/comfyui-json/README.md | 10 +- workers/comfyui-json/data_types.py | 84 ---- .../comfyui-json/misc/benchmark.json.example | 107 ----- workers/comfyui-json/misc/note.txt | 1 + workers/comfyui-json/misc/test_prompts.txt | 34 -- workers/comfyui-json/server.py | 150 ------ workers/comfyui-json/test_load.py | 8 - workers/comfyui-json/worker.py | 81 ++++ workers/comfyui/README.md | 92 ---- workers/comfyui/__init__.py | 0 workers/comfyui/client.py | 170 ------- workers/comfyui/data_types.py | 205 --------- .../comfyui/misc/default_workflows/flux.json | 137 ------ .../comfyui/misc/default_workflows/sd3.json | 142 ------ workers/comfyui/misc/test_prompts.txt | 34 -- workers/comfyui/server.py | 143 ------ workers/comfyui/test_load.py | 15 - workers/hello_world/README.md | 321 ------------- workers/hello_world/__init__.py | 0 workers/hello_world/client.py | 0 workers/hello_world/data_types.py | 48 -- workers/hello_world/server.py | 175 ------- workers/hello_world/test_load.py | 7 - workers/openai/README.templates.md | 77 ---- workers/openai/client.py | 62 ++- workers/openai/data_types/__init__.py | 0 workers/openai/data_types/client.py | 58 --- workers/openai/data_types/server.py | 207 --------- workers/openai/server.py | 62 --- workers/openai/test_load.py | 434 ------------------ workers/openai/worker.py | 78 ++++ workers/tgi/data_types.py | 73 --- workers/tgi/server.py | 130 ------ workers/tgi/test_load.py | 7 - workers/tgi/worker.py | 76 +++ workers/wan/README.md | 170 +++++++ {utils => workers/wan}/__init__.py | 0 workers/wan/client.py | 205 +++++++++ workers/wan/worker.py | 288 ++++++++++++ 54 files changed, 1616 insertions(+), 4620 deletions(-) delete mode 100644 lib/backend.py delete mode 100644 lib/data_types.py delete mode 100644 lib/metrics.py delete mode 100644 lib/server.py delete mode 100644 lib/test_utils.py create mode 100644 lib/wheres-my-pyworker.txt delete mode 100644 utils/endpoint_util.py delete mode 100644 utils/ssl.py create mode 100644 workers/ace/README.md rename {lib => workers/ace}/__init__.py (100%) create mode 100644 workers/ace/client.py create mode 100644 workers/ace/worker.py delete mode 100644 workers/comfyui-json/data_types.py delete mode 100644 workers/comfyui-json/misc/benchmark.json.example create mode 100644 workers/comfyui-json/misc/note.txt delete mode 100644 workers/comfyui-json/misc/test_prompts.txt delete mode 100644 workers/comfyui-json/server.py delete mode 100644 workers/comfyui-json/test_load.py create mode 100644 workers/comfyui-json/worker.py delete mode 100644 workers/comfyui/README.md delete mode 100644 workers/comfyui/__init__.py delete mode 100644 workers/comfyui/client.py delete mode 100644 workers/comfyui/data_types.py delete mode 100644 workers/comfyui/misc/default_workflows/flux.json delete mode 100644 workers/comfyui/misc/default_workflows/sd3.json delete mode 100644 workers/comfyui/misc/test_prompts.txt delete mode 100644 workers/comfyui/server.py delete mode 100644 workers/comfyui/test_load.py delete mode 100644 workers/hello_world/README.md delete mode 100644 workers/hello_world/__init__.py delete mode 100644 workers/hello_world/client.py delete mode 100644 workers/hello_world/data_types.py delete mode 100644 workers/hello_world/server.py delete mode 100644 workers/hello_world/test_load.py delete mode 100644 workers/openai/README.templates.md delete mode 100644 workers/openai/data_types/__init__.py delete mode 100644 workers/openai/data_types/client.py delete mode 100644 workers/openai/data_types/server.py delete mode 100644 workers/openai/server.py delete mode 100644 workers/openai/test_load.py create mode 100644 workers/openai/worker.py delete mode 100644 workers/tgi/data_types.py delete mode 100644 workers/tgi/server.py delete mode 100644 workers/tgi/test_load.py create mode 100644 workers/tgi/worker.py create mode 100644 workers/wan/README.md rename {utils => workers/wan}/__init__.py (100%) create mode 100644 workers/wan/client.py create mode 100644 workers/wan/worker.py diff --git a/README.md b/README.md index dda0ea2..72a8ce3 100644 --- a/README.md +++ b/README.md @@ -1,90 +1,152 @@ -# Vast PyWorker +# Vast PyWorker Examples -Vast PyWorker is a Python web server designed to run alongside a LLM or image generation models running on vast, -enabling autoscaler integration. -It serves as the primary entry point for API requests, forwarding them to the model's API hosted on the -same instance. Additionally, it monitors performance metrics and estimates current workload based on factors -such as the number of tokens processed for LLMs or image resolution and steps for image generation models, -reporting these metrics to the autoscaler. +This repository contains **example PyWorkers** used by Vast.ai’s default Serverless templates (e.g., vLLM, TGI, ComfyUI, Wan, ACE). A PyWorker is a lightweight Python HTTP proxy that runs alongside your model server and: -## Project Structure +- Exposes one or more HTTP routes (e.g., `/v1/completions`, `/generate/sync`) +- Optionally validates/transforms request payloads +- Computes per-request **workload** for autoscaling +- Forwards requests to the local model server +- Optionally supports FIFO queueing when the backend cannot process concurrent requests +- Detects readiness/failure from model logs and runs a benchmark to estimate throughput -* `lib/`: Contains the core PyWorker framework code (server logic, data types, metrics). -* `workers/`: Contains specific implementations (PyWorkers) for different model servers. Each subdirectory represents a worker for a particular model type. +> Important: The **core PyWorker framework** (Worker, WorkerConfig, HandlerConfig, BenchmarkConfig, LogActionConfig) is provided by the **`vastai` / `vastai-sdk`** Python package (https://github.com/vast-ai/vast-sdk). This repo focuses on *worker implementations and examples*, not the framework internals. -## Getting Started +## Repository Purpose -1. **Install Dependencies:** - ```bash - pip install -r requirements.txt - ``` - You may also need `pyright` for type checking: - ```bash - sudo npm install -g pyright - # or use your preferred method to install pyright - ``` +Use this repository as: -2. **Configure Environment:** Set any necessary environment variables (e.g., `MODEL_LOG` path, API keys if needed by your worker). +- A reference for how Vast templates wire up `worker.py` +- A starting point for implementing your own custom Serverless PyWorker +- A collection of working examples for common model backends -3. **Run the Server:** Use the provided script. You'll need to specify which worker to run. - ```bash - # Example for hello_world worker (assuming MODEL_LOG is set) - ./start_server.sh workers.hello_world.server - ``` - Replace `workers.hello_world.server` with the path to the `server.py` module of the worker you want to run. +If you are looking for the framework code itself, refer to the Vast.ai SDK. -## How to Use +## Project Structure -### Using Existing Workers +Typical layout: + +- `workers/` + - Example worker implementations (each worker is usually a self-contained folder) + - Each example typically includes: + - `worker.py` (the entrypoint used by Serverless) + - Optional sample workflows / payloads (for ComfyUI-based workers) + - Optional local test harness scripts + +## How Serverless launches worker.py + +On each worker instance, the template’s startup script typically: + +1. Clones your repository from `PYWORKER_REPO` +2. Installs dependencies from `requirements.txt` +3. Starts the **model server** (vLLM, TGI, ComfyUI, etc.) +4. Runs: + ```bash + python worker.py + ``` + +Your `worker.py` builds a `WorkerConfig`, constructs a `Worker`, and starts the PyWorker HTTP server. + +## worker.py + +A PyWorker is usually a single `worker.py` that uses SDK configuration objects: + +```python +from vastai import ( + Worker, + WorkerConfig, + HandlerConfig, + BenchmarkConfig, + LogActionConfig, +) + +worker_config = WorkerConfig( + model_server_url="http://127.0.0.1", + model_server_port=18000, + model_log_file="/var/log/model/server.log", + handlers=[ + HandlerConfig( + route="/v1/completions", + allow_parallel_requests=True, + max_queue_time=60.0, + workload_calculator=lambda payload: float(payload.get("max_tokens", 0)), + benchmark_config=BenchmarkConfig( + generator=lambda: {"prompt": "hello", "max_tokens": 128}, + runs=8, + concurrency=10, + ), + ) + ], + log_action_config=LogActionConfig( + on_load=["Application startup complete."], + on_error=["Traceback (most recent call last):", "RuntimeError:"], + on_info=['"message":"Download'], + ), +) + +Worker(worker_config).run() +``` -If you are using a Vast.ai template that includes PyWorker integration (marked as autoscaler compatible), it should work out of the box. The template will typically start the appropriate PyWorker server automatically. Here's a few: +## Included Examples -* **vLLM:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=63ae93902bf3978bea033782592b784d) -* **TGI (Text Generation Inference):** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=6fa6bd5bdf5f0df63db80e40b086037d) -* **ComfyUI:** [Vast.ai Template](https://cloud.vast.ai?ref_id=62897&template_id=e6748878ba688e765e3e9fca29541938) +This repository contains example PyWorkers corresponding to common Vast templates, including: -Currently available workers: -* `openai`: A simple example worker for a basic vLLM server. -* `comfyui`: A worker for the ComfyUI image generation backend. -* `tgi`: A worker for the Text Generation Inference backend. +- **vLLM**: OpenAI-compatible completions/chat endpoints with parallel request support +- **TGI (Text Generation Inference)**: OpenAI-compatible endpoints and log-based readiness +- **ComfyUI (Image / JSON workflows)**: `/generate/sync` for ComfyUI workflow execution +- **ComfyUI Wan 2.2 (T2V)**: ComfyUI workflow execution producing video outputs +- **ComfyUI ACE Step (Text-to-Music)**: ComfyUI workflow execution producing audio outputs -### Implementing a New Worker +Exact worker paths and naming may vary by template; use the `workers/` directory as the source of truth. -To integrate PyWorker with a model server not already supported, you need to create a new worker implementation under the `workers/` directory. Follow these general steps: +## Getting Started (Local) -1. **Create Worker Directory:** Add a new directory under `workers/` (e.g., `workers/my_model/`). -2. **Define Data Types (`data_types.py`):** - * Create a class inheriting from `lib.data_types.ApiPayload`. - * Implement methods like `for_test`, `generate_payload_json`, `count_workload`, and `from_json_msg` to handle request data, testing, and workload calculation specific to your model's API. -3. **Implement Endpoint Handlers (`server.py`):** - * For each model API endpoint you want PyWorker to proxy, create a class inheriting from `lib.data_types.EndpointHandler`. - * Implement methods like `endpoint`, `payload_cls`, `generate_payload_json`, `make_benchmark_payload` (for one handler), and `generate_client_response`. - * Instantiate `lib.backend.Backend` with your model server details, log file path, benchmark handler, and log actions. - * Define `aiohttp` routes, mapping paths to your handlers using `backend.create_handler()`. - * Use `lib.server.start_server` to run the application. -4. **Add `__init__.py`:** Create an empty `__init__.py` file in your worker directory. -5. **(Optional) Add Load Testing (`test_load.py`):** Create a script using `lib.test_harness.run` to test your worker against a Vast.ai endpoint group. -6. **(Optional) Add Client Example (`client.py`):** Provide a script demonstrating how to call your worker's endpoints. +1. Install Python dependencies for the examples you plan to run: + ```bash + pip install -r requirements.txt + ``` -**For a detailed walkthrough, refer to the `hello_world` example:** [workers/hello_world/README.md](workers/hello_world/README.md) +2. Start your model server locally (vLLM, TGI, ComfyUI, etc.) and ensure: + - You know the model server URL/port + - You have a log file path you can tail for readiness/error detection +3. Run the worker: + ```bash + python worker.py + ``` + or, if running an example from a subfolder: + ```bash + python workers//worker.py + ``` -**Type Hinting:** It is strongly recommended to use strict type hinting throughout your implementation. Use `pyright` to check for type errors. +> Note: Many examples assume they are running inside Vast templates (ports, log paths, model locations). You may need to adjust `model_server_port` and `model_log_file` for local usage. -## Testing Your Worker +## Deploying on Vast Serverless -If you implement a `test_load.py` script for your worker, you can use it to load test a Vast.ai endpoint group running your instance image. +To use a custom PyWorker with Serverless: -```bash -# Example for hello_world worker -python3 -m workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME" -``` +1. Create a public Git repository containing: + - `worker.py` + - `requirements.txt` -Replace `workers.hello_world.test_load` with the path to your worker's test script and provide your Vast.ai API Key (`-k`) and the target Endpoint Group Name (`-e`). Adjust the number of requests (`-n`) and requests per second (`-rps`) as needed. +2. In your Serverless template / endpoint configuration, set: + - `PYWORKER_REPO` to your Git repository URL + - (Optional) `PYWORKER_REF` to a git ref (branch, tag, or commit) -## Community & Support +3. The template startup script will clone/install and run your `worker.py`. -Join the conversation and get help: +## Guidance for Custom Workers + +When implementing your own worker: + +- Define one `HandlerConfig` per route you want to expose. +- Choose a workload function that correlates with compute cost: + - LLMs: prompt tokens + max output tokens (or `max_tokens` as a simpler proxy) + - Non-LLMs: constant cost per request (e.g., `100.0`) is often sufficient +- Set `allow_parallel_requests=False` for backends that cannot handle concurrency (e.g., many ComfyUI deployments). +- Configure exactly **one** `BenchmarkConfig` across all handlers to enable capacity estimation. +- Use `LogActionConfig` to reliably detect “model loaded” and “fatal error” log lines. + +## Community & Support -* **Vast.ai Discord:** [https://discord.gg/Pa9M29FFye](https://discord.gg/Pa9M29FFye) -* **Vast.ai Subreddit:** [https://reddit.com/r/vastai/](https://reddit.com/r/vastai/) +- Vast.ai Discord: https://discord.gg/Pa9M29FFye +- Vast.ai Subreddit: https://reddit.com/r/vastai/ \ No newline at end of file diff --git a/lib/backend.py b/lib/backend.py deleted file mode 100644 index 0d9a273..0000000 --- a/lib/backend.py +++ /dev/null @@ -1,434 +0,0 @@ -import os -import json -import time -import base64 -import subprocess -import dataclasses -import logging -from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task -from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional -from functools import cached_property -from distutils.util import strtobool - -from anyio import open_file -from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector -import asyncio - -import requests -from Crypto.Signature import pkcs1_15 -from Crypto.Hash import SHA256 -from Crypto.PublicKey import RSA - -from lib.metrics import Metrics -from lib.data_types import ( - AuthData, - EndpointHandler, - LogAction, - ApiPayload_T, - JsonDataException, - RequestMetrics, - BenchmarkResult -) - -VERSION = "0.2.1" - -MSG_HISTORY_LEN = 100 -log = logging.getLogger(__file__) - -# defines the minimum wait time between sending updates to autoscaler -LOG_POLL_INTERVAL = 0.1 -BENCHMARK_INDICATOR_FILE = ".has_benchmark" -MAX_PUBKEY_FETCH_ATTEMPTS = 3 - - -@dataclasses.dataclass -class Backend: - """ - This class is responsible for: - 1. Tailing logs and updating load time metrics - 2. Taking an EndpointHandler alongside incoming payload, preparing a json to be sent to the model, and - sending the request. It also updates metrics as it makes those requests. - 3. Running a benchmark from an EndpointHandler - """ - - model_server_url: str - model_log_file: str - allow_parallel_requests: bool - benchmark_handler: ( - EndpointHandler # this endpoint handler will be used for benchmarking - ) - log_actions: List[Tuple[LogAction, str]] - max_wait_time: float = 10.0 - reqnum = -1 - version = VERSION - msg_history = [] - sem: Semaphore = dataclasses.field(default_factory=Semaphore) - unsecured: bool = dataclasses.field( - default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), - ) - report_addr: str = dataclasses.field( - default_factory=lambda: os.environ.get("REPORT_ADDR", "https://run.vast.ai") - ) - mtoken: str = dataclasses.field( - default_factory=lambda: os.environ.get("MASTER_TOKEN", "") - ) - - def __post_init__(self): - self.metrics = Metrics() - self.metrics._set_version(self.version) - self.metrics._set_mtoken(self.mtoken) - self._total_pubkey_fetch_errors = 0 - self._pubkey = self._fetch_pubkey() - self.__start_healthcheck: bool = False - - @property - def pubkey(self) -> Optional[RSA.RsaKey]: - if self._pubkey is None: - self._pubkey = self._fetch_pubkey() - return self._pubkey - - @cached_property - def session(self): - log.debug(f"starting session with {self.model_server_url}") - connector = TCPConnector( - force_close=True, # Required for long running jobs - enable_cleanup_closed=True, - ) - - timeout = ClientTimeout(total=None) - return ClientSession(self.model_server_url, timeout=timeout, connector=connector) - - def create_handler( - self, - handler: EndpointHandler[ApiPayload_T], - ) -> Callable[[web.Request], Awaitable[Union[web.Response, web.StreamResponse]]]: - async def handler_fn( - request: web.Request, - ) -> Union[web.Response, web.StreamResponse]: - return await self.__handle_request(handler=handler, request=request) - - return handler_fn - - #######################################Private####################################### - def _fetch_pubkey(self): - report_addr = self.report_addr.rstrip("/") - command = ["curl", "-X", "GET", f"{report_addr}/pubkey/"] - try: - result = subprocess.check_output(command, universal_newlines=True) - log.debug("public key:") - log.debug(result) - key = RSA.import_key(result) - if key is not None: - return key - except (ValueError , subprocess.CalledProcessError) as e: - log.debug(f"Error downloading key: {e}") - self.backend_errored("Failed to get autoscaler pubkey") - - - async def __handle_request( - self, - handler: EndpointHandler[ApiPayload_T], - request: web.Request, - ) -> Union[web.Response, web.StreamResponse]: - """use this function to forward requests to the model endpoint""" - try: - data = await request.json() - auth_data, payload = handler.get_data_from_request(data) - except JsonDataException as e: - return web.json_response(data=e.message, status=422) - except json.JSONDecodeError: - return web.json_response(dict(error="invalid JSON"), status=422) - workload = payload.count_workload() - request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") - - async def cancel_api_call_if_disconnected() -> web.Response: - await request.wait_for_disconnection() - log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") - self.metrics._request_canceled(request_metrics) - raise asyncio.CancelledError - - async def make_request() -> Union[web.Response, web.StreamResponse]: - try: - response = await self.__call_api(handler=handler, payload=payload) - status_code = response.status - log.debug( - " ".join( - [ - f"request with reqnum:{request_metrics.reqnum}", - f"returned status code: {status_code},", - ] - ) - ) - res = await handler.generate_client_response(request, response) - self.metrics._request_success(request_metrics) - return res - except requests.exceptions.RequestException as e: - log.debug(f"[backend] Request error: {e}") - self.metrics._request_errored(request_metrics) - return web.Response(status=500) - - ########### - - if self.__check_signature(auth_data) is False: - self.metrics._request_reject(request_metrics) - return web.Response(status=401) - - if self.metrics.model_metrics.wait_time > self.max_wait_time: - self.metrics._request_reject(request_metrics) - return web.Response(status=429) - - acquired = False - try: - self.metrics._request_start(request_metrics) - if self.allow_parallel_requests is False: - log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") - await self.sem.acquire() - acquired = True - log.debug( - f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." - ) - else: - log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") - done, pending = await wait( - [ - create_task(make_request()), - create_task(cancel_api_call_if_disconnected()), - ], - return_when=FIRST_COMPLETED, - ) - for t in pending: - t.cancel() - await asyncio.gather(*pending, return_exceptions=True) - - done_task = done.pop() - try: - return done_task.result() - except Exception as e: - log.debug(f"Request task raised exception: {e}") - return web.Response(status=500) - except asyncio.CancelledError: - # Client is gone. Do not write a response; just unwind. - return web.Response(status=499) - except Exception as e: - log.debug(f"Exception in main handler loop {e}") - return web.Response(status=500) - finally: - # Always release the semaphore if it was acquired - if acquired: - self.sem.release() - self.metrics._request_end(request_metrics) - - @cached_property - def healthcheck_session(self): - """Dedicated session for healthchecks to avoid conflicts with API session""" - log.debug("creating dedicated healthcheck session") - connector = TCPConnector( - force_close=True, # Keep this for isolation - enable_cleanup_closed=True, - ) - timeout = ClientTimeout(total=10) # Reasonable timeout for healthchecks - return ClientSession(timeout=timeout, connector=connector) - - async def __healthcheck(self): - health_check_url = self.benchmark_handler.healthcheck_endpoint - if health_check_url is None: - log.debug("No healthcheck endpoint defined, skipping healthcheck") - return - - while True: - await sleep(10) - if self.__start_healthcheck is False: - continue - try: - log.debug(f"Performing healthcheck on {health_check_url}") - async with self.healthcheck_session.get(health_check_url) as response: - if response.status == 200: - log.debug("Healthcheck successful") - elif response.status == 503: - log.debug(f"Healthcheck failed with status: {response.status}") - self.backend_errored( - f"Healthcheck failed with status: {response.status}" - ) - else: - log.debug(f"Healthcheck Endpoint not ready: {response.status}") - except Exception as e: - log.debug(f"Healthcheck failed with exception: {e}") - self.backend_errored(str(e)) - - async def _start_tracking(self) -> None: - await gather( - self.__read_logs(), self.metrics._send_metrics_loop(), self.__healthcheck(), self.metrics._send_delete_requests_loop() - ) - - def backend_errored(self, msg: str) -> None: - self.metrics._model_errored(msg) - - async def __call_api( - self, handler: EndpointHandler[ApiPayload_T], payload: ApiPayload_T - ) -> ClientResponse: - api_payload = payload.generate_payload_json() - log.debug(f"posting to endpoint: '{handler.endpoint}', payload: {api_payload}") - return await self.session.post(url=handler.endpoint, json=api_payload) - - def __check_signature(self, auth_data: AuthData) -> bool: - if self.unsecured is True: - return True - - def verify_signature(message, signature): - if self.pubkey is None: - log.debug(f"No Public Key!") - return False - - h = SHA256.new(message.encode()) - try: - pkcs1_15.new(self.pubkey).verify(h, base64.b64decode(signature)) - return True - except (ValueError, TypeError): - return False - - message = { - key: value - for (key, value) in (dataclasses.asdict(auth_data).items()) - if key != "signature" and key != "__request_id" - } - if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN): - log.debug( - f"reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}" - ) - return False - elif message in self.msg_history: - log.debug(f"message: {message} already in message history") - return False - elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature): - self.reqnum = max(auth_data.reqnum, self.reqnum) - self.msg_history.append(message) - self.msg_history = self.msg_history[-MSG_HISTORY_LEN:] - return True - else: - log.debug( - f"signature verification failed, sig:{auth_data.signature}, message: {message}" - ) - return False - - async def __read_logs(self) -> Awaitable[NoReturn]: - - async def run_benchmark() -> float: - log.debug("starting benchmark") - try: - with open(BENCHMARK_INDICATOR_FILE, "r") as f: - log.debug("already ran benchmark") - # trigger model load - # payload = self.benchmark_handler.make_benchmark_payload() - # _ = await self.__call_api( - # handler=self.benchmark_handler, payload=payload - # ) - return float(f.readline()) - except FileNotFoundError: - pass - - log.debug("Initial run to trigger model loading...") - payload = self.benchmark_handler.make_benchmark_payload() - await self.__call_api(handler=self.benchmark_handler, payload=payload) - - max_throughput = 0 - sum_throughput = 0 - concurrent_requests = 10 if self.allow_parallel_requests else 1 - - for run in range(1, self.benchmark_handler.benchmark_runs + 1): - start = time.time() - benchmark_requests = [] - - for i in range(concurrent_requests): - payload = self.benchmark_handler.make_benchmark_payload() - workload = payload.count_workload() - task = self.__call_api(handler=self.benchmark_handler, payload=payload) - benchmark_requests.append( - BenchmarkResult(request_idx=i, workload=workload, task=task) - ) - - responses = await gather(*[br.task for br in benchmark_requests]) - for br, response in zip(benchmark_requests, responses): - br.response = response - - total_workload = sum(br.workload for br in benchmark_requests if br.is_successful) - time_elapsed = time.time() - start - successful_responses = sum([1 for br in benchmark_requests if br.is_successful]) - if successful_responses == 0: - self.backend_errored("No successful responses from benchmark") - log.debug(f"benchmark failed: {successful_responses}/{concurrent_requests} successful responses") - - throughput = total_workload / time_elapsed - sum_throughput += throughput - max_throughput = max(max_throughput, throughput) - - # Log results for debugging - log.debug( - "\n".join( - [ - "#" * 60, - f"Run: {run}, concurrent_requests: {concurrent_requests}", - f"Total workload: {total_workload}, time_elapsed: {time_elapsed}s", - f"Throughput: {throughput} workload/s", - f"Successful responses: {successful_responses}/{concurrent_requests}", - "#" * 60, - ] - ) - ) - - average_throughput = sum_throughput / self.benchmark_handler.benchmark_runs - log.debug( - f"benchmark result: avg {average_throughput} workload per second, max {max_throughput}" - ) - with open(BENCHMARK_INDICATOR_FILE, "w") as f: - f.write(str(max_throughput)) - return max_throughput - - async def handle_log_line(log_line: str) -> None: - """ - Implement this function to handle each log line for your model. - This function should mutate self.system_metrics and self.model_metrics - """ - for action, msg in self.log_actions: - match action: - case LogAction.ModelLoaded if msg in log_line: - log.debug( - f"Got log line indicating model is loaded: {log_line}" - ) - # some backends need a few seconds after logging successful startup before - # they can begin accepting requests - # await sleep(5) - try: - max_throughput = await run_benchmark() - self.__start_healthcheck = True - self.metrics._model_loaded( - max_throughput=max_throughput, - ) - except ClientConnectorError as e: - log.debug( - f"failed to connect to comfyui api during benchmark" - ) - self.backend_errored(str(e)) - case LogAction.ModelError if msg in log_line: - log.debug(f"Got log line indicating error: {log_line}") - self.backend_errored(msg) - break - case LogAction.Info if msg in log_line: - log.debug(f"Info from model logs: {log_line}") - - async def tail_log(): - log.debug(f"tailing file: {self.model_log_file}") - async with await open_file(self.model_log_file, encoding='utf-8', errors='ignore') as f: - while True: - line = await f.readline() - if line: - await handle_log_line(line.rstrip()) - else: - await asyncio.sleep(LOG_POLL_INTERVAL) - - ########### - - while True: - if os.path.isfile(self.model_log_file) is True: - return await tail_log() - else: - await sleep(1) diff --git a/lib/data_types.py b/lib/data_types.py deleted file mode 100644 index d948c60..0000000 --- a/lib/data_types.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -import logging -from dataclasses import dataclass, field -from enum import Enum -from abc import ABC, abstractmethod -from typing import Dict, Any, Union, Tuple, Optional, Set, TypeVar, Generic, Type, Awaitable -from aiohttp import web, ClientResponse -import inspect - -import psutil - - -""" -type variable representing an incoming payload to pyworker that will used to calculate load and will then -be forwarded to the model -""" - -log = logging.getLogger(__file__) - - -class JsonDataException(Exception): - def __init__(self, json_msg: Dict[str, Any]): - self.message = json_msg - - -ApiPayload_T = TypeVar("ApiPayload_T", bound="ApiPayload") - - -@dataclass -class ApiPayload(ABC): - - @classmethod - @abstractmethod - def for_test(cls: Type[ApiPayload_T]) -> ApiPayload_T: - """defines how create a payload for load testing""" - pass - - @abstractmethod - def generate_payload_json(self) -> Dict[str, Any]: - """defines how to convert an ApiPayload to JSON that will be sent to model API""" - pass - - @abstractmethod - def count_workload(self) -> float: - """defines how to calculate workload for a payload""" - pass - - @classmethod - @abstractmethod - def from_json_msg( - cls: Type[ApiPayload_T], json_msg: Dict[str, Any] - ) -> ApiPayload_T: - """ - defines how to create an API payload from a JSON message, - it should throw an JsonDataException if there are issues with some fields - or they are missing in the format of - { - "field": "error msg" - } - """ - pass - - -@dataclass -class AuthData: - """data used to authenticate requester""" - - cost: str - endpoint: str - reqnum: int - request_idx: int - signature: str - url: str - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]): - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) - - -@dataclass -class EndpointHandler(ABC, Generic[ApiPayload_T]): - """ - Each model endpoint will have a handler responsible for counting workload from the incoming ApiPayload - and converting it to json to be forwarded to model API - """ - - benchmark_runs: int = 8 - benchmark_words: int = 100 - - @property - @abstractmethod - def endpoint(self) -> str: - """the endpoint on the model API""" - pass - - @property - @abstractmethod - def healthcheck_endpoint(self) -> Optional[str]: - """the endpoint on the model API that is used for healthchecks""" - pass - - @classmethod - @abstractmethod - def payload_cls(cls) -> Type[ApiPayload_T]: - """ApiPayload class""" - pass - - @abstractmethod - def make_benchmark_payload(self) -> ApiPayload_T: - """defines how to create an ApiPayload for benchmarking.""" - pass - - @abstractmethod - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - """ - defines how to convert a model API response to a response to PyWorker client - """ - pass - - @classmethod - def get_data_from_request( - cls, req_data: Dict[str, Any] - ) -> Tuple[AuthData, ApiPayload_T]: - errors = {} - auth_data: Optional[AuthData] = None - payload: Optional[ApiPayload_T] = None - try: - if "auth_data" in req_data: - auth_data = AuthData.from_json_msg(req_data["auth_data"]) - else: - errors["auth_data"] = "field missing" - except JsonDataException as e: - errors["auth_data"] = e.message - try: - if "payload" in req_data: - payload_cls = cls.payload_cls() - payload = payload_cls.from_json_msg(req_data["payload"]) - else: - errors["payload"] = "field missing" - except JsonDataException as e: - errors["payload"] = e.message - if errors: - raise JsonDataException(errors) - if auth_data and payload: - return (auth_data, payload) - else: - raise Exception("error deserializing request data") - - -@dataclass -class SystemMetrics: - """General system metrics""" - - model_loading_start: float - model_loading_time: Union[float, None] - last_disk_usage: float - additional_disk_usage: float - model_is_loaded: bool - - @staticmethod - def get_disk_usage_GB(): - return psutil.disk_usage("/").used / (2**30) # want units of GB - - @classmethod - def empty(cls): - return cls( - model_loading_start=time.time(), - model_loading_time=None, - last_disk_usage=SystemMetrics.get_disk_usage_GB(), - additional_disk_usage=0.0, - model_is_loaded=False, - ) - - def update_disk_usage(self): - disk_usage = SystemMetrics.get_disk_usage_GB() - self.additional_disk_usage = disk_usage - self.last_disk_usage - self.last_disk_usage = disk_usage - - def reset(self, expected: float | None) -> None: - # autoscaler excepts model_loading_time to be populated only once, when the instance has - # finished benchmarking and is ready to receive requests. This applies to restarted instances - # as well: they should send model_loading_time once when they are done loading - if self.model_loading_time == expected: - self.model_loading_time = None - - -@dataclass -class RequestMetrics: - """Tracks metrics for an active request.""" - request_idx: int - reqnum: int - workload: float - status: str - success: bool = False - -@dataclass -class BenchmarkResult: - request_idx: int - workload: float - task: Awaitable[ClientResponse] - response: Optional[ClientResponse] = None - - @property - def is_successful(self) -> bool: - return self.response is not None and self.response.status == 200 - -@dataclass -class ModelMetrics: - """Model specific metrics""" - - # these are reset after being sent to autoscaler - workload_served: float - workload_received: float - workload_cancelled: float - workload_errored: float - workload_rejected: float - # these are not - workload_pending: float - error_msg: Optional[str] - max_throughput: float - requests_recieved: Set[int] = field(default_factory=set) - requests_working: dict[int, RequestMetrics] = field(default_factory=dict) - requests_deleting: list[RequestMetrics] = field(default_factory=list) - last_update: float = field(default_factory=time.time) - - @classmethod - def empty(cls): - return cls( - workload_pending=0.0, - workload_served=0.0, - workload_cancelled=0.0, - workload_errored=0.0, - workload_rejected=0.0, - workload_received=0.0, - error_msg=None, - max_throughput=0.0, - ) - - @property - def workload_processing(self) -> float: - return max(self.workload_received - self.workload_cancelled, 0.0) - - @property - def wait_time(self) -> float: - if (len(self.requests_working) == 0): - return 0.0 - return sum([request.workload for request in self.requests_working.values()]) / max(self.max_throughput, 0.00001) - - @property - def cur_load(self) -> float: - return sum([request.workload for request in self.requests_working.values()]) - - @property - def working_request_idxs(self) -> list[int]: - return [req.request_idx for req in self.requests_working.values()] - - def set_errored(self, error_msg): - self.reset() - self.error_msg = error_msg - - def reset(self): - self.workload_served = 0 - self.workload_received = 0 - self.workload_cancelled = 0 - self.workload_errored = 0 - self.workload_rejected = 0 - self.last_update = time.time() - - -@dataclass -class AutoScalerData: - """Data that is reported to autoscaler""" - - id: int - mtoken: str - version: str - loadtime: float - cur_load: float - rej_load: float - new_load: float - error_msg: str - max_perf: float - cur_perf: float - cur_capacity: float - max_capacity: float - num_requests_working: int - num_requests_recieved: int - additional_disk_usage: float - working_request_idxs: list[int] - url: str - - -class LogAction(Enum): - """ - These actions tell the backend what a log value means, for example: - actions [ - # this marks the model server as loaded - (LogAction.ModelLoaded, "Starting server"), - # these mark the model server as errored - (LogAction.ModelError, "Exception loading model"), - (LogAction.ModelError, "Server failed to bind to port"), - # this tells the backend to print any logs containing the string into its own logs - # which are visible in the vast console instance logs - (LogAction.Info, "Starting model download"), - ] - """ - - ModelLoaded = 1 - ModelError = 2 - Info = 3 diff --git a/lib/metrics.py b/lib/metrics.py deleted file mode 100644 index 48774fe..0000000 --- a/lib/metrics.py +++ /dev/null @@ -1,286 +0,0 @@ -import os -import time -import logging -import json -from asyncio import sleep -from dataclasses import dataclass, asdict, field -from functools import cache -import asyncio -from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError - -from lib.data_types import AutoScalerData, SystemMetrics, ModelMetrics, RequestMetrics -from typing import Awaitable, NoReturn, List - -METRICS_UPDATE_INTERVAL = 1 -DELETE_REQUESTS_INTERVAL = 1 - -log = logging.getLogger(__file__) - - -@cache -def get_url() -> str: - use_ssl = os.environ.get("USE_SSL", "false") == "true" - worker_port = os.environ[f"VAST_TCP_PORT_{os.environ['WORKER_PORT']}"] - public_ip = os.environ["PUBLIC_IPADDR"] - return f"http{'s' if use_ssl else ''}://{public_ip}:{worker_port}" - - -@dataclass -class Metrics: - version: str = "0" - mtoken: str = "" - last_metric_update: float = 0.0 - last_request_served: float = 0.0 - update_pending: bool = False - id: int = field(default_factory=lambda: int(os.environ["CONTAINER_ID"])) - report_addr: List[str] = field( - default_factory=lambda: os.environ["REPORT_ADDR"].split(",") - ) - url: str = field(default_factory=get_url) - system_metrics: SystemMetrics = field(default_factory=SystemMetrics.empty) - model_metrics: ModelMetrics = field(default_factory=ModelMetrics.empty) - _session: ClientSession | None = field(default=None, init=False, repr=False) - - async def http(self) -> ClientSession: - if self._session is None: - self._session = ClientSession( - timeout=ClientTimeout(total=10), - connector=TCPConnector(limit=8, limit_per_host=4, force_close=True, enable_cleanup_closed=True) - ) - return self._session - - async def aclose(self) -> None: - if self._session is not None: - await self._session.close() - self._session = None - - def _request_start(self, request: RequestMetrics) -> None: - """ - this function is called prior to forwarding a request to a model API. - """ - log.debug("request start") - request.status = "Started" - self.model_metrics.workload_pending += request.workload - self.model_metrics.workload_received += request.workload - self.model_metrics.requests_recieved.add(request.reqnum) - self.model_metrics.requests_working[request.reqnum] = request - self.update_pending = True - - def _request_end(self, request: RequestMetrics) -> None: - """ - this function is called after handling of a request ends, regardless of the outcome - """ - self.model_metrics.workload_pending -= request.workload - self.model_metrics.requests_working.pop(request.reqnum, None) - self.model_metrics.requests_deleting.append(request) - self.last_request_served = time.time() - - def _request_success(self, request: RequestMetrics) -> None: - """ - this function is called after a response from model API is received and forwarded. - """ - self.model_metrics.workload_served += request.workload - request.status = "Success" - request.success = True - self.update_pending = True - - def _request_errored(self, request: RequestMetrics) -> None: - """ - this function is called if model API returns an error - """ - self.model_metrics.workload_errored += request.workload - request.status = "Error" - request.success = False - self.update_pending = True - - def _request_canceled(self, request: RequestMetrics) -> None: - """ - this function is called if client drops connection before model API has responded - """ - self.model_metrics.workload_cancelled += request.workload - request.success = True - request.status = "Cancelled" - - def _request_reject(self, request: RequestMetrics): - """ - this function is called if the current wait time for the model is above max_wait_time - """ - self.model_metrics.requests_recieved.add(request.reqnum) - self.model_metrics.requests_deleting.append(request) - self.model_metrics.workload_rejected += request.workload - request.success = False - request.status = "Rejected" - self.update_pending = True - - async def _send_delete_requests_loop(self) -> Awaitable[NoReturn]: - while True: - await sleep(DELETE_REQUESTS_INTERVAL) - if len(self.model_metrics.requests_deleting) > 0: - await self.__send_delete_requests_and_reset() - - async def _send_metrics_loop(self) -> Awaitable[NoReturn]: - while True: - await sleep(METRICS_UPDATE_INTERVAL) - elapsed = time.time() - self.last_metric_update - if self.system_metrics.model_is_loaded is False and elapsed >= 10: - log.debug(f"sending loading model metrics after {int(elapsed)}s wait") - await self.__send_metrics_and_reset() - elif self.update_pending or elapsed > 10: - log.debug(f"sending loaded model metrics after {int(elapsed)}s wait") - await self.__send_metrics_and_reset() - - def _model_loaded(self, max_throughput: float) -> None: - self.system_metrics.model_loading_time = ( - time.time() - self.system_metrics.model_loading_start - ) - self.system_metrics.model_is_loaded = True - self.model_metrics.max_throughput = max_throughput - - def _model_errored(self, error_msg: str) -> None: - self.model_metrics.set_errored(error_msg) - self.system_metrics.model_is_loaded = True - - def _set_version(self, version: str) -> None: - self.version = version - - def _set_mtoken(self, mtoken: str) -> None: - self.mtoken = mtoken - - #######################################Private####################################### - - async def __send_delete_requests_and_reset(self): - async def post(report_addr: str, idxs: list[int], success_flag: bool) -> bool: - data = { - "worker_id": self.id, - "mtoken": self.mtoken, - "request_idxs": idxs, - "success": success_flag, - } - log.debug( - f"Deleting requests that {'succeeded' if success_flag else 'failed'}: {data['request_idxs']}" - ) - full_path = report_addr.rstrip("/") + "/delete_requests/" - for attempt in range(1, 4): - try: - session = await self.http() - async with session.post(full_path, json=data) as res: - log.debug(f"delete_requests response: {res.status}") - res.raise_for_status() - return True - except asyncio.TimeoutError: - log.debug("delete_requests timed out") - except (ClientResponseError, Exception) as e: - log.debug(f"delete_requests failed with error: {e}") - await asyncio.sleep(2) - log.debug(f"retrying delete_request, attempt: {attempt}") - return False - - # Take a snapshot of what we plan to send this tick. - # New arrivals after this snapshot will remain in the queue for the next tick. - snapshot = list(self.model_metrics.requests_deleting) - success_idxs = [r.request_idx for r in snapshot if r.success is True] - failed_idxs = [r.request_idx for r in snapshot if r.success is False] - - if not success_idxs and not failed_idxs: - return # nothing to do - - for report_addr in self.report_addr: - # TODO: Add a Redis subscriber queue for delete_requests - if report_addr == "https://cloud.vast.ai/api/v0": - # Patch: ignore the Redis API report_addr - continue - sent_success = True - sent_failed = True - - if success_idxs: - sent_success = await post(report_addr, success_idxs, True) - if failed_idxs: - sent_failed = await post(report_addr, failed_idxs, False) - - if sent_success and sent_failed: - # Remove only the items we actually sent from the live queue. - sent_set = set(success_idxs) | set(failed_idxs) - self.model_metrics.requests_deleting[:] = [ - r for r in self.model_metrics.requests_deleting - if r.request_idx not in sent_set - ] - break - - - async def __send_metrics_and_reset(self): - - loadtime_snapshot = self.system_metrics.model_loading_time - - def compute_autoscaler_data() -> AutoScalerData: - return AutoScalerData( - id=self.id, - mtoken=self.mtoken, - version=self.version, - loadtime=(loadtime_snapshot or 0.0), - new_load=self.model_metrics.workload_processing, - cur_load=self.model_metrics.cur_load, - rej_load=self.model_metrics.workload_rejected, - max_perf=self.model_metrics.max_throughput, - cur_perf=self.model_metrics.workload_served, - error_msg=self.model_metrics.error_msg or "", - num_requests_working=len(self.model_metrics.requests_working), - num_requests_recieved=len(self.model_metrics.requests_recieved), - additional_disk_usage=self.system_metrics.additional_disk_usage, - working_request_idxs=self.model_metrics.working_request_idxs, - cur_capacity=0, - max_capacity=0, - url=self.url, - ) - - async def send_data(report_addr: str) -> bool: - data = compute_autoscaler_data() - log_data = asdict(data) - def obfuscate(secret: str) -> str: - if secret is None: - return "" - return secret[:7] + "..." if len(secret) > 7 else ("*" * len(secret)) - - log_data["mtoken"] = obfuscate(log_data.get("mtoken")) - log.debug( - "\n".join( - [ - "#" * 60, - f"sending data to autoscaler", - f"{json.dumps(log_data, indent=2)}", - "#" * 60, - ] - ) - ) - - full_path = report_addr.rstrip("/") + "/worker_status/" - for attempt in range(1, 4): - try: - session = await self.http() - async with session.post(full_path, json=asdict(data)) as res: - res.raise_for_status() - return True - except asyncio.TimeoutError: - log.debug(f"autoscaler status update timed out") - except (ClientResponseError, Exception) as e: - log.debug(f"autoscaler status update failed with error: {e}") - await asyncio.sleep(2) - log.debug(f"retrying autoscaler status update, attempt: {attempt}") - log.debug(f"failed to send update through {report_addr}") - return False - - ########### - - self.system_metrics.update_disk_usage() - - sent = False - for report_addr in self.report_addr: - if await send_data(report_addr): - sent = True - break - - if sent: - # clear the one-shot loadtime only if we actually sent *this* value - self.system_metrics.reset(expected=loadtime_snapshot) - self.update_pending = False - self.model_metrics.reset() - self.last_metric_update = time.time() diff --git a/lib/server.py b/lib/server.py deleted file mode 100644 index 0029311..0000000 --- a/lib/server.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import logging -from typing import List -import ssl -from asyncio import run, gather -import asyncio - -from lib.backend import Backend -from lib.metrics import Metrics -from aiohttp import web - -log = logging.getLogger(__file__) - - -def start_server(backend: Backend, routes: List[web.RouteDef], **kwargs): - try: - log.debug("getting certificate...") - use_ssl = os.environ.get("USE_SSL", "false") == "true" - if use_ssl is True: - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain( - certfile="/etc/instance.crt", - keyfile="/etc/instance.key", - ) - else: - ssl_context = None - - async def main(): - log.debug("starting server...") - app = web.Application() - app.add_routes(routes) - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite( - runner, - ssl_context=ssl_context, - port=int(os.environ["WORKER_PORT"]), - **kwargs - ) - await gather(site.start(), backend._start_tracking()) - - run(main()) - - except Exception as e: - err_msg = f"PyWorker failed to launch: {e}" - log.error(err_msg) - - async def beacon(): - metrics = Metrics() - metrics._set_version(getattr(backend, "version", "0")) - metrics._set_mtoken(getattr(backend, "mtoken", "")) - try: - while True: - metrics._model_errored(err_msg) - await metrics._Metrics__send_metrics_and_reset() - await asyncio.sleep(10) - finally: - await metrics.aclose() - - run(beacon()) diff --git a/lib/test_utils.py b/lib/test_utils.py deleted file mode 100644 index d64a4b6..0000000 --- a/lib/test_utils.py +++ /dev/null @@ -1,310 +0,0 @@ -import logging -import os -import time -import argparse -from typing import Callable, List, Dict, Tuple, Dict, Any, Type -from time import sleep -import threading -from enum import Enum -from collections import Counter -from dataclasses import dataclass, field, asdict -from urllib.parse import urljoin -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path -import requests - -from lib.data_types import AuthData, ApiPayload - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -class ClientStatus(Enum): - FetchEndpoint = 1 - Generating = 2 - Done = 3 - Error = 4 - - -total_success = 0 -last_res = [] -stop_event = threading.Event() - -start_time = time.time() -test_args = argparse.ArgumentParser(description="Test inference endpoint") -test_args.add_argument( - "-k", dest="api_key", type=str, required=True, help="Your vast account API key" -) -test_args.add_argument( - "-e", - dest="endpoint_group_name", - type=str, - required=True, - help="Endpoint group name", -) -test_args.add_argument( - "-l", - dest="server_url", - action="store_const", - const="http://localhost:8081", - default="https://run.vast.ai", - help="Call local autoscaler instead of prod, for dev use only", -) -test_args.add_argument( - "-i", - dest="instance", - type=str, - default="prod", - help="Autoscaler shard to run the command against, default: prod", -) - -GetPayloadAndWorkload = Callable[[], Tuple[Dict[str, Any], float]] - - -def print_truncate_res(res: str): - if len(res) > 150: - print(f"{res[:50]}....{res[-100:]}") - else: - print(res) - - -@dataclass -class ClientState: - endpoint_group_name: str - api_key: str - server_url: str - worker_endpoint: str - instance: str - payload: ApiPayload - url: str = "" - status: ClientStatus = ClientStatus.FetchEndpoint - as_error: List[str] = field(default_factory=list) - infer_error: List[str] = field(default_factory=list) - conn_errors: Counter = field(default_factory=Counter) - - def make_call(self): - self.status = ClientStatus.FetchEndpoint - if not self.api_key: - self.as_error.append( - f"Endpoint {self.endpoint_group_name} not found for API key", - ) - self.status = ClientStatus.Error - return - route_payload = { - "endpoint": self.endpoint_group_name, - "api_key": self.api_key, - "cost": self.payload.count_workload(), - } - headers = {"Authorization": f"Bearer {self.api_key}"} - response = requests.post( - urljoin(self.server_url, "/route/"), - json=route_payload, - headers=headers, - timeout=4, - ) - if response.status_code != 200: - self.as_error.append( - f"code: {response.status_code}, body: {response.text}", - ) - self.status = ClientStatus.Error - return - message = response.json() - worker_address = message["url"] - req_data = dict( - payload=asdict(self.payload), - auth_data=asdict(AuthData.from_json_msg(message)), - ) - self.url = worker_address - url = urljoin(worker_address, self.worker_endpoint) - self.status = ClientStatus.Generating - - response = requests.post( - url, - json=req_data, - verify=get_cert_file_path(), - ) - if response.status_code != 200: - self.infer_error.append( - f"code: {response.status_code}, body: {response.text}, url: {url}", - ) - self.status = ClientStatus.Error - return - res = str(response.json()) - global total_success - global last_res - total_success += 1 - last_res.append(res) - self.status = ClientStatus.Done - - def simulate_user(self) -> None: - try: - self.make_call() - except Exception as e: - print(e) - self.status = ClientStatus.Error - _ = e - self.conn_errors[self.url] += 1 - - -def print_state(clients: List[ClientState], num_clients: int) -> None: - print("starting up...") - sleep(2) - center_size = 14 - global start_time - while len(clients) < num_clients or ( - any( - map( - lambda client: client.status - in [ClientStatus.FetchEndpoint, ClientStatus.Generating], - clients, - ) - ) - ): - sleep(0.5) - os.system("clear") - print( - " | ".join( - [member.name.center(center_size) for member in ClientStatus] - + [ - item.center(center_size) - for item in [ - "urls", - "as_error", - "infer_error", - "conn_error", - "total_success", - ] - ] - ) - ) - unique_urls = len(set([c.url for c in clients if c.url != ""])) - as_errors = sum( - map( - lambda client: len(client.as_error), - [client for client in clients], - ) - ) - infer_errors = sum( - map( - lambda client: len(client.infer_error), - [client for client in clients], - ) - ) - conn_errors = sum([client.conn_errors for client in clients], start=Counter()) - conn_errors_str = ",".join(map(str, conn_errors.values())) or "0" - elapsed = time.time() - start_time - print( - " | ".join( - map( - lambda item: str(item).center(center_size), - [ - len(list(filter(lambda x: x.status == member, clients))) - for member in ClientStatus - ] - + [ - unique_urls, - as_errors, - infer_errors, - conn_errors_str, - f"{total_success}({((total_success/elapsed) * 60):.2f}/minute)", - ], - ) - ) - ) - if conn_errors: - print("conn_errors:") - for url, count in conn_errors.items(): - print(url.ljust(28), ": ", str(count)) - elapsed = time.time() - start_time - print(f"\n elapsed: {int(elapsed // 60)}:{int(elapsed % 60)}") - if last_res: - for i, res in enumerate(last_res[-10:]): - print_truncate_res(f"res #{1+i+max(len(last_res )-10,0)}: {res}") - if stop_event.is_set(): - print("\n### waiting for existing connections to close ###") - - -def run_test( - num_requests: int, - requests_per_second: int, - endpoint_group_name: str, - api_key: str, - server_url: str, - worker_endpoint: str, - payload_cls: Type[ApiPayload], - instance: str, -): - threads = [] - - clients = [] - print_thread = threading.Thread(target=print_state, args=(clients, num_requests)) - print_thread.daemon = True # makes threads get killed on program exit - print_thread.start() - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=endpoint_group_name, account_api_key=api_key, instance=instance - ) - if not endpoint_api_key: - log.debug(f"Endpoint {endpoint_group_name} not found for API key") - return - try: - for _ in range(num_requests): - client = ClientState( - endpoint_group_name=endpoint_group_name, - api_key=endpoint_api_key, - server_url=server_url, - worker_endpoint=worker_endpoint, - payload=payload_cls.for_test(), - instance=instance, - ) - clients.append(client) - thread = threading.Thread(target=client.simulate_user, args=()) - threads.append(thread) - thread.start() - sleep(1 / requests_per_second) - for thread in threads: - thread.join() - print("done spawning workers") - except KeyboardInterrupt: - stop_event.set() - - -def test_load_cmd( - payload_cls: Type[ApiPayload], endpoint: str, arg_parser: argparse.ArgumentParser -): - arg_parser.add_argument( - "-n", - dest="num_requests", - type=int, - required=True, - help="total number of requests", - ) - arg_parser.add_argument( - "-rps", - dest="requests_per_second", - type=float, - required=True, - help="requests per second", - ) - args = arg_parser.parse_args() - if hasattr(args, "comfy_model"): - os.environ["COMFY_MODEL"] = args.comfy_model - server_url = { - "prod": "https://run.vast.ai", - "alpha": "https://run-alpha.vast.ai", - "candidate": "https://run-candidate.vast.ai", - "local": "http://localhost:8080", - }.get(args.instance, "http://localhost:8080") - run_test( - num_requests=args.num_requests, - requests_per_second=args.requests_per_second, - api_key=args.api_key, - server_url=server_url, - endpoint_group_name=args.endpoint_group_name, - worker_endpoint=endpoint, - payload_cls=payload_cls, - instance=args.instance, - ) diff --git a/lib/wheres-my-pyworker.txt b/lib/wheres-my-pyworker.txt new file mode 100644 index 0000000..d439ebb --- /dev/null +++ b/lib/wheres-my-pyworker.txt @@ -0,0 +1,22 @@ +# Where did the PyWorker code go? +We have moved the PyWorker source code into the `vastai-sdk` Python SDK. +You can install it with +``` +pip install vastai-sdk +``` + +All of the source code can be found here: +https://github.com/vast-ai/vast-sdk + +And can be imported from vastai.serverless.server.lib + +Serverless instances automatically run the start_server.sh script, which installs the vastai-sdk. +This is how the PyWorker source code makes it onto your serverless instances. +You provide a worker.py file in your PYWORKER_REPO, and the start_server.sh will +create and run a PyWorker according to your configuration defined in the file. + +While you can still create and run PyWorkers for serverless using your old PyWorker code, +we **strongly** encourage you to use the new worker.py configuration method, since +we can guarantee backwards compatibility for all your worker definitions. No more forking pyworker :) + +If you encounter and issues with using PyWorker, please create a GitHub issue and we will be happy to assist. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b484d2d..807a1c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ Requests~=2.32 transformers~=4.52 utils==1.0.* hf_transfer>=0.1.9 -vastai-sdk>=0.2.0 +vastai-sdk>=0.3.0 diff --git a/start_server.sh b/start_server.sh index 2f5ecdc..84656c2 100755 --- a/start_server.sh +++ b/start_server.sh @@ -15,7 +15,6 @@ WORKER_PORT="${WORKER_PORT:-3000}" mkdir -p "$WORKSPACE_DIR" cd "$WORKSPACE_DIR" -# make all output go to $DEBUG_LOG and stdout without having to add `... | tee -a $DEBUG_LOG` to every command exec &> >(tee -a "$DEBUG_LOG") function echo_var(){ @@ -25,11 +24,10 @@ function echo_var(){ function report_error_and_exit(){ local error_msg="$1" echo "ERROR: $error_msg" - - # Report error to autoscaler + MTOKEN="${MASTER_TOKEN:-}" VERSION="${PYWORKER_VERSION:-0}" - + IFS=',' read -r -a REPORT_ADDRS <<< "${REPORT_ADDR}" for addr in "${REPORT_ADDRS[@]}"; do curl -sS -X POST -H 'Content-Type: application/json' \ @@ -38,35 +36,20 @@ function report_error_and_exit(){ "id": ${CONTAINER_ID:-0}, "mtoken": "${MTOKEN}", "version": "${VERSION}", - "loadtime": 0, - "new_load": 0, - "cur_load": 0, - "rej_load": 0, - "max_perf": 0, - "cur_perf": 0, "error_msg": "${error_msg}", - "num_requests_working": 0, - "num_requests_recieved": 0, - "additional_disk_usage": 0, - "working_request_idxs": [], - "cur_capacity": 0, - "max_capacity": 0, "url": "${URL:-}" } JSON )" "${addr%/}/worker_status/" || true done - + exit 1 } -[ -z "$BACKEND" ] && report_error_and_exit "BACKEND must be set!" -[ -z "$MODEL_LOG" ] && report_error_and_exit "MODEL_LOG must be set!" -[ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set!" +[ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!" [ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!" [ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!" - echo "start_server.sh" date @@ -80,8 +63,6 @@ echo_var DEBUG_LOG echo_var PYWORKER_LOG echo_var MODEL_LOG -# if instance is rebooted, we want to clear out the log file so pyworker doesn't read lines -# from the run prior to reboot. past logs are saved in $MODEL_LOG.old for debugging only if [ -e "$MODEL_LOG" ]; then echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old" if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then @@ -119,7 +100,6 @@ then fi fi - # Fork testing if [[ ! -d $SERVER_DIR ]]; then if ! git clone "${PYWORKER_REPO:-https://github.com/vast-ai/pyworker}" "$SERVER_DIR"; then report_error_and_exit "Failed to clone pyworker repository" @@ -159,8 +139,6 @@ else echo "venv: $VIRTUAL_ENV" fi -[ ! -d "$SERVER_DIR/workers/$BACKEND" ] && report_error_and_exit "$BACKEND not supported!" - if [ "$USE_SSL" = true ]; then if ! cat << EOF > /etc/openssl-san.cnf @@ -204,9 +182,6 @@ EOF fi fi - - - export REPORT_ADDR WORKER_PORT USE_SSL UNSECURED if ! cd "$SERVER_DIR"; then @@ -216,12 +191,34 @@ fi echo "launching PyWorker server" set +e -python3 -m "workers.$BACKEND.server" |& tee -a "$PYWORKER_LOG" -PY_STATUS=${PIPESTATUS[0]} + +PY_STATUS=1 + +if [ -f "$SERVER_DIR/worker.py" ]; then + echo "trying worker.py" + python3 -m "worker" |& tee -a "$PYWORKER_LOG" + PY_STATUS=${PIPESTATUS[0]} +fi + +if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/worker.py" ]; then + echo "trying workers.${BACKEND}.worker" + python3 -m "workers.${BACKEND}.worker" |& tee -a "$PYWORKER_LOG" + PY_STATUS=${PIPESTATUS[0]} +fi + +if [ "${PY_STATUS}" -ne 0 ] && [ -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then + echo "trying workers.${BACKEND}.server" + python3 -m "workers.${BACKEND}.server" |& tee -a "$PYWORKER_LOG" + PY_STATUS=${PIPESTATUS[0]} +fi + set -e if [ "${PY_STATUS}" -ne 0 ]; then - report_error_and_exit "PyWorker exited with status ${PY_STATUS}" + if [ ! -f "$SERVER_DIR/worker.py" ] && [ ! -f "$SERVER_DIR/workers/$BACKEND/worker.py" ] && [ ! -f "$SERVER_DIR/workers/$BACKEND/server.py" ]; then + report_error_and_exit "Failed to find PyWorker" + fi + report_error_and_exit "PyWorker exited with status ${PY_STATUS}" fi -echo "launching PyWorker server done" \ No newline at end of file +echo "launching PyWorker server done" diff --git a/utils/endpoint_util.py b/utils/endpoint_util.py deleted file mode 100644 index 927262e..0000000 --- a/utils/endpoint_util.py +++ /dev/null @@ -1,136 +0,0 @@ -import logging -import time -from typing import Any, Dict, Optional, Tuple - -import requests - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -class Endpoint: - """ - Utility class for handling endpoint operations. - """ - - @staticmethod - def get_endpoint_info( - endpoint_name: str, account_api_key: str, instance: str - ) -> Optional[Dict[str, Any]]: - headers = {"Authorization": f"Bearer {account_api_key}"} - url = f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}" - # Retry a few times to smooth over transient propagation/network delays - for attempt in range(4): - try: - response = requests.get(url, headers=headers, timeout=8) - if response.status_code != 200: - # brief backoff and retry - time.sleep(0.3 * (attempt + 1)) - continue - try: - data = response.json() - except Exception: - # JSON parse failed; backoff and retry - time.sleep(0.3 * (attempt + 1)) - continue - result = data.get("results", []) if isinstance(data, dict) else [] - endpoint = next( - (item for item in result if item.get("endpoint_name") == endpoint_name), - None, - ) - if endpoint and endpoint.get("id") and endpoint.get("api_key"): - return {"id": endpoint.get("id"), "api_key": endpoint.get("api_key")} - except Exception: - # network or other transient error; retry - time.sleep(0.3 * (attempt + 1)) - return None - - @staticmethod - def get_autoscaler_server_url(instance: str) -> str: - endpoints = { - "alpha": "run-alpha", - "candidate": "run-candidate", - "prod": "run", - } - host = endpoints.get(instance) - if host: - return f"https://{host}.vast.ai/" - return "http://localhost:8080" - - @staticmethod - def get_server_url(instance: str) -> str: - endpoints = { - "alpha": "alpha", - "candidate": "candidate", - "prod": "console", - } - host = endpoints.get(instance, "alpha") - return f"https://{host}.vast.ai/api/v0/endptjobs/" - - @staticmethod - def get_endpoint_api_key( - endpoint_name: str, account_api_key: str, instance: str - ) -> Optional[str]: - """ - Fetch endpoint API key from VastAI console following the healthcheck pattern. - - Args: - endpoint_name: Name of the endpoint - account_api_key: Account API key for authentication - - Returns: - Endpoint API key if successful, None otherwise - """ - headers = {"Authorization": f"Bearer {account_api_key}"} - - try: - log.debug(f"Fetching endpoint API key for endpoint: {endpoint_name}") - response = requests.get( - f"{Endpoint.get_server_url(instance)}?autoscaler_instance={instance}", - headers=headers, - timeout=8, - ) - - if response.status_code != 200: - error_msg = f"Failed to fetch endpoint API key: {response.status_code} - {response.text}" - log.debug(error_msg) - return None - - try: - data = response.json() - except Exception as e: - log.debug(f"Failed to parse JSON response: {e}") - return None - - result = data.get("results", []) - - endpoint: Optional[Dict[str, Any]] = next( - (item for item in result if item.get("endpoint_name") == endpoint_name), - None, - ) - if not endpoint: - error_msg = f"Endpoint '{endpoint_name}' not found." - log.debug(error_msg) - return None - - endpoint_api_key = endpoint.get("api_key") - if not endpoint_api_key: - error_msg = f"API key for endpoint '{endpoint_name}' not found." - log.debug(error_msg) - return None - - log.debug(f"Successfully retrieved API key for endpoint: {endpoint_name}") - return endpoint_api_key - - except requests.exceptions.RequestException as e: - error_msg = f"Request error while fetching endpoint API key: {e}" - log.debug(error_msg) - return None - except Exception as e: - error_msg = f"Unexpected error while fetching endpoint API key: {e}" - log.debug(error_msg) - return None diff --git a/utils/ssl.py b/utils/ssl.py deleted file mode 100644 index 5406ac8..0000000 --- a/utils/ssl.py +++ /dev/null @@ -1,15 +0,0 @@ -import tempfile -from functools import cache - -import requests - - -@cache -def get_cert_file_path(): - cert_url = "https://console.vast.ai/static/jvastai_root.cer" - response = requests.get(cert_url) - response.raise_for_status() - # Use a temporary file that is not deleted on close - with tempfile.NamedTemporaryFile(delete=False, suffix=".cer", mode="wb") as f: - f.write(response.content) - return f.name diff --git a/workers/ace/README.md b/workers/ace/README.md new file mode 100644 index 0000000..9cff8cf --- /dev/null +++ b/workers/ace/README.md @@ -0,0 +1,168 @@ +# ComfyUI ACE Step PyWorker + +This is the PyWorker implementation for running **ACE Step v1 3.5B** text-to-music workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI audio-generation workflows through a proxy-based architecture and returning generated audio assets. + +Each request has a static cost of `1000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. + +## Requirements + +This worker requires the following components: + +- ComfyUI (https://github.com/comfyanonymous/ComfyUI) +- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper) +- ACE Step v1 3.5B model and required custom nodes + +A Docker image is provided with the ACE Step model pre-installed, but any image may be used if the above requirements are met. + +## Endpoint + +The worker exposes a single synchronous endpoint: + +- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates audio output + +## Request Format + +The ACE Step worker **only supports custom workflow mode**. Modifier-based workflows are not supported. + +```json +{ + "input": { + "request_id": "uuid-string", + "workflow_json": { + // Complete ComfyUI ACE Step workflow JSON + }, + "s3": { }, + "webhook": { } + } +} +``` + +## Request Fields + +### Required Fields + +- `input`: Container for all request parameters +- `input.workflow_json`: Complete ComfyUI workflow graph for ACE Step audio generation + +### Optional Fields + +- `input.request_id`: Client-defined request identifier +- `input.s3`: S3-compatible storage configuration +- `input.webhook`: Webhook configuration for completion notifications + +The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI. + +## S3 Configuration + +Generated audio assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence. + +### Via Request JSON + +```json +"s3": { + "access_key_id": "your-s3-access-key", + "secret_access_key": "your-s3-secret-access-key", + "endpoint_url": "https://s3.amazonaws.com", + "bucket_name": "your-bucket", + "region": "us-east-1" +} +``` + +### Via Environment Variables + +```bash +S3_ACCESS_KEY_ID=your-key +S3_SECRET_ACCESS_KEY=your-secret +S3_BUCKET_NAME=your-bucket +S3_ENDPOINT_URL=https://s3.amazonaws.com +S3_REGION=us-east-1 +``` + +## Webhook Configuration + +Webhooks are triggered on request completion or failure. + +### Via Request JSON + +```json +"webhook": { + "url": "https://your-webhook-url", + "extra_params": { + "custom_field": "value" + } +} +``` + +### Via Environment Variables + +```bash +WEBHOOK_URL=https://your-webhook-url +WEBHOOK_TIMEOUT=30 +``` + +## Example Request + +### ACE Step Text-to-Music Workflow + +```json +{ + "input": { + "workflow_json": { + "14": { + "inputs": { + "tags": "funk, pop, upbeat, 105 BPM", + "lyrics": "Turn it up and let it flow", + "lyrics_strength": 0.99, + "clip": ["40", 1] + }, + "class_type": "TextEncodeAceStepAudio" + }, + "17": { + "inputs": { + "seconds": 180, + "batch_size": 1 + }, + "class_type": "EmptyAceStepLatentAudio" + }, + "40": { + "inputs": { + "ckpt_name": "ace_step_v1_3.5b.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + } + } + } +} +``` + +## Response Format + +A successful response includes execution metadata, ComfyUI output details, and generated audio assets. + +### Response Fields + +- `id`: Unique request identifier +- `status`: `completed`, `failed`, `processing`, `generating`, or `queued` +- `message`: Human-readable status message +- `comfyui_response`: Raw response from ComfyUI, including execution status and progress +- `output`: Array of generated outputs +- `timings`: Timing information for the request + +### Output Object + +Each entry in `output` includes: + +- `filename`: Generated file name (e.g., `.mp3`) +- `local_path`: File path on the worker +- `url`: Pre-signed download URL (if S3 is configured) +- `type`: Output type (`output`) +- `subfolder`: Output directory (e.g., `audio`) +- `node_id`: ComfyUI node that produced the output +- `output_type`: Output category (e.g., `audio`) + +## Notes and Limitations + +- Only full ComfyUI workflow JSONs are supported +- Concurrent requests are not supported per worker +- ACE Step model must be installed before processing requests +- Audio generation duration and runtime depend on workflow configuration \ No newline at end of file diff --git a/lib/__init__.py b/workers/ace/__init__.py similarity index 100% rename from lib/__init__.py rename to workers/ace/__init__.py diff --git a/workers/ace/client.py b/workers/ace/client.py new file mode 100644 index 0000000..4f6f577 --- /dev/null +++ b/workers/ace/client.py @@ -0,0 +1,149 @@ +from vastai import Serverless +import asyncio + + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-ace-endpoint") + + # ComfyUI API compatible json workflow for ACE Step + workflow = { + "14": { + "inputs": { + "tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic", + "lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin in my chest\nHeartbeats match the citys zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song", + "lyrics_strength": 0.99, + "clip": ["40", 1] + }, + "class_type": "TextEncodeAceStepAudio", + "_meta": { + "title": "TextEncodeAceStepAudio" + } + }, + "17": { + "inputs": { + "seconds": 180, + "batch_size": 1 + }, + "class_type": "EmptyAceStepLatentAudio", + "_meta": { + "title": "EmptyAceStepLatentAudio" + } + }, + "18": { + "inputs": { + "samples": ["52", 0], + "vae": ["40", 2] + }, + "class_type": "VAEDecodeAudio", + "_meta": { + "title": "VAE Decode Audio" + } + }, + "40": { + "inputs": { + "ckpt_name": "ace_step_v1_3.5b.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "44": { + "inputs": { + "conditioning": ["14", 0] + }, + "class_type": "ConditioningZeroOut", + "_meta": { + "title": "ConditioningZeroOut" + } + }, + "49": { + "inputs": { + "model": ["51", 0], + "operation": ["50", 0] + }, + "class_type": "LatentApplyOperationCFG", + "_meta": { + "title": "LatentApplyOperationCFG" + } + }, + "50": { + "inputs": { + "multiplier": 1.15 + }, + "class_type": "LatentOperationTonemapReinhard", + "_meta": { + "title": "LatentOperationTonemapReinhard" + } + }, + "51": { + "inputs": { + "shift": 6, + "model": ["40", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "52": { + "inputs": { + "seed": "__RANDOM_INT__", + "steps": 65, + "cfg": 4, + "sampler_name": "er_sde", + "scheduler": "linear_quadratic", + "denoise": 1, + "model": ["49", 0], + "positive": ["14", 0], + "negative": ["44", 0], + "latent_image": ["17", 0] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "59": { + "inputs": { + "filename_prefix": "audio/ComfyUI", + "quality": "V0", + "audioUI": "", + "audio": ["18", 0] + }, + "class_type": "SaveAudioMP3", + "_meta": { + "title": "Save Audio (MP3)" + } + } + } + + payload = { + "input": { + "request_id": "", + "workflow_json": workflow, + "s3": { + "access_key_id": "", + "secret_access_key": "", + "endpoint_url": "", + "bucket_name": "", + "region": "" + }, + "webhook": { + "url": "", + "extra_params": { + "user_id": "12345", + "project_id": "abc-def" + } + } + } + } + + response = await endpoint.request("/generate/sync", payload) + + # Response contains status, output, and any errors + print(response["response"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/workers/ace/worker.py b/workers/ace/worker.py new file mode 100644 index 0000000..11524a6 --- /dev/null +++ b/workers/ace/worker.py @@ -0,0 +1,184 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_lyrics = [ + "[verse]\nGuardian cloaked in twilight hue\nShadows melt where he breaks through\nEchoes swirl in mystic flight\nHooded hero owns the night\n\n[verse]\nThrough the chaos shapes arise\nFeral whispers, glowing eyes\nOrcs and creatures side by side\nMarch within the inky tide\n\n[chorus]\nRise above the fear and gloom\nLet your courage fully bloom\nIn the darkness stand your ground\nHear the night proclaim your sound", + "[verse]\nMorning sun on fields of gold\nGentle stories unfold\nEvery breeze a quiet song\nWhere the peaceful hearts belong\n\n[verse]\nLanterns glow at stable doors\nRustling leaves on orchard floors\nSimple joys in every hand\nLife grows soft in fertile land\n\n[chorus]\nLet the day drift slow and free\nRoot your soul where you can be\nIn this haven warm and bright\nFeel the earth breathe pure delight", + "[verse]\nLittle feet on dusty ground\nChasing dreams without a sound\nSoccer ball in morning light\nHopes take wing in youthful flight\n\n[verse]\nChrome reflections paint the day\nSwagger in the steps that play\nCopper tones in shining air\nChildhood gleaming everywhere\n\n[chorus]\nKick the world with boundless cheer\nHold the magic close and near\nIn each moment bold and true\nLet the sky belong to you", + "[verse]\nSunset bleeds across the street\nGilded calm in summer heat\nLow-rise towers rimmed with fire\nDreams ignite as lights climb higher\n\n[verse]\nFootsteps scatter through the haze\nFutures shimmer in the blaze\nEvery window tells a tale\nFloating through a tangerine veil\n\n[chorus]\nLet the neon softly glow\nLet your restless heartbeat slow\nIn this city forged in light\nCarry hope into the night", + "[verse]\nOcean breathes in rolling arcs\nSprays of diamond, glowing sparks\nWaves unfold a perfect line\nNature’s rhythm feels divine\n\n[verse]\nSun above in golden sweep\nPaints the rise of every deep\nShimmer drifting through the blue\nWorld reborn in every view\n\n[chorus]\nLet the tide pull you along\nHear the water’s ancient song\nIn the cresting waves you’ll find\nQuiet peace for heart and mind", + "[verse]\nGlass aglow with swirling light\nFruits and mints in colors bright\nIcy whispers clink and chime\nFlowing forms suspend in time\n\n[verse]\nCreamy spirals drift within\nGentle currents slowly spin\nWarm reflections lingering sweet\nMixing flavors at your feet\n\n[chorus]\nSip the glow and let it rise\nTaste the sunset in disguise\nIn this moment clear and true\nLet the warmth flow into you", + "[verse]\nEngines rumble down the lane\nCopper clouds of steam and rain\nOilpunk dreams in metal shine\nRider drifting down the line\n\n[verse]\nLeather jacket, steady glare\nStories sparking in the air\nMagazine lights frame his face\nKing of roads in timeless grace\n\n[chorus]\nThrottle up beyond the bend\nFeel the force of steel ascend\nRide the night and hold on tight\nClaim the world in streaks of light", + "[verse]\nCut-out shapes in swirling play\nTextures dance in bold array\nCats in denim, grinning wide\nStrut across the patterned tide\n\n[verse]\nPosters hum with neon glow\nSurreal scenes begin to grow\nColors crisp as folded art\nPatchwork beating like a heart\n\n[chorus]\nLet the collage come alive\nWatch the vibrant pieces thrive\nIn this joyful, crafted space\nEvery shape finds its own place", + "[verse]\nTiny world in crystal glass\nAncient tales behind the mass\nVillage lights in winter gleam\nFrozen in a mystic dream\n\n[verse]\nLantern beams in swirling air\nSoft enchantment everywhere\nShadows drift with gentle grace\nMagic sealed within the space\n\n[chorus]\nHold the sphere and you will see\nEchoes of a memory\nIn the glow of fragile light\nLives a realm of pure delight", + "[verse]\nArmor hums with power bright\nChopping sparks in jungle night\nMecha spirits shift and scream\nThrough the ferns like shattered beams\n\n[verse]\nAxes blaze in glowing arcs\nLighting up the shadowed marks\nNature roars in trembling air\nClash of steel and cosmic flare\n\n[chorus]\nRaise the fire, strike the ground\nLet your legend shake the sound\nIn the wild where echoes roam\nForge the fight and carve your home", + "[verse]\nCrowds ignite in vibrant flare\nBeats explode through smoky air\nDJ robes replaced with flame\nPope on decks in holy frame\n\n[verse]\nLeather gleams in blinding light\nTurntables spin with sacred might\nChoirs echo in the bass\nHeaven pulses through the place\n\n[chorus]\nLift the roof and shake the floor\nSacred rhythm evermore\nLet the music take control\nFeel the blessing in your soul", +] + +benchmark_dataset = [ + { + "input": { + "request_id": "", + "workflow_json": { + "14": { + "inputs": { + "tags": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic", + "lyrics": lyrics, + "lyrics_strength": 0.99, + "clip": ["40", 1] + }, + "class_type": "TextEncodeAceStepAudio", + "_meta": { + "title": "TextEncodeAceStepAudio" + } + }, + "17": { + "inputs": { + "seconds": 180, + "batch_size": 1 + }, + "class_type": "EmptyAceStepLatentAudio", + "_meta": { + "title": "EmptyAceStepLatentAudio" + } + }, + "18": { + "inputs": { + "samples": ["52", 0], + "vae": ["40", 2] + }, + "class_type": "VAEDecodeAudio", + "_meta": { + "title": "VAE Decode Audio" + } + }, + "40": { + "inputs": { + "ckpt_name": "ace_step_v1_3.5b.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "44": { + "inputs": { + "conditioning": ["14", 0] + }, + "class_type": "ConditioningZeroOut", + "_meta": { + "title": "ConditioningZeroOut" + } + }, + "49": { + "inputs": { + "model": ["51", 0], + "operation": ["50", 0] + }, + "class_type": "LatentApplyOperationCFG", + "_meta": { + "title": "LatentApplyOperationCFG" + } + }, + "50": { + "inputs": { + "multiplier": 1.15 + }, + "class_type": "LatentOperationTonemapReinhard", + "_meta": { + "title": "LatentOperationTonemapReinhard" + } + }, + "51": { + "inputs": { + "shift": 6, + "model": ["40", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "52": { + "inputs": { + "seed": "__RANDOM_INT__", + "steps": 65, + "cfg": 4, + "sampler_name": "er_sde", + "scheduler": "linear_quadratic", + "denoise": 1, + "model": ["49", 0], + "positive": ["14", 0], + "negative": ["44", 0], + "latent_image": ["17", 0] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "59": { + "inputs": { + "filename_prefix": "audio/ComfyUI", + "quality": "V0", + "audioUI": "", + "audio": ["18", 0] + }, + "class_type": "SaveAudioMP3", + "_meta": { + "title": "Save Audio (MP3)" + } + } + } + } + } for lyrics in benchmark_lyrics +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + runs=1 + ), + workload_calculator= lambda _ : 1000.0 + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/comfyui-json/README.md b/workers/comfyui-json/README.md index 9517dbb..c91b09d 100644 --- a/workers/comfyui-json/README.md +++ b/workers/comfyui-json/README.md @@ -2,7 +2,7 @@ This is the base PyWorker for ComfyUI. It provides a unified interface for running any ComfyUI workflow through a proxy-based architecture. See the [Serverless documentation](https://docs.vast.ai/serverless) for guides and how-to's. -The cost for each request has a static value of `1`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. +The cost for each request has a static value of `100`. ComfyUI does not handle concurrent workloads and there is no current provision to load multiple instances of ComfyUI per worker node. ## Instance Setup @@ -302,3 +302,11 @@ WEBHOOK_TIMEOUT=30 # Webhook timeout in seconds } } ``` + +## Client Libraries + +See the client example for implementation details on how to integrate with the ComfyUI worker. + +--- + +See Vast's serverless documentation for more details on how to use ComfyUI with autoscaler. diff --git a/workers/comfyui-json/data_types.py b/workers/comfyui-json/data_types.py deleted file mode 100644 index 1af1f8b..0000000 --- a/workers/comfyui-json/data_types.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import sys -import random -import dataclasses -from typing import Dict, Any -from functools import cache -from math import ceil -from pathlib import Path -import json -import logging - -from lib.data_types import ApiPayload, JsonDataException - -log = logging.getLogger(__file__) - -def count_workload() -> float: - # Always 100.0 where there is a single instance of ComfyUI handling requests - # Results will indicate % or a job completed per second. Avoids sub 0.1 sec performance indication - return 100.0 - -@dataclasses.dataclass -class ComfyWorkflowData(ApiPayload): - input: dict - - @classmethod - def for_test(cls): - """ - If the user has provided a benchmark workflow we can use it here to properly gauge performance. - Otherwise, use the variables available to simulate workflows of the required running time - Example: SD1.5, simple image gen 10000 steps, 512px x 512px will run for approximately 9 minutes @ ~18 it/s (RTX 4090) - """ - # Try to load benchmark.json - benchmark_file = Path("workers/comfyui-json/misc/benchmark.json") - - if benchmark_file.exists(): - try: - with open(benchmark_file, "r") as f: - benchmark_workflow = json.load(f) - return cls( - input={ - "request_id": f"test-{random.randint(1000, 99999)}", - "workflow_json": benchmark_workflow - } - ) - except (json.JSONDecodeError, IOError): - # JSON is malformed or file can't be read, fall through to default - log.error(f"Failed to benchmark using {benchmark_file}") - - # Fallback: read prompts and construct payload - log.info("Using fallback method for benchmarking") - with open("workers/comfyui-json/misc/test_prompts.txt", "r") as f: - test_prompts = f.readlines() - - test_prompt = random.choice(test_prompts).rstrip() - return cls( - input={ - "request_id": f"test-{random.randint(1000, 99999)}", - "modifier": "Text2Image", - "modifications": { - "prompt": test_prompt, - "width": os.getenv('BENCHMARK_TEST_WIDTH', 512), - "height": os.getenv('BENCHMARK_TEST_HEIGHT', 512), - "steps": os.getenv('BENCHMARK_TEST_STEPS', 20), - "seed": random.randint(0, sys.maxsize), - } - } - ) - - def generate_payload_json(self) -> Dict[str, Any]: - # input is already a dict, just return it wrapped in the expected structure - return {"input": self.input} - - def count_workload(self) -> float: - return count_workload() - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "ComfyWorkflowData": - # Extract required fields - if "input" not in json_msg: - raise JsonDataException({"input": "missing parameter"}) - - return cls( - input=json_msg["input"] - ) \ No newline at end of file diff --git a/workers/comfyui-json/misc/benchmark.json.example b/workers/comfyui-json/misc/benchmark.json.example deleted file mode 100644 index 3e20040..0000000 --- a/workers/comfyui-json/misc/benchmark.json.example +++ /dev/null @@ -1,107 +0,0 @@ -{ - "3": { - "inputs": { - "seed": "__RANDOM_INT__", - "steps": 20, - "cfg": 8, - "sampler_name": "euler", - "scheduler": "normal", - "denoise": 1, - "model": [ - "4", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - }, - "4": { - "inputs": { - "ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "5": { - "inputs": { - "width": 512, - "height": 512, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, - "6": { - "inputs": { - "text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "7": { - "inputs": { - "text": "text, watermark", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "8": { - "inputs": { - "samples": [ - "3", - 0 - ], - "vae": [ - "4", - 2 - ] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - } -} \ No newline at end of file diff --git a/workers/comfyui-json/misc/note.txt b/workers/comfyui-json/misc/note.txt new file mode 100644 index 0000000..3666cce --- /dev/null +++ b/workers/comfyui-json/misc/note.txt @@ -0,0 +1 @@ +# This folder is required for the provisioning scripts of ace and wan to complete. \ No newline at end of file diff --git a/workers/comfyui-json/misc/test_prompts.txt b/workers/comfyui-json/misc/test_prompts.txt deleted file mode 100644 index cfb8f8c..0000000 --- a/workers/comfyui-json/misc/test_prompts.txt +++ /dev/null @@ -1,34 +0,0 @@ -cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background -stardew valley, fine details -2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture -realistic futuristic city-downtown with short buildings, sunset -seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water -inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award. -biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover -generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric. -fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details -Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting -(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece -Pope Francis wearing biker (leather jacket), a masterpiece -Luke Skywalker ordering a burger and fries from the Death Star canteen. -I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar -portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece -young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece -Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting -crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting -london luxurious interior living-room, light walls -Parisian luxurious interior penthouse bedroom, dark walls, wooden panels -cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot -houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style -Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity -High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight -a landscape from the Moon with the Earth setting on the horizon, realistic, detailed -Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view -A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism -the street of amedieval fantasy town, at dawn, dark, highly detailed -overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark -a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field -electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar. diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py deleted file mode 100644 index daf35e5..0000000 --- a/workers/comfyui-json/server.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import logging -import dataclasses -import base64 -from typing import Optional, Union, Type - -import aiohttp -from aiohttp import web, ClientResponse - -from lib.backend import Backend, LogAction -from lib.data_types import EndpointHandler -from lib.server import start_server -from .data_types import ComfyWorkflowData - - -MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://127.0.0.1:18288") -COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:18188") # Raw ComfyUI server - -# This is the last log line that gets emitted once comfyui+extensions have been fully loaded -MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: " -MODEL_SERVER_ERROR_LOG_MSGS = [ - "MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted - "Value not in list: ", # This error is emitted when the model file is not there at all - "[ERROR] Provisioning Script failed", # Error inserted by provisioning script if models/nodes fail to download -] - - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -async def generate_client_response( - client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - # Check if the response is actually streaming based on response headers/content-type - is_streaming_response = ( - model_response.content_type == "text/event-stream" - or model_response.content_type == "application/x-ndjson" - or model_response.headers.get("Transfer-Encoding") == "chunked" - or "stream" in model_response.content_type.lower() - ) - - if is_streaming_response: - log.debug("Detected streaming response...") - res = web.StreamResponse() - res.content_type = model_response.content_type - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - else: - log.debug("Detected non-streaming response...") - content = await model_response.read() - return web.Response( - body=content, - status=model_response.status, - content_type=model_response.content_type - ) - - -@dataclasses.dataclass -class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]): - - @property - def endpoint(self) -> str: - return "/generate/sync" - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return f"{MODEL_SERVER_URL}/health" - - @classmethod - def payload_cls(cls) -> Type[ComfyWorkflowData]: - return ComfyWorkflowData - - def make_benchmark_payload(self) -> ComfyWorkflowData: - return ComfyWorkflowData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - return await generate_client_response(client_request, model_response) - - -backend = Backend( - model_server_url=MODEL_SERVER_URL, - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=False, - benchmark_handler=ComfyWorkflowHandler( - benchmark_runs=3, benchmark_words=100 - ), - log_actions=[ - (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), - (LogAction.Info, "Downloading:"), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -async def handle_ping(_): - return web.Response(body="pong") - - -async def handle_view(request: web.Request) -> web.Response: - """Proxy /view requests to raw ComfyUI server to fetch generated images""" - # Forward query params to raw ComfyUI (not the API wrapper) - query_string = request.query_string - url = f"{COMFYUI_URL}/view?{query_string}" - - log.debug(f"Proxying /view request to: {url}") - - try: - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: - if resp.status == 200: - content = await resp.read() - return web.Response( - body=content, - status=200, - content_type=resp.content_type or "image/png" - ) - else: - text = await resp.text() - return web.Response( - text=text, - status=resp.status, - content_type="text/plain" - ) - except Exception as e: - log.error(f"Error proxying /view: {e}") - return web.Response(text=str(e), status=500) - - -routes = [ - web.post("/generate/sync", backend.create_handler(ComfyWorkflowHandler())), - web.get("/view", handle_view), - web.get("/ping", handle_ping), -] - -if __name__ == "__main__": - start_server(backend, routes) diff --git a/workers/comfyui-json/test_load.py b/workers/comfyui-json/test_load.py deleted file mode 100644 index c493f67..0000000 --- a/workers/comfyui-json/test_load.py +++ /dev/null @@ -1,8 +0,0 @@ -from lib.test_utils import test_load_cmd, test_args -from .data_types import ComfyWorkflowData - -WORKER_ENDPOINT = "/generate/sync" - - -if __name__ == "__main__": - test_load_cmd(ComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args) diff --git a/workers/comfyui-json/worker.py b/workers/comfyui-json/worker.py new file mode 100644 index 0000000..ddb7da6 --- /dev/null +++ b/workers/comfyui-json/worker.py @@ -0,0 +1,81 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_prompts = [ + "Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.", + "Cozy farming-game scene with fine details.", + "2D vector child with soccer ball; airbrush chrome; swagger; antique copper.", + "Realistic futuristic downtown of low buildings at sunset.", + "Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.", + "Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.", + "Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.", + "Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.", + "Medieval village inside glass sphere; volumetric light; macro focus.", + "Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.", + "Pope Francis DJ in leather jacket, mixing on giant console; dramatic.", +] + + + +benchmark_dataset = [ + { + "input": { + "request_id": f"test-{random.randint(1000, 99999)}", + "modifier": "Text2Image", + "modifications": { + "prompt": prompt, + "width": 512, + "height": 512, + "steps": 20, + "seed": random.randint(0, sys.maxsize) + } + } + } for prompt in benchmark_prompts +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + ) + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/comfyui/README.md b/workers/comfyui/README.md deleted file mode 100644 index 575e597..0000000 --- a/workers/comfyui/README.md +++ /dev/null @@ -1,92 +0,0 @@ -This is the base PyWorker for comfyui. It can be used to create PyWorker that use various models and -workflows. It provides two endpoints: - -1. `/prompt`: Uses the default comfy workflow defined under `misc/default_workflows` -2. `/custom_workflow`: Allows the client to send their own comfy workflow with each API request. - -To use the comfyui PyWorker, `$COMFY_MODEL` env variable must be set in the template. Current options are -`sd3` and `flux`. Each have example clients. - -To add new models, a JSON with name `$COMFY_MODEL.json` must be created under `misc/default_workflows` - -NOTE: default workflows follow this format: - -```json -{ - "input": { - "handler": "RawWorkflow", - "aws_access_key_id": "your-s3-access-key", - "aws_secret_access_key": "your-s3-secret-access-key", - "aws_endpoint_url": "https://my-endpoint.backblaze.com", - "aws_bucket_name": "your-bucket", - "webhook_url": "your-webhook-url", - "webhook_extra_params": {}, - "workflow_json": {} - } -} -``` - -You can ignore all of these fields except for `workflow_json`. - -Fields written as "{{FOO}}" will be replaced using data from a user request. For example, SD3's workflow has the -following nodes: - -```json - "5": { - "inputs": { - "width": "{{WIDTH}}", - "height": "{{HEIGHT}}", - "batch_size": 1 - }, - - "6": { - "inputs": { - "text": "{{PROMPT}}", - "clip": ["11", 0] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - ... - "17": { - "inputs": { - "scheduler": "simple", - "steps": "{{STEPS}}", - "denoise": 1, - "model": ["12", 0] - }, - "class_type": "BasicScheduler", - "_meta": { - "title": "BasicScheduler" - } - }, - ... - "25": { - "inputs": { - "noise_seed": "{{SEED}}" - }, - "class_type": "RandomNoise", - "_meta": { - "title": "RandomNoise" - } - } - -``` - -Incoming requests have the following JSON format: - -```json -{ - prompt: str - width: int - height: int - steps: int - seed: int -} -``` - -Each value in those fields with replace the placeholder of the same name in the default workflow. - -See Vast's serverless documentation for more details on how to use comfyui with autoscaler diff --git a/workers/comfyui/__init__.py b/workers/comfyui/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/workers/comfyui/client.py b/workers/comfyui/client.py deleted file mode 100644 index 7d1935e..0000000 --- a/workers/comfyui/client.py +++ /dev/null @@ -1,170 +0,0 @@ -import logging -from urllib.parse import urljoin - -import requests - -from lib.test_utils import print_truncate_res -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path - -from vastai import Serverless - - -ENDPOINT_NAME = "my-comfyui-endpoint" -COST = 100 # Use a constant cost for image generation - -def call_default_workflow(client: Serverless) -> None: - WORKER_ENDPOINT = "/prompt" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, - } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() - message = response.json() - url = message["url"] - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=message["url"], - ) - payload = dict( - prompt="a fat fluffy cat", width=1024, height=1024, steps=20, seed=123456789 - ) - req_data = dict(payload=payload, auth_data=auth_data) - url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {url}") - response = requests.post( - url, - json=req_data, - verify=get_cert_file_path(), - ) - response.raise_for_status() - print_truncate_res(str(response.json())) - - -def call_custom_workflow_for_sd3( - endpoint_group_name: str, api_key: str, server_url: str -) -> None: - WORKER_ENDPOINT = "/custom-workflow" - COST = 100 - route_payload = { - "endpoint": endpoint_group_name, - "api_key": api_key, - "cost": COST, - } - response = requests.post( - urljoin(server_url, "/route/"), - json=route_payload, - timeout=4, - ) - response.raise_for_status() - message = response.json() - url = message["url"] - auth_data = dict( - signature=message["signature"], - cost=message["cost"], - endpoint=message["endpoint"], - reqnum=message["reqnum"], - url=message["url"], - request_idx=message["request_idx"], - ) - workflow = { - "3": { - "inputs": { - "seed": 156680208700286, - "steps": 20, - "cfg": 8, - "sampler_name": "euler", - "scheduler": "normal", - "denoise": 1, - "model": ["4", 0], - "positive": ["6", 0], - "negative": ["7", 0], - "latent_image": ["5", 0], - }, - "class_type": "KSampler", - }, - "4": { - "inputs": {"ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"}, - "class_type": "CheckpointLoaderSimple", - }, - "5": { - "inputs": {"width": 512, "height": 512, "batch_size": 1}, - "class_type": "EmptyLatentImage", - }, - "6": { - "inputs": { - "text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle", - "clip": ["4", 1], - }, - "class_type": "CLIPTextEncode", - }, - "7": { - "inputs": {"text": "text, watermark", "clip": ["4", 1]}, - "class_type": "CLIPTextEncode", - }, - "8": { - "inputs": {"samples": ["3", 0], "vae": ["4", 2]}, - "class_type": "VAEDecode", - }, - "9": { - "inputs": {"filename_prefix": "ComfyUI", "images": ["8", 0]}, - "class_type": "SaveImage", - }, - } - # these values should match the values in the custom workflow above, - # they are used to calculate workload - custom_fields = dict( - steps=20, - width=512, - height=512, - ) - req_data = dict( - payload=dict(custom_fields=custom_fields, workflow=workflow), - auth_data=auth_data, - ) - url = urljoin(url, WORKER_ENDPOINT) - print(f"url: {url}") - response = requests.post( - url, - json=req_data, - verify=get_cert_file_path(), - ) - response.raise_for_status() - print_truncate_res(str(response.json())) - - -if __name__ == "__main__": - from lib.test_utils import test_args - - args = test_args.parse_args() - endpoint_api_key = Endpoint.get_endpoint_api_key( - endpoint_name=args.endpoint_group_name, - account_api_key=args.api_key, - instance=args.instance, - ) - if endpoint_api_key: - try: - call_default_workflow( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - call_custom_workflow_for_sd3( - api_key=endpoint_api_key, - endpoint_group_name=args.endpoint_group_name, - server_url=args.server_url, - ) - except Exception as e: - log.error(f"Error during API call: {e}") - else: - log.error(f"Failed to get API key for endpoint {args.endpoint_group_name} ") diff --git a/workers/comfyui/data_types.py b/workers/comfyui/data_types.py deleted file mode 100644 index 82d9c6a..0000000 --- a/workers/comfyui/data_types.py +++ /dev/null @@ -1,205 +0,0 @@ -import sys -import os -import json -import random -import dataclasses -import inspect -from typing import Dict, Any -from functools import cache -from math import ceil -from enum import Enum - -from lib.data_types import ApiPayload, JsonDataException - - -with open("workers/comfyui/misc/test_prompts.txt", "r") as f: - test_prompts = f.readlines() - - -class Model(Enum): - Flux = "flux" - Sd3 = "sd3" - - def get_request_time(self) -> int: - match self: - case Model.Flux: - return 23 - case Model.Sd3: - return 6 - - -@cache -def get_model() -> Model: - match os.environ.get("COMFY_MODEL"): - case "flux": - return Model.Flux - case "sd3": - return Model.Sd3 - case None: - raise Exception( - "For comfyui pyworker, $COMFY_MODEL must be set in the vast template" - ) - case model: - raise Exception(f"Unsupported comfyui model: {model}") - - -@cache -def get_request_template() -> str: - with open(f"workers/comfyui/misc/default_workflows/{get_model().value}.json") as f: - return f.read() - - -def count_workload(width: int, height: int, steps: int) -> float: - """ - we want to normalize the workload is a number such that cur_perf(tokens/second) for 1024x1024 image with - 28 steps is 200 tokens on a 4090. - - in order get that we calculate the - - A = ( absolute workload based on given data ) - B = ( absolute workload for a 1024x1024 image with 28 steps ) - - and adjust the workload to 200 tokens by A/B. - - we then adjust for difference between Flux and SD3 by multiplying this value by expected request time for a - standard image(23s for Flux, 6s for SD3). - On a 4090, this would give us a workload that would give a cur_perf(workload / request_time) of around 200 - """ - - def _calculate_absolute_tokens(width_: int, height_: int, steps_: int) -> float: - """ - This is based on how openai counts image generation tokens, see: https://openai.com/api/pricing/ - - we count how many 512x512 grids are needed to cover the image. - each tile is then counted as 175 tokens. - each image generation also has constant of 85 base tokens. - - we then adjust the count based on the number of steps. The baseline number of steps is assumed to be 28. - Some testing with flux gave me this data: - - steps(X) | request time(Y) - __________|_________________ - 07(0.25x) | 11s (0.47x) - 14(0.50x) | 15s (0.65x) - 21(0.75x) | 20s (0.86x) - 28(1.00x) | 23s (1.00x) - 35(1.25x) | 28s (1.21x) - 42(1.50x) | 32s (1.39x) - 49(1.75x) | 37s (1.60x) - - this gives a linear regression of Y = 0.61*X + 6.57 - - we can use this as an adjustment_factor for token count - - adjustment_factor = (0.61 * steps + 6.57) - """ - - width_grids = ceil(width_ / 512) - height_grids = ceil(height_ / 512) - tokens = 85 + width_grids * height_grids * 175 - adjustment_factor = 0.61 * steps_ + 6.57 - return tokens * adjustment_factor - - REQUEST_TIME_FOR_STANDARD_IMAGE = get_model().get_request_time() - - absolute_tokens = _calculate_absolute_tokens( - width_=width, height_=height, steps_=steps - ) - absolute_tokens_standard_image = _calculate_absolute_tokens( - width_=1024, height_=1024, steps_=28 - ) - return REQUEST_TIME_FOR_STANDARD_IMAGE * ( - (absolute_tokens / absolute_tokens_standard_image) * 200 - ) - - -@dataclasses.dataclass -class DefaultComfyWorkflowData(ApiPayload): - prompt: str - width: int - height: int - steps: int - seed: int - - @classmethod - def for_test(cls): - - test_prompt = random.choice(test_prompts).rstrip() - return cls( - prompt=test_prompt, - width=1024, - height=1024, - steps=28, - seed=random.randint(0, sys.maxsize), - ) - - def generate_payload_json( - self, - ) -> Dict[str, Any]: - return json.loads( - get_request_template() - .replace("{{PROMPT}}", self.prompt) - # these values should be of int type. Since "{{VAR}}" is wrapped with " in the template - # to make the JSON valid, we must replace the double quotes. i.e. "{{WIDTH}}" -> 1024 and not "1024" - .replace('"{{WIDTH}}"', str(self.width)) - .replace('"{{HEIGHT}}"', str(self.height)) - .replace('"{{STEPS}}"', str(self.steps)) - .replace('"{{SEED}}"', str(self.seed)) - ) - - def count_workload(self) -> float: - return count_workload(width=self.width, height=self.height, steps=self.steps) - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "DefaultComfyWorkflowData": - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) - - -@dataclasses.dataclass -class CustomComfyWorkflowData(ApiPayload): - custom_fields: Dict[str, int] - workflow: Dict[str, Any] - - @classmethod - def for_test(cls): - raise NotImplementedError("Custom comfy workflow is not used for testing") - - def count_workload(self) -> float: - return count_workload( - width=int(self.custom_fields.get("width", 1024)), - height=int(self.custom_fields.get("height", 1024)), - steps=int(self.custom_fields.get("steps", 28)), - ) - - def generate_payload_json(self) -> Dict[str, Any]: - template_json = json.loads(get_request_template()) - template_json["input"]["workflow_json"] = self.workflow - return template_json - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "CustomComfyWorkflowData": - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) diff --git a/workers/comfyui/misc/default_workflows/flux.json b/workers/comfyui/misc/default_workflows/flux.json deleted file mode 100644 index eab7e80..0000000 --- a/workers/comfyui/misc/default_workflows/flux.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "input": { - "handler": "RawWorkflow", - "aws_access_key_id": "your-s3-access-key", - "aws_secret_access_key": "your-s3-secret-access-key", - "aws_endpoint_url": "https://my-endpoint.backblaze.com", - "aws_bucket_name": "your-bucket", - "webhook_url": "your-webhook-url", - "webhook_extra_params": {}, - "workflow_json": { - "5": { - "inputs": { - "width": "{{WIDTH}}", - "height": "{{HEIGHT}}", - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, - "6": { - "inputs": { - "text": "{{PROMPT}}", - "clip": ["11", 0] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "8": { - "inputs": { - "samples": ["13", 0], - "vae": ["10", 0] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": ["8", 0] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - }, - "10": { - "inputs": { - "vae_name": "ae.safetensors" - }, - "class_type": "VAELoader", - "_meta": { - "title": "Load VAE" - } - }, - "11": { - "inputs": { - "clip_name1": "t5xxl_fp16.safetensors", - "clip_name2": "clip_l.safetensors", - "type": "flux" - }, - "class_type": "DualCLIPLoader", - "_meta": { - "title": "DualCLIPLoader" - } - }, - "12": { - "inputs": { - "unet_name": "flux1-dev.safetensors", - "weight_dtype": "default" - }, - "class_type": "UNETLoader", - "_meta": { - "title": "Load Diffusion Model" - } - }, - "13": { - "inputs": { - "noise": ["25", 0], - "guider": ["22", 0], - "sampler": ["16", 0], - "sigmas": ["17", 0], - "latent_image": ["5", 0] - }, - "class_type": "SamplerCustomAdvanced", - "_meta": { - "title": "SamplerCustomAdvanced" - } - }, - "16": { - "inputs": { - "sampler_name": "euler" - }, - "class_type": "KSamplerSelect", - "_meta": { - "title": "KSamplerSelect" - } - }, - "17": { - "inputs": { - "scheduler": "simple", - "steps": "{{STEPS}}", - "denoise": 1, - "model": ["12", 0] - }, - "class_type": "BasicScheduler", - "_meta": { - "title": "BasicScheduler" - } - }, - "22": { - "inputs": { - "model": ["12", 0], - "conditioning": ["6", 0] - }, - "class_type": "BasicGuider", - "_meta": { - "title": "BasicGuider" - } - }, - "25": { - "inputs": { - "noise_seed": "{{SEED}}" - }, - "class_type": "RandomNoise", - "_meta": { - "title": "RandomNoise" - } - } - } - } -} diff --git a/workers/comfyui/misc/default_workflows/sd3.json b/workers/comfyui/misc/default_workflows/sd3.json deleted file mode 100644 index dd0ddc0..0000000 --- a/workers/comfyui/misc/default_workflows/sd3.json +++ /dev/null @@ -1,142 +0,0 @@ -{ - "input": { - "handler": "RawWorkflow", - "aws_access_key_id": "your-s3-access-key", - "aws_secret_access_key": "your-s3-secret-access-key", - "aws_endpoint_url": "https://my-endpoint.backblaze.com", - "aws_bucket_name": "your-bucket", - "webhook_url": "your-webhook-url", - "webhook_extra_params": {}, - "workflow_json": { - "6": { - "inputs": { - "text": "{{PROMPT}}", - "clip": ["252", 1] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "13": { - "inputs": { - "shift": 3, - "model": ["252", 0] - }, - "class_type": "ModelSamplingSD3", - "_meta": { - "title": "ModelSamplingSD3" - } - }, - "67": { - "inputs": { - "conditioning": ["71", 0] - }, - "class_type": "ConditioningZeroOut", - "_meta": { - "title": "ConditioningZeroOut" - } - }, - "68": { - "inputs": { - "start": 0.1, - "end": 1, - "conditioning": ["67", 0] - }, - "class_type": "ConditioningSetTimestepRange", - "_meta": { - "title": "ConditioningSetTimestepRange" - } - }, - "69": { - "inputs": { - "conditioning_1": ["68", 0], - "conditioning_2": ["70", 0] - }, - "class_type": "ConditioningCombine", - "_meta": { - "title": "Conditioning (Combine)" - } - }, - "70": { - "inputs": { - "start": 0, - "end": 0.1, - "conditioning": ["71", 0] - }, - "class_type": "ConditioningSetTimestepRange", - "_meta": { - "title": "ConditioningSetTimestepRange" - } - }, - "71": { - "inputs": { - "text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi", - "clip": ["252", 1] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Negative Prompt)" - } - }, - "135": { - "inputs": { - "width": "{{WIDTH}}", - "height": "{{HEIGHT}}", - "batch_size": 1 - }, - "class_type": "EmptySD3LatentImage", - "_meta": { - "title": "EmptySD3LatentImage" - } - }, - "231": { - "inputs": { - "samples": ["271", 0], - "vae": ["252", 2] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "233": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": ["231", 0] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - }, - "252": { - "inputs": { - "ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "271": { - "inputs": { - "seed": "{{SEED}}", - "steps": "{{STEPS}}", - "cfg": 4.5, - "sampler_name": "dpmpp_2m", - "scheduler": "sgm_uniform", - "denoise": 1, - "model": ["13", 0], - "positive": ["6", 0], - "negative": ["69", 0], - "latent_image": ["135", 0] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - } - } - } -} diff --git a/workers/comfyui/misc/test_prompts.txt b/workers/comfyui/misc/test_prompts.txt deleted file mode 100644 index cfb8f8c..0000000 --- a/workers/comfyui/misc/test_prompts.txt +++ /dev/null @@ -1,34 +0,0 @@ -cartoon character of a person with a hoodie , in style of cytus and deemo, ork, gold chains, realistic anime cat, dripping black goo, lineage revolution style, thug life, cute anthropomorphic bunny, balrog, arknights, aliased, very buff, black and red and yellow paint, painting illustration collage style, character composition in vector with white background -stardew valley, fine details -2D Vector Illustration of a child with soccer ball Art for Sublimation, Design Art, Chrome Art, Painting and Stunning Artwork, Highly Detailed Digital Painting, Airbrush Art, Highly Detailed Digital Artwork, Dramatic Artwork, stained antique yellow copper paint, digital airbrush art, detailed by Mark Brooks, Chicano airbrush art, Swagger! snake Culture -realistic futuristic city-downtown with short buildings, sunset -seascape by Ray Collins and artgerm, front view of a perfect wave, sunny background, ultra detailed water -inspired by realflow-cinema4d editor features, create image of a transparent luxury cup with ice fruits and mint, connected with white, yellow and pink cream, Slow - High Speed MO Photography, YouTube Video Screenshot, Abstract Clay, Transparent Cup , molecular gastronomy, wheel, 3D fluid,Simulation rendering, still video, 4k polymer clay futras photography, very surreal, Houdini Fluid Simulation, hyperrealistic CGI and FLUIDS & MULTIPHYSICS SIMULATION effect, with Somali Stain Lurex, Metallic Jacquard, Gold Thread, Mulberry Silk, Toub Saree, Warm background, a fantastic image worthy of an award. -biker with backpack on his back riding a motorcycle, Style by Ade Santora, Oilpunk, Cover photo, craig mullins style, on the cover of a magazine, Outdoor Magazine, inspired by Alex Petruk APe, image of a male biker, Cover of an award-winning magazine, the man has a backpack, photo for magazine, with a backpack, magazine cover -generate a collage-style illustration inspired by the Procreate raster graphic editor, photographic illustration with the theme, 2D vector, art for textile sublimation, containing surrealistic cartoon cat wearing a baseball cap and jeans standing in front of a poster, inspired by Sadao Watanabe, Doraemon, Japanese cartoon style, Eichiro Oda, Iconic high detail character, Director: Nakahara Nantenbō, Kastuhiro Otomo, image detailed, by Miyamoto, Hidetaka Miyazaki, Katsuhiro illustration, 8k, masterpiece, Minimize noise and grain in photo quality without lose quality and increase brightness and lighting,Symmetry and Alignment, Avoid asymmetrical shapes and out-of-focus points. Focus and Sharpness: Make sure the image is focused and sharp and encourages the viewer to see it as a work of art printed on fabric. -fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details -Iron Man, (Arnold Tsang, Toru Nakayama), Masterpiece, Studio Quality, 6k , toa, toaair, 1boy, glowing, axe, mecha, science_fiction, solo, weapon, jungle , green_background, nature, outdoors, solo, tree, weapon, mask, dynamic lighting, detailed shading, digital texture painting -(Pope Francis) wearing leather jacket is a DJ in a nightclub, mixing live on stage, giant mixing table, a masterpiece -Pope Francis wearing biker (leather jacket), a masterpiece -Luke Skywalker ordering a burger and fries from the Death Star canteen. -I want to generate a group avatar for a Feishu group chat. The role of this group is daily software technical communication. Now the subject technology stacks that members of this group discuss daily include: algorithms, data structures, optimization, functional programming, and the programming languages often discussed are: TypeScript, Java, python, etc. I hope this avatar has a simple aesthetic, this avatar is a single person avatar -portrait Anime black girl cute-fine-face, pretty face, realistic shaded Perfect face, fine details. Anime. realistic shaded lighting by Ilya Kuvshinov Giuseppe Dangelico Pino and Michael Garmash and Rob Rey, IAMAG premiere, WLOP matte print, cute freckles, masterpiece -young Disney socialite wearing a beige miniskirt, dark brown turtleneck sweater, small neckless, cute-fine-face, anime. illustration, realistic shaded perfect face, brown hair, grey eyes, fine details, realistic shaded lighting by ilya kuvshinov giuseppe dangelico pino and michael garmash and rob rey, iamag premiere, wlop matte print, a masterpiece -Cute small cat sitting in a movie theater eating chicken wiggs watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -Cute small dog sitting in a movie theater eating popcorn watching a movie ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -fox bracelet made of buckskin with fox features, rich details, fine carvings, studio lighting -crane buckskin bracelet with crane features, rich details, fine carvings, studio lighting -london luxurious interior living-room, light walls -Parisian luxurious interior penthouse bedroom, dark walls, wooden panels -cute girl, crop-top, blond hair, black glasses, stretching, with background by greg rutkowski makoto shinkai kyoto animation key art feminine mid shot -houses in front, houses background, straight houses, digital art, smooth, sharp focus, gravity falls style, doraemon style, shinchan style, anime style -Simplified technical drawing, Leonardo da Vinci, Mechanical Dinosaur Skeleton, Minimalistic annotations, Hand-drawn illustrations, Basic design and engineering, Wonder and curiosity -High quality 8K painting impressionist style of a Japanese modern city street with a girl on the foreground wearing a traditional wedding dress with a fox mask, staring at the sky, daylight -a landscape from the Moon with the Earth setting on the horizon, realistic, detailed -Isometric Atlantis city,great architecture with columns, great details, ornaments,seaweed, blue ambiance, 3D cartoon style, soft light, 45° view -A hyper realistic avatar of a guy riding on a black honda cbr 650r in leather suit,high detail, high quality,8K,photo realism -the street of amedieval fantasy town, at dawn, dark, highly detailed -overwhelmingly beautiful eagle framed with vector flowers, long shiny wavy flowing hair, polished, ultra detailed vector floral illustration mixed with hyper realism, muted pastel colors, vector floral details in background, muted colors, hyper detailed ultra intricate overwhelming realism in detailed complex scene with magical fantasy atmosphere, no signature, no watermark -a highly detailed matte painting of a man on a hill watching a rocket launch in the distance by studio ghibli, makoto shinkai, by artgerm, by wlop, by greg rutkowski, volumetric lighting, octane render, 4 k resolution, trending on artstation, masterpiece | hyperrealism| highly detailed| insanely detailed| intricate| cinematic lighting| depth of field -electronik robot and ofice ,unreal engine, cozy indoor lighting, artstation, detailed, digital painting,cinematic,character design by mark ryden and pixar and hayao miyazaki, unreal 5, daz, hyperrealistic, octane render -exquisitely intricately detailed illustration, of a small world with a lake and a rainbow, inside a closed glass jar. diff --git a/workers/comfyui/server.py b/workers/comfyui/server.py deleted file mode 100644 index 4b5e025..0000000 --- a/workers/comfyui/server.py +++ /dev/null @@ -1,143 +0,0 @@ -import os -import logging -import dataclasses -import base64 -from typing import Optional, Union, Type - -from aiohttp import web, ClientResponse -from anyio import open_file - -from lib.backend import Backend, LogAction -from lib.data_types import EndpointHandler -from lib.server import start_server -from .data_types import DefaultComfyWorkflowData, CustomComfyWorkflowData - - -MODEL_SERVER_URL = "http://127.0.0.1:18288" # API Wrapper Service - -# This is the last log line that gets emitted once comfyui+extensions have been fully loaded -MODEL_SERVER_START_LOG_MSG = "To see the GUI go to: http://127.0.0.1:18188" -MODEL_SERVER_ERROR_LOG_MSGS = [ - "MetadataIncompleteBuffer", # This error is emitted when the downloaded model is corrupted - "Value not in list: unet_name", # This error is emitted when the model file is not there at all -] - - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -async def generate_client_response( - request: web.Request, response: ClientResponse -) -> Union[web.Response, web.StreamResponse]: - _ = request - match response.status: - case 200: - log.debug("SUCCESS") - res = await response.json() - if "output" not in res: - return web.json_response( - data=dict(error="there was an error in the workflow"), - status=422, - ) - image_paths = [path["local_path"] for path in res["output"]["images"]] - if not image_paths: - return web.json_response( - data=dict(error="workflow did not produce any images"), - status=422, - ) - images = [] - for image_path in image_paths: - async with await open_file(image_path, mode="rb") as f: - contents = await f.read() - images.append( - f"data:image/png;base64,{base64.b64encode(contents).decode('utf-8')}" - ) - return web.json_response(data=dict(images=images)) - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -@dataclasses.dataclass -class DefaultComfyWorkflowHandler(EndpointHandler[DefaultComfyWorkflowData]): - - @property - def endpoint(self) -> str: - return "/runsync" - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return None - - @classmethod - def payload_cls(cls) -> Type[DefaultComfyWorkflowData]: - return DefaultComfyWorkflowData - - def make_benchmark_payload(self) -> DefaultComfyWorkflowData: - return DefaultComfyWorkflowData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - return await generate_client_response(client_request, model_response) - - -@dataclasses.dataclass -class CustomComfyWorkflowHandler(EndpointHandler[CustomComfyWorkflowData]): - - @property - def endpoint(self) -> str: - return "/runsync" - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return None - - @classmethod - def payload_cls(cls) -> Type[CustomComfyWorkflowData]: - return CustomComfyWorkflowData - - def make_benchmark_payload(self) -> CustomComfyWorkflowData: - return CustomComfyWorkflowData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - return await generate_client_response(client_request, model_response) - - -backend = Backend( - model_server_url=MODEL_SERVER_URL, - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=False, - benchmark_handler=DefaultComfyWorkflowHandler( - benchmark_runs=3, benchmark_words=100 - ), - log_actions=[ - (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), - (LogAction.Info, "Downloading:"), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -async def handle_ping(_): - return web.Response(body="pong") - - -routes = [ - web.post("/prompt", backend.create_handler(DefaultComfyWorkflowHandler())), - web.post("/custom-workflow", backend.create_handler(CustomComfyWorkflowHandler())), - web.get("/ping", handle_ping), -] - -if __name__ == "__main__": - start_server(backend, routes) diff --git a/workers/comfyui/test_load.py b/workers/comfyui/test_load.py deleted file mode 100644 index a6d468a..0000000 --- a/workers/comfyui/test_load.py +++ /dev/null @@ -1,15 +0,0 @@ -from lib.test_utils import test_load_cmd, test_args -from .data_types import DefaultComfyWorkflowData, Model - -WORKER_ENDPOINT = "/prompt" - - -if __name__ == "__main__": - test_args.add_argument( - "-m", - dest="comfy_model", - choices=list(map(lambda x: x.value, Model)), - required=True, - help="Image generation model name", - ) - test_load_cmd(DefaultComfyWorkflowData, WORKER_ENDPOINT, arg_parser=test_args) diff --git a/workers/hello_world/README.md b/workers/hello_world/README.md deleted file mode 100644 index a523e10..0000000 --- a/workers/hello_world/README.md +++ /dev/null @@ -1,321 +0,0 @@ -# Vast PyWorker - -## Hello_world example - -There is a hello_world PyWorker implementation under `workers/hello_world`. This PyWorker is -created for an LLM model server that runs on port 5001 has two API endpoints: - -1. `/generate`: generates an full response to the prompt and sends a JSON response -2. `/generate_stream`: streams a response one token at a time - -Both of these endpoints take the same API JSON payload: - -``` -{ - "prompt": String, - "max_response_tokens": Number | null -} -``` - -We want the PyWorker to also expose two endpoints that correspond to the above endpoints. - -### Structure - -All PyWorkers have four files: - -``` -. -└── workers - └── hello_world - ├── __init__.py - ├── data_types.py # contains data types representing model API endpoints - ├── server.py # contains endpoint handlers - └── test_load.py # script for load testing - -``` - -All of the classes follow strict type hinting. It is recommended that you type hint all of your function. -This will allow your IDE or VSCode with `pyright` plugin to find any type errors in your implementation. -You can also install `pyright` with `sudo npm install -g pyright` and run `pyright` in the root of the project to find -any type errors. - -### data_types.py: Contains data types representing model API endpoints - -This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker *interprets* that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses. - -These classes **must** inherit from `lib.data_types.ApiPayload`. This inheritance is not just for structure; it's how PyWorker knows how to: - -* **Parse Incoming Requests:** Convert JSON from clients into usable Python objects. -* **Calculate Workload:** Determine the computational cost of a request. -* **Generate Test Data:** Create realistic inputs for benchmarking. -* **Format Requests for the Model Server:** Prepare data for the underlying model. - - -```python -import dataclasses -import random -from typing import Dict, Any - -from transformers import OpenAIGPTTokenizer # used to count tokens in a prompt -import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model - -from lib.data_types import ApiPayload - -nltk.download("words") -WORD_LIST = nltk.corpus.words.words() - -# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs -tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt") - -@dataclasses.dataclass -class InputData(ApiPayload): - prompt: str - max_response_tokens: int - - @classmethod - def for_test(cls) -> "ApiPayload": - """defines how create a payload for load testing""" - prompt = " ".join(random.choices(WORD_LIST, k=int(250))) - return cls(prompt=prompt, max_response_tokens=300) - - def generate_payload_json(self) -> Dict[str, Any]: - """defines how to convert an ApiPayload to JSON that will be sent to model API""" - return dataclasses.asdict(self) - - def count_workload(self) -> float: - """defines how to calculate workload for a payload""" - return len(tokenizer.tokenize(self.prompt)) - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData": - """ - defines how to transform JSON data to AuthData and payload type, - in this case `InputData` defined above represents the data sent to the model API. - AuthData is data generated by autoscaler in order to authenticate payloads. - In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker - for more complicated examples - """ - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) - -``` - -### server.py: Creating Your Model's API Endpoints - -This section guides you through creating the core of your custom model API: the `EndpointHandler`. Think of `EndpointHandler` as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable. - -**Why use an `EndpointHandler`?** - -* **Organized Request Handling:** It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks). -* **Scalability:** By separating request handling from the model itself, you can easily scale your API to handle many concurrent users. -* **Flexibility:** You can customize how requests are processed, validated, and transformed before being sent to your model. -* **Standard Interface:** It provides a consistent interface for interacting with your model, regardless of the underlying implementation. - -For every model API endpoint you want to expose (e.g., `/generate`, `/generate_stream`), you'll implement an `EndpointHandler`. This class is responsible for: -The `EndpointHandler` achieves this through several key methods: - -* **Receiving and validating incoming requests (`get_data_from_request`):** This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests. -* **Defining the endpoint (`endpoint`):** This method specifies the URL endpoint on the model API server where requests will be sent (e.g., `/generate`). -* **Specifying the payload type (`payload_cls`):** This method indicates the specific `ApiPayload` class used for this endpoint, defining the structure of the request data. -* **Creating benchmark payloads (`make_benchmark_payload`):** This method creates payloads specifically for benchmarking the model's performance. -* **Handling the model's response (`generate_client_response`):** This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed. - -The `EndpointHandler` class has several abstract functions that you *must* implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: `/generate` (for synchronous requests) and `/generate_stream` (for streaming responses): - -```python - -""" -AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route. -When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData -work. -When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON: -{ - auth_data: AuthData, - payload : InputData # defined above -} -""" -from aiohttp import web - -from lib.data_types import EndpointHandler, JsonDataException -from lib.server import start_server -from .data_types import InputData - -# This class is the implementer for the '/generate' endpoint of model API -@dataclasses.dataclass -class GenerateHandler(EndpointHandler[InputData]): - - @property - def endpoint(self) -> str: - # the API endpoint - return "/generate" - - @classmethod - def payload_cls(cls) -> Type[InputData]: - """this function should just return ApiPayload subclass used by this handler""" - return InputData - - def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: - """ - defines how to convert `InputData` defined above, to - JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but - can be more complicated, See comfyui for an example - """ - return dataclasses.asdict(payload) - - def make_benchmark_payload(self) -> InputData: - """ - defines how to generate an InputData for benchmarking. This needs to be defined in only - one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test() - method on InputData. However, in some cases you might need to fine tune your InputData used for - benchmarking to closely resemble the average request users call the endpoint with in order to get best - autoscaling performance - """ - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - """ - defines how to convert a model API response to a response to PyWorker client - """ - _ = client_request - match model_response.status: - case 200: - log.debug("SUCCESS") - data = await model_response.json() - return web.json_response(data=data) - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -``` - -We also handle `GenerateStreamHandler` for streaming responses. It is identical to `GenerateHandler`, except for -the endpoint name and how we create a web response, as it is a streaming response: - -```python -class GenerateStreamHandler(EndpointHandler[InputData]): - @property - def endpoint(self) -> str: - return "/generate_stream" - - @classmethod - def payload_cls(cls) -> Type[InputData]: - return InputData - - def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: - return dataclasses.asdict(payload) - - def make_benchmark_payload(self) -> InputData: - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - match model_response.status: - case 200: - log.debug("Streaming response...") - res = web.StreamResponse() - res.content_type = "text/event-stream" - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -``` - -You can now instantiate a Backend and use it to handle requests. - -```python -from lib.backend import Backend, LogAction - -# the url and port of model API -MODEL_SERVER_URL = "http://0.0.0.0:5001" - - -# This is the log line that is emitted once the server has started -MODEL_SERVER_START_LOG_MSG = "server has started" -MODEL_SERVER_ERROR_LOG_MSGS = [ - "Exception: corrupted model file" # message in the logs indicating the unrecoverable error -] - -backend = Backend( - model_server_url=MODEL_SERVER_URL, - # location of model log file - model_log_file=os.environ["MODEL_LOG"], - # for some model backends that can only handle one request at a time, be sure to set this to False to - # let PyWorker handling queueing requests. - allow_parallel_requests=True, - # give the backend an EndpointHandler instance that is used for benchmarking - # number of benchmark run and number of words for a random benchmark run are given - benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), - # defines how to handle specific log messages. See docstring of LogAction for details - log_actions=[ - (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), - (LogAction.Info, '"message":"Download'), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - -# this is a simple ping handler for PyWorker -async def handle_ping(_: web.Request): - return web.Response(body="pong") - -# this is a handler for forwarding a health check to model API -async def handle_healthcheck(_: web.Request): - healthcheck_res = await backend.session.get("/healthcheck") - return web.Response(body=healthcheck_res.content, status=healthcheck_res.status) - -routes = [ - web.post("/generate", backend.create_handler(GenerateHandler())), - web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())), - web.get("/ping", handle_ping), - web.get("/healthcheck", handle_healthcheck), -] - -if __name__ == "__main__": - # start server, called from start_server.sh - start_server(backend, routes) -``` - -### test_load.py - -Here you can create a script that allows you test an endpoint group running instances with this PyWorker - -```python -from lib.test_harness import run -from .data_types import InputData - -WORKER_ENDPOINT = "/generate" - -if __name__ == "__main__": - run(InputData.for_test(), WORKER_ENDPOINT) -``` - -You can then run the following command from the root of this repo to load test endpoint group: - -```sh -# sends 1000 requests at the rate of 0.5 requests per second -python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME" -``` diff --git a/workers/hello_world/__init__.py b/workers/hello_world/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/workers/hello_world/client.py b/workers/hello_world/client.py deleted file mode 100644 index e69de29..0000000 diff --git a/workers/hello_world/data_types.py b/workers/hello_world/data_types.py deleted file mode 100644 index 0c2a296..0000000 --- a/workers/hello_world/data_types.py +++ /dev/null @@ -1,48 +0,0 @@ -import dataclasses -import random -import inspect -from typing import Dict, Any - -from transformers import OpenAIGPTTokenizer -import nltk - -from lib.data_types import ApiPayload, JsonDataException - -nltk.download("words") -WORD_LIST = nltk.corpus.words.words() - -# used to count to count tokens and workload for LLM -tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt") - - -@dataclasses.dataclass -class InputData(ApiPayload): - prompt: str - max_response_tokens: int - - @classmethod - def for_test(cls) -> "InputData": - prompt = " ".join(random.choices(WORD_LIST, k=int(250))) - return cls(prompt=prompt, max_response_tokens=300) - - def generate_payload_json(self) -> Dict[str, Any]: - return dataclasses.asdict(self) - - def count_workload(self) -> int: - return len(tokenizer.tokenize(self.prompt)) - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData": - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) diff --git a/workers/hello_world/server.py b/workers/hello_world/server.py deleted file mode 100644 index 91fb9a5..0000000 --- a/workers/hello_world/server.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -PyWorker works as a man-in-the-middle between the client and model API. It's function is: -1. receive request from client, update metrics such as workload of a request, number of pending requests, etc. -2a. transform the data and forward the transformed data to model API -2b. send updated metrics to autoscaler -3. transform response from model API(if needed) and forward the response to client - -PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also -write function to just forward requests that don't generate anything with the model to model API without an -EndpointHandler. This is useful for endpoints such as healthchecks. See below for example -""" - -import os -import logging -import dataclasses -from typing import Dict, Any, Optional, Union, Type - -from aiohttp import web, ClientResponse - -from lib.backend import Backend, LogAction -from lib.data_types import EndpointHandler -from lib.server import start_server -from .data_types import InputData - -# the url and port of model API -MODEL_SERVER_URL = "http://0.0.0.0:5001" - - -# This is the log line that is emitted once the server has started -MODEL_SERVER_START_LOG_MSG = "infer server has started" -MODEL_SERVER_ERROR_LOG_MSGS = [ - "Exception: corrupted model file" # message in the logs indicating the unrecoverable error -] - - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -# This class is the implementer for the '/generate' endpoint of model API -@dataclasses.dataclass -class GenerateHandler(EndpointHandler[InputData]): - - @property - def endpoint(self) -> str: - # the API endpoint - return "/generate" - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return None - - @classmethod - def payload_cls(cls) -> Type[InputData]: - return InputData - - def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: - """ - defines how to convert `InputData` defined above, to - json data to be sent to the model API - """ - return dataclasses.asdict(payload) - - def make_benchmark_payload(self) -> InputData: - """ - defines how to generate an InputData for benchmarking. This needs to be defined in only - one EndpointHandler, the one passed to the backend as the benchmark handler - """ - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - """ - defines how to convert a model API response to a response to PyWorker client - """ - _ = client_request - match model_response.status: - case 200: - log.debug("SUCCESS") - data = await model_response.json() - return web.json_response(data=data) - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the -# response, which itself is streaming, back to the client. -# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the -# streaming response from model API to client -class GenerateStreamHandler(EndpointHandler[InputData]): - @property - def endpoint(self) -> str: - return "/generate_stream" - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return None - - @classmethod - def payload_cls(cls) -> Type[InputData]: - return InputData - - def generate_payload_json(self, payload: InputData) -> Dict[str, Any]: - return dataclasses.asdict(payload) - - def make_benchmark_payload(self) -> InputData: - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - match model_response.status: - case 200: - log.debug("Streaming response...") - res = web.StreamResponse() - res.content_type = "text/event-stream" - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process -# incoming requests -backend = Backend( - model_server_url=MODEL_SERVER_URL, - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=True, - # give the backend a handler instance that is used for benchmarking - # number of benchmark run and number of words for a random benchmark run are given - benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), - # defines how to handle specific log messages. See docstring of LogAction for details - log_actions=[ - (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG), - (LogAction.Info, '"message":"Download'), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -# this is a simple ping handler for pyworker -async def handle_ping(_: web.Request): - return web.Response(body="pong") - - -# this is a handler for forwarding a health check to modelAPI -async def handle_healthcheck(_: web.Request): - healthcheck_res = await backend.session.get("/healthcheck") - return web.Response(body=healthcheck_res.content, status=healthcheck_res.status) - - -routes = [ - web.post("/generate", backend.create_handler(GenerateHandler())), - web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())), - web.get("/ping", handle_ping), - web.get("/healthcheck", handle_healthcheck), -] - -if __name__ == "__main__": - # start the PyWorker server - start_server(backend, routes) diff --git a/workers/hello_world/test_load.py b/workers/hello_world/test_load.py deleted file mode 100644 index b0fc674..0000000 --- a/workers/hello_world/test_load.py +++ /dev/null @@ -1,7 +0,0 @@ -from lib.test_utils import test_load_cmd, test_args -from .data_types import InputData - -WORKER_ENDPOINT = "/generate" - -if __name__ == "__main__": - test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args) diff --git a/workers/openai/README.templates.md b/workers/openai/README.templates.md deleted file mode 100644 index f4d7c2b..0000000 --- a/workers/openai/README.templates.md +++ /dev/null @@ -1,77 +0,0 @@ -# + (serverless) - -Run with our serverless autoscaling infrastructure. - -See the [serverless documentation](https://docs.vast.ai/serverless) and the [Getting Started](https://docs.vast.ai/serverless/getting-started) guide for in-depth details about how to use these templates. - -## Configuration - -Two environment variables are provided to help you configure the server: - -| Variable | Default Value | Used For | -| --- | --- | --- | -| `MODEL_NAME` | `` | The model to load. Also accepts [hf.co/repo/model](#) links | -| `` | `` | Arguments to pass to the `` command | - -This template has been configured to work with VRAM. Setting alternative models and server arguments will change the VRAM requirements. Check model cards and for guidance. - -## Usage - -We have provided a demonstration client to help you implement this template into your own infrastructure - -### Client Setup - -Clone the PyWorker repository to your local machine and install the necessary requirements for running the test client. - -```bash -git clone https://github.com/vast-ai/pyworker -cd pyworker -pip install uv -uv venv -p 3.12 -source .venv/bin/activate -uv pip install -r requirements.txt -``` - -### Completions - -Call to `/v1/completions` with json response - -```bash -python -m workers.openai.client -k -e --completion --model -``` - -### Chat Completion (json) - -Call to `/v1/chat/completions` with json response - -```bash -python -m workers.openai.client -k -e --chat --model -``` - -### Chat Completion (streaming) - -Call to `/v1/chat/completions` with streaming response - -```bash -python -m workers.openai.client -k -e --chat-stream --model -``` - -### Tool Use (json) - -Call to `/v1/chat/completions` with tool and json response. - -This test defines a simple tool which will list the contents of the local pyworker directory. The output is then analysed by the model. - -```bash -python -m workers.openai.client -k -e --tools --model -``` - -### Interactive Chat (streaming) - -Interactive session with calls to `/v1/chat/completions`. - -Type `clear` to clear the chat history or `quit` to exit. - -```bash -python -m workers.openai.client -k -e --interactive --model -``` diff --git a/workers/openai/client.py b/workers/openai/client.py index a92ad95..2385aef 100644 --- a/workers/openai/client.py +++ b/workers/openai/client.py @@ -102,15 +102,13 @@ async def call_completions(client: Serverless, *, model: str, prompt: str, endpo endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "input": { - "model": model, - "prompt": prompt, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - } + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), } log.debug("POST /v1/completions %s", json.dumps(payload)[:500]) - resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"]) + resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"]) return resp["response"] async def call_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs) -> Dict[str, Any]: @@ -118,17 +116,15 @@ async def call_chat_completions(client: Serverless, *, model: str, messages: Lis endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "input": { - "model": model, - "messages": messages, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), - **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), - } + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), } log.debug("POST /v1/chat/completions %s", json.dumps(payload)[:500]) - resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"]) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"]) return resp["response"] # ---- Streaming variants ---- @@ -137,17 +133,15 @@ async def stream_completions(client: Serverless, *, model: str, prompt: str, end endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "input": { - "model": model, - "prompt": prompt, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - "stream": True, - **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), - } + "model": model, + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"stop": kwargs["stop"]} if "stop" in kwargs else {}), } log.debug("STREAM /v1/completions %s", json.dumps(payload)[:500]) - resp = await endpoint.request("/v1/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + resp = await endpoint.request("/v1/completions", payload, cost=payload["max_tokens"], stream=True) return resp["response"] # async generator async def stream_chat_completions(client: Serverless, *, model: str, messages: List[Dict[str, Any]], endpoint_name: str, **kwargs): @@ -155,18 +149,16 @@ async def stream_chat_completions(client: Serverless, *, model: str, messages: L endpoint = await client.get_endpoint(name=endpoint_name) payload = { - "input": { - "model": model, - "messages": messages, - "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), - "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), - "stream": True, - **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), - **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), - } + "model": model, + "messages": messages, + "max_tokens": kwargs.get("max_tokens", MAX_TOKENS), + "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), + "stream": True, + **({"tools": kwargs["tools"]} if "tools" in kwargs else {}), + **({"tool_choice": kwargs["tool_choice"]} if "tool_choice" in kwargs else {}), } log.debug("STREAM /v1/chat/completions %s", json.dumps(payload)[:500]) - resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["input"]["max_tokens"], stream=True) + resp = await endpoint.request("/v1/chat/completions", payload, cost=payload["max_tokens"], stream=True) return resp["response"] # async generator diff --git a/workers/openai/data_types/__init__.py b/workers/openai/data_types/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/workers/openai/data_types/client.py b/workers/openai/data_types/client.py deleted file mode 100644 index 444ae2d..0000000 --- a/workers/openai/data_types/client.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -from dataclasses import dataclass, field, fields, is_dataclass -from typing import Optional, List, Dict, Any - - -class SerializableDataclass: - def _serialize_recursive(self, obj: Any) -> Any: - if is_dataclass(obj): - return { - field.name: self._serialize_recursive(getattr(obj, field.name)) - for field in fields(obj) - } - elif isinstance(obj, dict): - return {key: self._serialize_recursive(value) for key, value in obj.items()} - elif isinstance(obj, (list, tuple)): - return [self._serialize_recursive(item) for item in obj] - elif isinstance(obj, set): - return [self._serialize_recursive(item) for item in obj] - else: - return obj - - def to_dict(self) -> Dict[str, Any]: - return self._serialize_recursive(self) - - def to_json(self, indent: int = 2) -> str: - return json.dumps(self.to_dict(), indent=indent) - - -@dataclass -class CompletionConfig(SerializableDataclass): - """Configuration for completion requests""" - - model: str - prompt: str = "Hello" - max_tokens: int = 256 - temperature: float = 0.7 - top_k: int = 20 - top_p: float = 0.4 - stream: bool = False - - -@dataclass -class ChatCompletionConfig(SerializableDataclass): - """Configuration for chat completion requests""" - - model: str - messages: list = field(default_factory=list) - max_tokens: int = 2096 - temperature: float = 0.7 - top_k: int = 20 - top_p: float = 0.4 - stream: bool = False - tools: Optional[List[Dict[str, Any]]] = field(default_factory=list) - tool_choice: str = "auto" - - def __post_init__(self): - if self.messages is None: - self.messages = [{"role": "user", "content": "Hello"}] diff --git a/workers/openai/data_types/server.py b/workers/openai/data_types/server.py deleted file mode 100644 index e549864..0000000 --- a/workers/openai/data_types/server.py +++ /dev/null @@ -1,207 +0,0 @@ -import os, json, random -from abc import ABC, abstractmethod -from dataclasses import dataclass -from lib.data_types import EndpointHandler, ApiPayload, JsonDataException -from typing import Union, Type, Dict, Any, Optional -from aiohttp import web, ClientResponse -import nltk -import logging - -nltk.download("words") -WORD_LIST = nltk.corpus.words.words() -log = logging.getLogger(__name__) - -""" -Generic dataclass accepts any dictionary in input. -""" - - -@dataclass -class GenericData(ApiPayload, ABC): - input: Dict[str, Any] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GenericData": - return cls(input=data["input"]) - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericData": - errors = {} - - # Validate required parameters - required_params = ["input"] - for param in required_params: - if param not in json_msg: - errors[param] = "missing parameter" - - if errors: - raise JsonDataException(errors) - - try: - # Create clean data dict and delegate to from_dict - clean_data = {"input": json_msg["input"]} - - return cls.from_dict(clean_data) - - except (json.JSONDecodeError, JsonDataException) as e: - errors["parameters"] = str(e) - raise JsonDataException(errors) - - @classmethod - @abstractmethod - def for_test(cls) -> "GenericData": - pass - - def generate_payload_json(self) -> Dict[str, Any]: - return self.input - - def count_workload(self) -> int: - return self.input.get("max_tokens", 0) - - -@dataclass -class GenericHandler(EndpointHandler[GenericData], ABC): - - @property - @abstractmethod - def endpoint(self) -> str: - pass - - @property - def healthcheck_endpoint(self) -> Optional[str]: - return os.environ.get("MODEL_HEALTH_ENDPOINT") - - @classmethod - def payload_cls(cls) -> Type[GenericData]: - return GenericData - - @abstractmethod - def make_benchmark_payload(self) -> GenericData: - pass - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - match model_response.status: - case 200: - # Check if the response is actually streaming based on response headers/content-type - is_streaming_response = ( - model_response.content_type == "text/event-stream" - or model_response.content_type == "application/x-ndjson" - or model_response.headers.get("Transfer-Encoding") == "chunked" - or "stream" in model_response.content_type.lower() - ) - - if is_streaming_response: - log.debug("Detected streaming response...") - res = web.StreamResponse() - res.content_type = model_response.content_type - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - else: - log.debug("Detected non-streaming response...") - content = await model_response.read() - return web.Response( - body=content, - status=200, - content_type=model_response.content_type, - ) - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -@dataclass -class CompletionsData(GenericData): - @classmethod - def for_test(cls) -> "CompletionsData": - system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base: - - Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines - with distinctive black-and-white striped coats. There are three living species: Grévy's zebra - (Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the - genus Equus with horses and asses, the three groups being the only living members of the family - Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern - and southern Africa and can be found in a variety of habitats such as savannahs, grasslands, - woodlands, shrublands, and mountainous areas. - - Please answer the following question based on the above context.""" - unique_question = " ".join(random.choices(WORD_LIST, k=int(100))) - model = os.environ.get("MODEL_NAME") - if not model: - raise ValueError("MODEL_NAME environment variable not set") - - test_input = { - "model": model, - "prompt": f"{system_prompt}\n\n{unique_question}", - "temperature": 0.7, - "max_tokens": 500, - } - return cls(input=test_input) - - -@dataclass -class CompletionsHandler(GenericHandler): - @property - def endpoint(self) -> str: - return "/v1/completions" - - @classmethod - def payload_cls(cls) -> Type[CompletionsData]: - return CompletionsData - - def make_benchmark_payload(self) -> CompletionsData: - return CompletionsData.for_test() - - -@dataclass -class ChatCompletionsData(GenericData): - """Chat completions-specific data implementation""" - - @classmethod - def for_test(cls) -> "ChatCompletionsData": - system_prompt = """You are a helpful AI assistant. You have access to the following knowledge base: - - Zebras (US: /ˈziːbrəz/, UK: /ˈzɛbrəz, ˈziː-/)[2] (subgenus Hippotigris) are African equines - with distinctive black-and-white striped coats. There are three living species: Grévy's zebra - (Equus grevyi), the plains zebra (E. quagga), and the mountain zebra (E. zebra). Zebras share the - genus Equus with horses and asses, the three groups being the only living members of the family - Equidae. Zebra stripes come in different patterns, unique to each individual. Zebras inhabit eastern - and southern Africa and can be found in a variety of habitats such as savannahs, grasslands, - woodlands, shrublands, and mountainous areas. - - Please answer the following question based on the above context.""" - unique_question = " ".join(random.choices(WORD_LIST, k=int(100))) - model = os.environ.get("MODEL_NAME") - if not model: - raise ValueError("MODEL_NAME environment variable not set") - - # Chat completions use messages format instead of prompt - test_input = { - "model": model, - "messages": [ - {"role": "system", "content": system_prompt}, # Shared prefix - {"role": "user", "content": unique_question} # Unique per request - ], - "temperature": 0.7, - "max_tokens": 500, - } - return cls(input=test_input) - - -@dataclass -class ChatCompletionsHandler(GenericHandler): - @property - def endpoint(self) -> str: - return "/v1/chat/completions" - - @classmethod - def payload_cls(cls) -> Type[ChatCompletionsData]: - return ChatCompletionsData - - def make_benchmark_payload(self) -> ChatCompletionsData: - return ChatCompletionsData.for_test() diff --git a/workers/openai/server.py b/workers/openai/server.py deleted file mode 100644 index 63f21f9..0000000 --- a/workers/openai/server.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import logging -from .data_types.server import CompletionsHandler, ChatCompletionsHandler -from aiohttp import web -from lib.backend import Backend, LogAction -from lib.server import start_server - -# This line indicates that the inference server is listening -MODEL_SERVER_START_LOG_MSG = [ - "Application startup complete.", # vLLM - "llama runner started", # Ollama - '"message":"Connected","target":"text_generation_router"', # TGI - '"message":"Connected","target":"text_generation_router::server"', # TGI - "main: model loaded" # llama.cpp -] - -MODEL_SERVER_ERROR_LOG_MSGS = [ - "INFO exited: vllm", # vLLM - "RuntimeError: Engine", # vLLM - "Error: pull model manifest:", # Ollama - "stalled; retrying", # Ollama - "Error: WebserverFailed", # TGI - "Error: DownloadError", # TGI - "Error: ShardCannotStart", # TGI -] - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - -backend = Backend( - model_server_url=os.environ["MODEL_SERVER_URL"], - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=True, - max_wait_time=600.0, - benchmark_handler=CompletionsHandler(benchmark_runs=3, benchmark_words=256), - log_actions=[ - *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], - (LogAction.Info, '"message":"Download'), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -async def handle_ping(_): - return web.Response(body="pong") - - -routes = [ - web.post("/v1/completions", backend.create_handler(CompletionsHandler())), - web.post("/v1/chat/completions", backend.create_handler(ChatCompletionsHandler())), - web.get("/ping", handle_ping), -] - -if __name__ == "__main__": - start_server(backend, routes) diff --git a/workers/openai/test_load.py b/workers/openai/test_load.py deleted file mode 100644 index 9cb5f37..0000000 --- a/workers/openai/test_load.py +++ /dev/null @@ -1,434 +0,0 @@ -from lib.test_utils import test_args -from utils.endpoint_util import Endpoint -from utils.ssl import get_cert_file_path -from lib.data_types import AuthData -from .data_types.server import CompletionsData - -import os -import time -import threading -import requests -from dataclasses import dataclass -from collections import Counter -from urllib.parse import urljoin, urlparse -import re - -# Headless plotting -import matplotlib -matplotlib.use("Agg") -import logging -logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) -import matplotlib.pyplot as plt -import numpy as np -from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED -from requests.adapters import HTTPAdapter - -def get_incremented_path(path: str) -> str: - base, ext = os.path.splitext(path) - if not os.path.exists(path): - return path - i = 1 - while os.path.exists(f"{base}-{i}{ext}"): - i += 1 - return f"{base}-{i}{ext}" - -WORKER_ENDPOINT = "/v1/completions" # This will return the full text output at once. Latency metrics reflect that (ie not measuring TTFT) - -@dataclass -class ReqResult: - worker_url: str - route_ms: float - worker_ms: float - total_ms: float - ok: bool - error: str = "" - status_code: int = 0 - t_start: float = 0.0 - t_end: float = 0.0 - workload: float = 0.0 - -def do_one(endpoint_name: str, - endpoint_id: int, - endpoint_api_key: str, - server_url: str, - worker_endpoint: str, - payload, - results_list, - t0, - status_samples, - route_session, - worker_session): - try: - workload = payload.count_workload() - route_payload = {"endpoint": endpoint_name, "api_key": endpoint_api_key, "cost": workload} - headers = {"Authorization": f"Bearer {endpoint_api_key}"} - start = time.time() - r0 = route_session.post(urljoin(server_url, "/route/"), json=route_payload, headers=headers, timeout=4) - t_after_route = time.time() - if r0.status_code != 200: - results_list.append(ReqResult(worker_url="", - route_ms=(t_after_route - start) * 1000.0, - worker_ms=0.0, - total_ms=(t_after_route - start) * 1000.0, - ok=False, - error=f"route error {r0.reason} {r0.text}", - status_code=r0.status_code, - t_start=start - t0, - t_end=t_after_route - t0, - workload=workload)) - return - msg = r0.json() - - # 1) Check if we got a worker back from route - worker_url = msg.get("url", "") - if not worker_url: - status = msg.get("status", "") - m = re.search(r"total workers:\s*(\d+).*loading workers:\s*(\d+).*standby workers:\s*(\d+).*error workers:\s*(\d+)", status, re.I | re.S) - if m: - tot, loading, standby, err = map(int, m.groups()) - idle = max(tot - loading - standby - err, 0) - status_samples.append((time.time() - t0, idle)) - - # 2) If we got a worker, send the request - if worker_url: - req = dict(payload=payload.__dict__, auth_data=AuthData.from_json_msg(msg).__dict__) - t_before_worker = time.time() - r1 = worker_session.post( - urljoin(worker_url, worker_endpoint), - json=req, - verify=get_cert_file_path(), - timeout=(4, 120), - ) - t_after_worker = time.time() - if r1.status_code != 200: - results_list.append(ReqResult(worker_url=worker_url, - route_ms=(t_after_route - start) * 1000.0, - worker_ms=(t_after_worker - t_before_worker) * 1000.0, - total_ms=(t_after_worker - start) * 1000.0, - ok=False, - error=f"worker inference error {r1.reason} {r1.text}", - status_code=r1.status_code, - t_start=start - t0, - t_end=t_after_worker - t0, - workload=workload)) - return - # Success case - results_list.append(ReqResult(worker_url=worker_url, - route_ms=(t_after_route - start) * 1000.0, - worker_ms=(t_after_worker - t_before_worker) * 1000.0, - total_ms=(t_after_worker - start) * 1000.0, - ok=True, - error="", - status_code=200, - t_start=start - t0, - t_end=t_after_worker - t0, - workload=workload)) - - # 3) If so, sample via /get_endpoint_workers/ for eligible (idle) worker tracking - if worker_url: - try: - r_status = route_session.post( - urljoin(server_url, "/get_endpoint_workers/"), - json={"id": endpoint_id}, - headers={"Authorization": f"Bearer {endpoint_api_key}"}, - timeout=3, - ) - if r_status.status_code == 200: - workers = r_status.json() - idle = 0 - for w in workers: - st = str(w.get("status", "")).lower() - if (st in ("idle")): - idle += 1 - status_samples.append((time.time() - t0, idle)) - except Exception: - pass - except Exception as e: - t = time.time() - results_list.append(ReqResult(worker_url="", - route_ms=0.0, - worker_ms=0.0, - total_ms=0.0, - ok=False, - error=f"unknown error {e}", - status_code=0, - t_start=t - t0, - t_end=t - t0, - workload=0.0)) - -def run_load_with_metrics(num_requests: int, - requests_per_second: float, - endpoint_group_name: str, - account_api_key: str, - server_url: str, - worker_endpoint: str, - instance: str, - out_path: str): - - ep_info = Endpoint.get_endpoint_info(endpoint_name=endpoint_group_name, - account_api_key=account_api_key, - instance=instance) - if not ep_info or not ep_info.get("api_key") or not ep_info.get("id"): - print(f"Endpoint {endpoint_group_name} not found for API key") - return - endpoint_id = int(ep_info["id"]) - endpoint_api_key = ep_info["api_key"] - - t0 = time.time() - results = [] - status_samples = [] - max_concurrency = int(os.environ.get("MAX_CONCURRENCY", "8192")) - submit_queue_factor = 2 # cap queued tasks to reduce memory - - # Shared HTTP sessions with connection pooling (persistent connections) - def make_session(pool_connections: int, pool_maxsize: int) -> requests.Session: - sess = requests.Session() - adapter = HTTPAdapter(pool_connections=pool_connections, pool_maxsize=pool_maxsize, max_retries=0) - sess.mount("https://", adapter) - sess.mount("http://", adapter) - return sess - - # Router: mostly single host, small connection pool is sufficient - route_session = make_session(pool_connections=1, pool_maxsize=max_concurrency) - # Workers: many hosts; allow many pools and per-host concurrency up to max_concurrency - worker_session = make_session(pool_connections=64, pool_maxsize=max_concurrency // 8) - - # Fire requests using a thread pool, scheduling at requested RPS - inflight = set() - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - for i in range(num_requests): - # Pace submissions to RPS - target_time = t0 + i / max(requests_per_second, 1e-9) - sleep_s = target_time - time.time() - if sleep_s > 0: - time.sleep(min(sleep_s, 0.5)) # sleep in chunks to stay responsive - - payload = CompletionsData.for_test() - fut = executor.submit( - do_one, - endpoint_group_name, - endpoint_id, - endpoint_api_key, - server_url, - worker_endpoint, - payload, - results, - t0, - status_samples, - route_session, - worker_session, - ) - inflight.add(fut) - # Prevent unbounded queue growth - if len(inflight) >= max_concurrency * submit_queue_factor: - done, not_done = wait(inflight, return_when=FIRST_COMPLETED) - inflight = not_done - # Wait for all outstanding tasks - if inflight: - wait(inflight) - # Close sessions - try: - route_session.close() - finally: - worker_session.close() - - # Aggregate results - oks = [r for r in results if r.ok] - errs = [r for r in results if not r.ok] - total_reqs = len(results) - succ = len(oks) - - total_ms = np.array([r.total_ms for r in oks]) if succ else np.array([]) - worker_ms = np.array([r.worker_ms for r in oks]) if succ else np.array([]) - route_ms = np.array([r.route_ms for r in oks]) if succ else np.array([]) - - avg_total = float(np.mean(total_ms)) if succ else 0.0 - avg_worker = float(np.mean(worker_ms)) if succ else 0.0 - avg_route = float(np.mean(route_ms)) if succ else 0.0 - p50_total, p95_total = (float(np.percentile(total_ms, 50)), float(np.percentile(total_ms, 95))) if succ else (0.0, 0.0) - - # Distribution over workers (by host:port) - hosts = [urlparse(r.worker_url).netloc for r in oks if r.worker_url] - dist = Counter(hosts) - - # Idle over time (mode per second) - idle_ts, idle_vals = [], [] - if status_samples: - buckets = {} - for ts, idle in status_samples: - k = int(ts) - buckets.setdefault(k, []).append(idle) - keys = sorted(buckets.keys()) - idle_ts = keys - # Use the most frequent sampled value per second (mode) to keep integer counts - idle_vals = [] - for k in keys: - vals_k = [int(v) for v in buckets[k]] - if vals_k: - cnt = Counter(vals_k) - idle_vals.append(cnt.most_common(1)[0][0]) - else: - idle_vals.append(0) - - print(f"\nResults: total={total_reqs} success={succ} errors={len(errs)}") - print(f"Avg latency (ms): {avg_total:.1f} p50: {p50_total:.1f} p95: {p95_total:.1f}") - print(f"Avg route latency (ms): {avg_route:.1f} Avg worker latency (ms): {avg_worker:.1f}") - if errs: - print("Sample errors:") - for e in errs[:5]: - print(f" {e.status_code} {e.error}") - - # Plot: 2x3 grid - fig, axes = plt.subplots(2, 3, figsize=(15, 8)) - fig.suptitle(f"Load test: {endpoint_group_name} n={total_reqs}, rps={requests_per_second}, success={succ}") - - # Dist per worker - ax0 = axes[0, 0] - if dist: - items = sorted(dist.items(), key=lambda kv: kv[1], reverse=True) - labels, counts = zip(*items) - ax0.bar(range(len(labels)), counts) - ax0.set_xticks(range(len(labels))) - ax0.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) - ax0.set_title("Request distribution over workers") - ax0.set_ylabel("count") - - # Latency histogram (total) - ax1 = axes[0, 1] - if succ: - ax1.hist(total_ms, bins=30) - ax1.set_title("Total latency (ms)") - ax1.set_xlabel("ms") - ax1.set_ylabel("freq") - - # Eligible workers over time - ax_idle = axes[0, 2] - if idle_ts: - ax_idle.plot(idle_ts, idle_vals, "-o", ms=3) - ax_idle.set_title("Eligible workers over time") - ax_idle.set_xlabel("time (s)") - ax_idle.set_ylabel("eligible count") - - # Throughput over time (completions/sec) - ax_idle = axes[1, 0] - ax_idle.clear() - if succ: - per_sec = {} - for r in oks: - s = int(r.t_end) - per_sec[s] = per_sec.get(s, 0) + 1 - ts = sorted(per_sec.keys()) - vals = [per_sec[t] for t in ts] - ax_idle.plot(ts, vals, "-o", ms=3) - ax_idle.set_title("Completions per second") - ax_idle.set_xlabel("time (s)") - ax_idle.set_ylabel("completions / sec") - - # Summary text - ax3 = axes[1, 1] - ax3.axis("off") - text = ( - f"Total requests: {total_reqs}\n" - f"Success: {succ} Errors: {len(errs)}\n" - f"Avg total latency: {avg_total:.1f} ms\n" - f"p50: {p50_total:.1f} ms p95: {p95_total:.1f} ms\n" - f"Avg route latency: {avg_route:.1f} ms\n" - f"Avg worker latency: {avg_worker:.1f} ms\n" - f"300 errors: {len([r for r in errs if r.status_code >= 300 and r.status_code < 400])}\n" - f"429 errors: {len([r for r in errs if r.status_code == 429])}\n" - f"500 errors: {len([r for r in errs if r.status_code >= 500])}\n" - f"Other errors: {len([r for r in errs if r.status_code not in [300, 429, 500]])}\n" - ) - ax3.set_title("Summary") - ax3.text(0.02, 0.98, text, va="top", ha="left", fontsize=11, transform=ax3.transAxes) - - # Error count over time - ax_errors = axes[1, 2] - all_end_times = [int(r.t_end) for r in results if r.t_end > 0] - if all_end_times: - min_second = min(all_end_times) - max_second = max(all_end_times) - # Count errors per second - errors_per_second = {} - for result in errs: - second = int(result.t_end) - errors_per_second[second] = errors_per_second.get(second, 0) + 1 - # Create complete timeline including zeros - time_seconds = list(range(min_second, max_second + 1)) - error_counts = [errors_per_second.get(sec, 0) for sec in time_seconds] - ax_errors.plot(time_seconds, error_counts, "-o", ms=3) - ax_errors.set_title("Errors per second") - ax_errors.set_xlabel("time (s)") - ax_errors.set_ylabel("errors / sec") - - # Ensure unique output path and create directory if needed - final_out_path = get_incremented_path(out_path) - out_dir = os.path.dirname(final_out_path) - if out_dir: - os.makedirs(out_dir, exist_ok=True) - - plt.tight_layout(rect=[0, 0, 1, 0.96]) - plt.savefig(final_out_path, dpi=120) - print(f"Saved report to: {final_out_path}") - - # Per-worker latency boxplot (top 12 by volume) - groups = {} - for r in oks: - host = urlparse(r.worker_url).netloc - groups.setdefault(host, []).append(r.total_ms) - items = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)[:12] - if items: - labels, data = zip(*items) - fig2, axb = plt.subplots(1, 1, figsize=(12, 5)) - axb.boxplot(data, showfliers=False) - axb.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) - axb.set_title("Per-worker latency (ms)") - axb.set_ylabel("ms") - plt.tight_layout() - extra_out = get_incremented_path(os.path.splitext(out_path)[0] + "-workers.png") - plt.savefig(extra_out, dpi=120) - fig2.tight_layout() - fig2.savefig(extra_out, dpi=120) - print(f"Saved worker latency plot to: {extra_out}") - -if __name__ == "__main__": - # Check if MODEL_NAME environment variable is set - model_name_set = os.environ.get("MODEL_NAME") is not None - - # Add model argument - required only if MODEL_NAME is not set - test_args.add_argument( - "--model", - dest="model", - required=not model_name_set, - help="Model to use for completions request (required if MODEL_NAME env var not set)", - ) - - # Parse known args to get model early, before adding load args - known_args, _ = test_args.parse_known_args() - if hasattr(known_args, "model") and known_args.model: - os.environ["MODEL_NAME"] = known_args.model - print(f"Set MODEL_NAME environment variable to: {known_args.model}") - - # Load test args - test_args.add_argument("-n", dest="num_requests", type=int, required=True, help="total number of requests") - test_args.add_argument("-rps", dest="requests_per_second", type=float, required=True, help="requests per second") - test_args.add_argument("--out", dest="out_path", type=str, default="load_test_report.png", help="path to save the report image") - args = test_args.parse_args() - - server_url = { - "prod": "https://run.vast.ai", - "alpha": "https://run-alpha.vast.ai", - "candidate": "https://run-candidate.vast.ai", - "local": "http://localhost:8080" - }.get(args.instance, "http://localhost:8080") - - run_load_with_metrics( - num_requests=args.num_requests, - requests_per_second=args.requests_per_second, - endpoint_group_name=args.endpoint_group_name, - account_api_key=args.api_key, - server_url=server_url, - worker_endpoint=WORKER_ENDPOINT, - instance=args.instance, - out_path=args.out_path, - ) \ No newline at end of file diff --git a/workers/openai/worker.py b/workers/openai/worker.py new file mode 100644 index 0000000..6cf17f0 --- /dev/null +++ b/workers/openai/worker.py @@ -0,0 +1,78 @@ +import nltk +import random +import os + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# vLLM model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18000 +MODEL_LOG_FILE = '/var/log/portal/vllm.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# vLLM-specific log messages +MODEL_LOAD_LOG_MSG = [ + "Application startup complete.", +] + +MODEL_ERROR_LOG_MSGS = [ + "INFO exited: vllm", + "RuntimeError: Engine", + "Traceback (most recent call last):" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def completions_benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + model = os.environ.get("MODEL_NAME") + if not model: + raise ValueError("MODEL_NAME environment variable not set") + + benchmark_data = { + "model": model, + "prompt": prompt, + "temperature": 0.7, + "max_tokens": 500, + } + + return benchmark_data + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/v1/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=BenchmarkConfig( + generator=completions_benchmark_generator, + concurrency=100, + runs=2 + ) + ), + HandlerConfig( + route="/v1/chat/completions", + workload_calculator= lambda data: data.get("max_tokens", 0), + allow_parallel_requests=True, + max_queue_time=60.0, + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/tgi/data_types.py b/workers/tgi/data_types.py deleted file mode 100644 index 56e0b5b..0000000 --- a/workers/tgi/data_types.py +++ /dev/null @@ -1,73 +0,0 @@ -import dataclasses -import random -import inspect -from typing import Dict, Any - -from transformers import OpenAIGPTTokenizer -import nltk - -from lib.data_types import ApiPayload, JsonDataException - -nltk.download("words") -WORD_LIST = nltk.corpus.words.words() - -tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt") - - -@dataclasses.dataclass -class InputParameters: - max_new_tokens: int = 256 - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputParameters": - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - return cls( - **{ - k: v - for k, v in json_msg.items() - if k in inspect.signature(cls).parameters - } - ) - - -@dataclasses.dataclass -class InputData(ApiPayload): - inputs: str - parameters: InputParameters - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InputData": - return cls( - inputs=data["inputs"], parameters=InputParameters(**data["parameters"]) - ) - - @classmethod - def for_test(cls) -> "InputData": - prompt = " ".join(random.choices(WORD_LIST, k=int(250))) - return cls(inputs=prompt, parameters=InputParameters()) - - def generate_payload_json(self) -> Dict[str, Any]: - return dataclasses.asdict(self) - - def count_workload(self) -> int: - return self.parameters.max_new_tokens - - @classmethod - def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData": - errors = {} - for param in inspect.signature(cls).parameters: - if param not in json_msg: - errors[param] = "missing parameter" - if errors: - raise JsonDataException(errors) - try: - parameters = InputParameters.from_json_msg(json_msg["parameters"]) - return cls(inputs=json_msg["inputs"], parameters=parameters) - except JsonDataException as e: - errors["parameters"] = e.message - raise JsonDataException(errors) diff --git a/workers/tgi/server.py b/workers/tgi/server.py deleted file mode 100644 index 99fc810..0000000 --- a/workers/tgi/server.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -import logging -from typing import Union, Type -import dataclasses - -from aiohttp import web, ClientResponse - -from lib.backend import Backend, LogAction -from lib.data_types import EndpointHandler -from lib.server import start_server -from .data_types import InputData - - -MODEL_SERVER_URL = "http://0.0.0.0:5001" - -# This is the last log line that gets emitted once comfyui+extensions have been fully loaded -MODEL_SERVER_START_LOG_MSG = [ - '"message":"Connected","target":"text_generation_router"', - '"message":"Connected","target":"text_generation_router::server"', -] -MODEL_SERVER_ERROR_LOG_MSGS = [ - "Error: WebserverFailed", - "Error: DownloadError", - "Error: ShardCannotStart", -] - - -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s[%(levelname)-5s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -log = logging.getLogger(__file__) - - -@dataclasses.dataclass -class GenerateHandler(EndpointHandler[InputData]): - - @property - def endpoint(self) -> str: - return "/generate" - - @property - def healthcheck_endpoint(self) -> str: - return f"{MODEL_SERVER_URL}/health" - - @classmethod - def payload_cls(cls) -> Type[InputData]: - return InputData - - def make_benchmark_payload(self) -> InputData: - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - _ = client_request - match model_response.status: - case 200: - log.debug("SUCCESS") - data = await model_response.json() - return web.json_response(data=data) - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -class GenerateStreamHandler(EndpointHandler[InputData]): - @property - def endpoint(self) -> str: - return "/generate_stream" - - @property - def healthcheck_endpoint(self) -> str: - return f"{MODEL_SERVER_URL}/health" - - @classmethod - def payload_cls(cls) -> Type[InputData]: - return InputData - - def make_benchmark_payload(self) -> InputData: - return InputData.for_test() - - async def generate_client_response( - self, client_request: web.Request, model_response: ClientResponse - ) -> Union[web.Response, web.StreamResponse]: - match model_response.status: - case 200: - log.debug("Streaming response...") - res = web.StreamResponse() - res.content_type = "text/event-stream" - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - case code: - log.debug("SENDING RESPONSE: ERROR: unknown code") - return web.Response(status=code) - - -backend = Backend( - model_server_url=MODEL_SERVER_URL, - model_log_file=os.environ["MODEL_LOG"], - allow_parallel_requests=True, - benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256), - log_actions=[ - *[(LogAction.ModelLoaded, info_msg) for info_msg in MODEL_SERVER_START_LOG_MSG], - (LogAction.Info, '"message":"Download'), - *[ - (LogAction.ModelError, error_msg) - for error_msg in MODEL_SERVER_ERROR_LOG_MSGS - ], - ], -) - - -async def handle_ping(_): - return web.Response(body="pong") - - -routes = [ - web.post("/generate", backend.create_handler(GenerateHandler())), - web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())), - web.get("/ping", handle_ping), -] - -if __name__ == "__main__": - start_server(backend, routes) diff --git a/workers/tgi/test_load.py b/workers/tgi/test_load.py deleted file mode 100644 index b0fc674..0000000 --- a/workers/tgi/test_load.py +++ /dev/null @@ -1,7 +0,0 @@ -from lib.test_utils import test_load_cmd, test_args -from .data_types import InputData - -WORKER_ENDPOINT = "/generate" - -if __name__ == "__main__": - test_load_cmd(InputData, WORKER_ENDPOINT, arg_parser=test_args) diff --git a/workers/tgi/worker.py b/workers/tgi/worker.py new file mode 100644 index 0000000..f8084ab --- /dev/null +++ b/workers/tgi/worker.py @@ -0,0 +1,76 @@ +import nltk +import random + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# TGI model configuration +MODEL_SERVER_URL = 'http://0.0.0.0' +MODEL_SERVER_PORT = 5001 +MODEL_LOG_FILE = "/workspace/infer.log" +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# TGI-specific log messages +MODEL_LOAD_LOG_MSG = [ + '"message":"Connected","target":"text_generation_router"', + '"message":"Connected","target":"text_generation_router::server"', +] + +MODEL_ERROR_LOG_MSGS = [ + "Error: WebserverFailed", + "Error: DownloadError", + "Error: ShardCannotStart", +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Download' +] + +nltk.download("words") +WORD_LIST = nltk.corpus.words.words() + + +def benchmark_generator() -> dict: + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) + + benchmark_data = { + "inputs": prompt, + "parameters": { + "max_new_tokens": 128, + "temperature": 0.7, + "return_full_text": False + } + } + + return benchmark_data + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate", + allow_parallel_requests=True, + max_queue_time=60.0, + benchmark_config=BenchmarkConfig( + generator=benchmark_generator, + concurrency=50 + ), + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ), + HandlerConfig( + route="/generate_stream", + allow_parallel_requests=True, + max_queue_time=60.0, + workload_calculator= lambda x: x["parameters"]["max_new_tokens"] + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file diff --git a/workers/wan/README.md b/workers/wan/README.md new file mode 100644 index 0000000..857f485 --- /dev/null +++ b/workers/wan/README.md @@ -0,0 +1,170 @@ +# ComfyUI Wan 2.2 PyWorker + +This is the PyWorker implementation for running **Wan 2.2 T2V A14B** text-to-video workflows in ComfyUI. It provides a unified interface for executing complete ComfyUI video-generation workflows through a proxy-based architecture and returning generated video assets. + +Each request has a static cost of `10000`. ComfyUI does not support concurrent workloads, and there is no provision to run multiple ComfyUI instances per worker node. + +## Requirements + +This worker requires the following components: + +- ComfyUI (https://github.com/comfyanonymous/ComfyUI) +- ComfyUI API Wrapper (https://github.com/ai-dock/comfyui-api-wrapper) +- Wan 2.2 T2V A14B models and required custom nodes + +A Docker image is provided with all required Wan 2.2 models pre-installed, but any image may be used if the above requirements are met. + +## Endpoint + +The worker exposes a single synchronous endpoint: + +- `/generate/sync`: Processes a complete ComfyUI workflow JSON and generates video output + +## Request Format + +The Wan 2.2 worker **only supports custom workflow mode**. Modifier-based workflows are not supported. + +```json +{ + "input": { + "request_id": "uuid-string", + "workflow_json": { + // Complete ComfyUI Wan 2.2 workflow JSON + }, + "s3": { }, + "webhook": { } + } +} +``` + +## Request Fields + +### Required Fields + +- `input`: Container for all request parameters +- `input.workflow_json`: Complete ComfyUI workflow graph for Wan 2.2 video generation + +### Optional Fields + +- `input.request_id`: Client-defined request identifier +- `input.s3`: S3-compatible storage configuration +- `input.webhook`: Webhook configuration for completion notifications + +The special string `"__RANDOM_INT__"` may be used in the workflow JSON and will be replaced with a random integer before submission to ComfyUI. + +## S3 Configuration + +Generated video assets can be automatically uploaded to S3-compatible storage. Configuration can be supplied per request or via environment variables. Request-level values take precedence. + +### Via Request JSON + +```json +"s3": { + "access_key_id": "your-s3-access-key", + "secret_access_key": "your-s3-secret-access-key", + "endpoint_url": "https://s3.amazonaws.com", + "bucket_name": "your-bucket", + "region": "us-east-1" +} +``` + +### Via Environment Variables + +```bash +S3_ACCESS_KEY_ID=your-key +S3_SECRET_ACCESS_KEY=your-secret +S3_BUCKET_NAME=your-bucket +S3_ENDPOINT_URL=https://s3.amazonaws.com +S3_REGION=us-east-1 +``` + +## Webhook Configuration + +Webhooks are triggered on request completion or failure. + +### Via Request JSON + +```json +"webhook": { + "url": "https://your-webhook-url", + "extra_params": { + "custom_field": "value" + } +} +``` + +### Via Environment Variables + +```bash +WEBHOOK_URL=https://your-webhook-url +WEBHOOK_TIMEOUT=30 +``` + +## Example Request + +### Wan 2.2 Text-to-Video Workflow + +```json +{ + "input": { + "workflow_json": { + "90": { + "inputs": { + "clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + }, + "class_type": "CLIPLoader" + }, + "99": { + "inputs": { + "text": "A cinematic slow-motion portrait of a woman turning her head", + "clip": ["90", 0] + }, + "class_type": "CLIPTextEncode" + }, + "104": { + "inputs": { + "width": 640, + "height": 640, + "length": 81, + "batch_size": 1 + }, + "class_type": "EmptyHunyuanLatentVideo" + } + } + } +} +``` + +## Response Format + +A successful response includes execution metadata, ComfyUI output details, and generated video assets. + +### Response Fields + +- `id`: Unique request identifier +- `status`: `completed`, `failed`, `processing`, `generating`, or `queued` +- `message`: Human-readable status message +- `comfyui_response`: Raw response from ComfyUI, including execution status and progress +- `output`: Array of generated outputs +- `timings`: Timing information for the request + +### Output Object + +Each entry in `output` includes: + +- `filename`: Generated file name (e.g., `.mp4`) +- `local_path`: File path on the worker +- `url`: Pre-signed download URL (if S3 is configured) +- `type`: Output type (`output`) +- `subfolder`: Output directory (e.g., `video`) +- `node_id`: ComfyUI node that produced the output +- `output_type`: Output category (e.g., `images`) + +## Notes and Limitations + +- Only full ComfyUI workflow JSONs are supported +- Concurrent requests are not supported per worker +- Wan 2.2 models must be installed before processing requests +- Video generation workflows may take several minutes depending on resolution, length, and GPU performance \ No newline at end of file diff --git a/utils/__init__.py b/workers/wan/__init__.py similarity index 100% rename from utils/__init__.py rename to workers/wan/__init__.py diff --git a/workers/wan/client.py b/workers/wan/client.py new file mode 100644 index 0000000..cfb708d --- /dev/null +++ b/workers/wan/client.py @@ -0,0 +1,205 @@ +from vastai import Serverless +import asyncio + +async def main(): + async with Serverless() as client: + endpoint = await client.get_endpoint(name="my-wan-endpoint") + + # ComfyUI API compatible json workflow for Wan 2.2 T2V + workflow = { + "90": { + "inputs": { + "clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "91": { + "inputs": { + "text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW", + "clip": ["90", 0] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "92": { + "inputs": { + "vae_name": "wan_2.1_vae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "93": { + "inputs": { + "shift": 8.000000000000002, + "model": ["101", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "94": { + "inputs": { + "shift": 8, + "model": ["102", 0] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "95": { + "inputs": { + "add_noise": "disable", + "noise_seed": 0, + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 10, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": ["94", 0], + "positive": ["99", 0], + "negative": ["91", 0], + "latent_image": ["96", 0] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "96": { + "inputs": { + "add_noise": "enable", + "noise_seed": "__RANDOM_INT__", + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 0, + "end_at_step": 10, + "return_with_leftover_noise": "enable", + "model": ["93", 0], + "positive": ["99", 0], + "negative": ["91", 0], + "latent_image": ["104", 0] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "97": { + "inputs": { + "samples": ["95", 0], + "vae": ["92", 0] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "98": { + "inputs": { + "filename_prefix": "video/ComfyUI", + "format": "auto", + "codec": "auto", + "video": ["100", 0] + }, + "class_type": "SaveVideo", + "_meta": { + "title": "Save Video" + } + }, + "99": { + "inputs": { + "text": "Beautiful young European woman with honey blonde hair gracefully turning her head back over shoulder, gentle smile, bright eyes looking at camera. Hair flowing in slow motion as she turns. Soft natural lighting, clean background, cinematic portrait.", + "clip": ["90", 0] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "100": { + "inputs": { + "fps": 16, + "images": ["97", 0] + }, + "class_type": "CreateVideo", + "_meta": { + "title": "Create Video" + } + }, + "101": { + "inputs": { + "unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "102": { + "inputs": { + "unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "104": { + "inputs": { + "width": 640, + "height": 640, + "length": 81, + "batch_size": 1 + }, + "class_type": "EmptyHunyuanLatentVideo", + "_meta": { + "title": "EmptyHunyuanLatentVideo" + } + } + } + + payload = { + "input": { + "request_id": "", + "workflow_json": workflow, + "s3": { + "access_key_id": "", + "secret_access_key": "", + "endpoint_url": "", + "bucket_name": "", + "region": "" + }, + "webhook": { + "url": "", + "extra_params": { + "user_id": "12345", + "project_id": "abc-def" + } + } + } + } + + response = await endpoint.request("/generate/sync", payload) + + # Response contains status, output, and any errors + print(response["response"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/workers/wan/worker.py b/workers/wan/worker.py new file mode 100644 index 0000000..174b5f4 --- /dev/null +++ b/workers/wan/worker.py @@ -0,0 +1,288 @@ +import random +import sys + +from vastai import Worker, WorkerConfig, HandlerConfig, LogActionConfig, BenchmarkConfig + +# ComyUI model configuration +MODEL_SERVER_URL = 'http://127.0.0.1' +MODEL_SERVER_PORT = 18288 +MODEL_LOG_FILE = '/var/log/portal/comfyui.log' +MODEL_HEALTHCHECK_ENDPOINT = "/health" + +# ComyUI-specific log messages +MODEL_LOAD_LOG_MSG = [ + "To see the GUI go to: " +] + +MODEL_ERROR_LOG_MSGS = [ + "MetadataIncompleteBuffer", + "Value not in list: ", + "[ERROR] Provisioning Script failed" +] + +MODEL_INFO_LOG_MSGS = [ + '"message":"Downloading' +] + +benchmark_prompts = [ + "Cartoon hoodie hero; orc, anime cat, bunny; black goo; buff; vector on white.", + "Cozy farming-game scene with fine details.", + "2D vector child with soccer ball; airbrush chrome; swagger; antique copper.", + "Realistic futuristic downtown of low buildings at sunset.", + "Perfect wave front view; sunny seascape; ultra-detailed water; artful feel.", + "Clear cup with ice, fruit, mint; creamy swirls; fluid-sim CGI; warm glow.", + "Male biker with backpack on motorcycle; oilpunk; award-worthy magazine cover.", + "Collage for textile; surreal cartoon cat in cap/jeans before poster; crisp.", + "Medieval village inside glass sphere; volumetric light; macro focus.", + "Iron Man with glowing axe; mecha sci-fi; jungle scene; dynamic light.", + "Pope Francis DJ in leather jacket, mixing on giant console; dramatic.", +] + +benchmark_dataset = [ + { + "input": { + "workflow_json": { + "90": { + "inputs": { + "clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "91": { + "inputs": { + "text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW", + "clip": [ + "90", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "92": { + "inputs": { + "vae_name": "wan_2.1_vae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "93": { + "inputs": { + "shift": 8.000000000000002, + "model": [ + "101", + 0 + ] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "94": { + "inputs": { + "shift": 8, + "model": [ + "102", + 0 + ] + }, + "class_type": "ModelSamplingSD3", + "_meta": { + "title": "ModelSamplingSD3" + } + }, + "95": { + "inputs": { + "add_noise": "disable", + "noise_seed": 0, + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 10, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "94", + 0 + ], + "positive": [ + "99", + 0 + ], + "negative": [ + "91", + 0 + ], + "latent_image": [ + "96", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "96": { + "inputs": { + "add_noise": "enable", + "noise_seed": "__RANDOM_INT__", + "steps": 20, + "cfg": 3.5, + "sampler_name": "euler", + "scheduler": "simple", + "start_at_step": 0, + "end_at_step": 10, + "return_with_leftover_noise": "enable", + "model": [ + "93", + 0 + ], + "positive": [ + "99", + 0 + ], + "negative": [ + "91", + 0 + ], + "latent_image": [ + "104", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "97": { + "inputs": { + "samples": [ + "95", + 0 + ], + "vae": [ + "92", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "98": { + "inputs": { + "filename_prefix": "video/ComfyUI", + "format": "auto", + "codec": "auto", + "video": [ + "100", + 0 + ] + }, + "class_type": "SaveVideo", + "_meta": { + "title": "Save Video" + } + }, + "99": { + "inputs": { + "text":prompt, + "clip": [ + "90", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "100": { + "inputs": { + "fps": 16, + "images": [ + "97", + 0 + ] + }, + "class_type": "CreateVideo", + "_meta": { + "title": "Create Video" + } + }, + "101": { + "inputs": { + "unet_name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "102": { + "inputs": { + "unet_name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "104": { + "inputs": { + "width": 640, + "height": 640, + "length": 81, + "batch_size": 1 + }, + "class_type": "EmptyHunyuanLatentVideo", + "_meta": { + "title": "EmptyHunyuanLatentVideo" + } + } + } + } + } for prompt in benchmark_prompts +] + +worker_config = WorkerConfig( + model_server_url=MODEL_SERVER_URL, + model_server_port=MODEL_SERVER_PORT, + model_log_file=MODEL_LOG_FILE, + model_healthcheck_url=MODEL_HEALTHCHECK_ENDPOINT, + handlers=[ + HandlerConfig( + route="/generate/sync", + allow_parallel_requests=False, + max_queue_time=10.0, + benchmark_config=BenchmarkConfig( + dataset=benchmark_dataset, + runs=1 + ), + workload_calculator= lambda _ : 10000.0 + ) + ], + log_action_config=LogActionConfig( + on_load=MODEL_LOAD_LOG_MSG, + on_error=MODEL_ERROR_LOG_MSGS, + on_info=MODEL_INFO_LOG_MSGS + ) +) + +Worker(worker_config).run() \ No newline at end of file From 29f836eb1af976155b0caebabae6d73e403b6ca4 Mon Sep 17 00:00:00 2001 From: LucasArmandVast Date: Mon, 15 Dec 2025 22:58:02 -0500 Subject: [PATCH 33/40] Backwards compatible vLLM payload (#75) * Support old vLLM payloads --- workers/openai/worker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 6cf17f0..995fb3d 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -28,6 +28,12 @@ nltk.download("words") WORD_LIST = nltk.corpus.words.words() +def request_parser(request): + data = request + if request.get("input") is not None: + data = request.get("input") + return data + def completions_benchmark_generator() -> dict: prompt = " ".join(random.choices(WORD_LIST, k=int(250))) @@ -55,6 +61,7 @@ def completions_benchmark_generator() -> dict: workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, max_queue_time=60.0, + request_parser=request_parser, benchmark_config=BenchmarkConfig( generator=completions_benchmark_generator, concurrency=100, @@ -66,6 +73,7 @@ def completions_benchmark_generator() -> dict: workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, max_queue_time=60.0, + request_parser=request_parser ) ], log_action_config=LogActionConfig( From 9daf17148797c8b78f580de2a92ac2e5c8639472 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 17 Dec 2025 11:38:55 -0800 Subject: [PATCH 34/40] Increase queue limits for vLLM and TGI --- workers/openai/worker.py | 4 ++-- workers/tgi/worker.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 995fb3d..96ad077 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -60,8 +60,8 @@ def completions_benchmark_generator() -> dict: route="/v1/completions", workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, - max_queue_time=60.0, request_parser=request_parser, + max_queue_time=600.0, benchmark_config=BenchmarkConfig( generator=completions_benchmark_generator, concurrency=100, @@ -72,8 +72,8 @@ def completions_benchmark_generator() -> dict: route="/v1/chat/completions", workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, - max_queue_time=60.0, request_parser=request_parser + max_queue_time=600.0, ) ], log_action_config=LogActionConfig( diff --git a/workers/tgi/worker.py b/workers/tgi/worker.py index f8084ab..85425e2 100644 --- a/workers/tgi/worker.py +++ b/workers/tgi/worker.py @@ -52,7 +52,7 @@ def benchmark_generator() -> dict: HandlerConfig( route="/generate", allow_parallel_requests=True, - max_queue_time=60.0, + max_queue_time=600.0, benchmark_config=BenchmarkConfig( generator=benchmark_generator, concurrency=50 @@ -62,7 +62,7 @@ def benchmark_generator() -> dict: HandlerConfig( route="/generate_stream", allow_parallel_requests=True, - max_queue_time=60.0, + max_queue_time=600.0, workload_calculator= lambda x: x["parameters"]["max_new_tokens"] ) ], From bcb04b9a328442cfcd3a4294f952fc71f93778b9 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 17 Dec 2025 11:40:40 -0800 Subject: [PATCH 35/40] add missing comma --- workers/openai/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 96ad077..9298f45 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -72,7 +72,7 @@ def completions_benchmark_generator() -> dict: route="/v1/chat/completions", workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, - request_parser=request_parser + request_parser=request_parser, max_queue_time=600.0, ) ], From e02f4bc943dbe403b763f789b0ea777abcc64d04 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 17 Dec 2025 11:55:33 -0800 Subject: [PATCH 36/40] Lowered concurrency of vLLM and TGI benchmarks --- workers/openai/worker.py | 4 ++-- workers/tgi/worker.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 9298f45..95b4dd9 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -64,8 +64,8 @@ def completions_benchmark_generator() -> dict: max_queue_time=600.0, benchmark_config=BenchmarkConfig( generator=completions_benchmark_generator, - concurrency=100, - runs=2 + concurrency=10, + runs=3 ) ), HandlerConfig( diff --git a/workers/tgi/worker.py b/workers/tgi/worker.py index 85425e2..9d83062 100644 --- a/workers/tgi/worker.py +++ b/workers/tgi/worker.py @@ -55,7 +55,8 @@ def benchmark_generator() -> dict: max_queue_time=600.0, benchmark_config=BenchmarkConfig( generator=benchmark_generator, - concurrency=50 + concurrency=10, + runs=3 ), workload_calculator= lambda x: x["parameters"]["max_new_tokens"] ), From bd3e0032a1c4df91d9f0e3e9b52918b52ccc6041 Mon Sep 17 00:00:00 2001 From: LucasArmandVast Date: Thu, 18 Dec 2025 00:01:52 -0500 Subject: [PATCH 37/40] Add SDK version checking (#76) --- start_server.sh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/start_server.sh b/start_server.sh index 84656c2..4d12a86 100755 --- a/start_server.sh +++ b/start_server.sh @@ -46,6 +46,21 @@ JSON exit 1 } +function install_vastai_sdk() { + if [ -n "${SDK_VERSION:-}" ]; then + echo "Installing vastai-sdk version ${SDK_VERSION}" + if ! uv pip install "vastai-sdk==${SDK_VERSION}"; then + report_error_and_exit "Failed to install vastai-sdk==${SDK_VERSION}" + fi + else + echo "Installing default vastai-sdk" + if ! uv pip install vastai-sdk; then + report_error_and_exit "Failed to install vastai-sdk" + fi + fi +} + + [ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!" [ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!" [ "$BACKEND" = "comfyui" ] && [ -z "$COMFY_MODEL" ] && report_error_and_exit "For comfyui backends, COMFY_MODEL must be set!" @@ -123,6 +138,8 @@ then report_error_and_exit "Failed to install Python requirements" fi + install_vastai_sdk + if ! touch ~/.no_auto_tmux; then report_error_and_exit "Failed to create ~/.no_auto_tmux" fi From 4d786b4d1772c4e3e326038797669cedd92c9667 Mon Sep 17 00:00:00 2001 From: LucasArmandVast Date: Fri, 2 Jan 2026 13:23:07 -0500 Subject: [PATCH 38/40] SDK Versioning Improvements (#77) * Add SDK_BRANCH --- start_server.sh | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/start_server.sh b/start_server.sh index 4d12a86..129031c 100755 --- a/start_server.sh +++ b/start_server.sh @@ -47,19 +47,31 @@ JSON } function install_vastai_sdk() { + # If SDK_BRANCH is set, install vastai-sdk from the vast-sdk repo at that branch/tag/commit. + if [ -n "${SDK_BRANCH:-}" ]; then + if [ -n "${SDK_VERSION:-}" ]; then + echo "WARNING: Both SDK_BRANCH and SDK_VERSION are set; using SDK_BRANCH=${SDK_BRANCH}" + fi + echo "Installing vastai-sdk from https://github.com/vast-ai/vast-sdk/ @ ${SDK_BRANCH}" + if ! uv pip install "vastai-sdk @ git+https://github.com/vast-ai/vast-sdk.git@${SDK_BRANCH}"; then + report_error_and_exit "Failed to install vastai-sdk from vast-ai/vast-sdk@${SDK_BRANCH}" + fi + return 0 + fi + if [ -n "${SDK_VERSION:-}" ]; then echo "Installing vastai-sdk version ${SDK_VERSION}" if ! uv pip install "vastai-sdk==${SDK_VERSION}"; then report_error_and_exit "Failed to install vastai-sdk==${SDK_VERSION}" fi - else - echo "Installing default vastai-sdk" - if ! uv pip install vastai-sdk; then - report_error_and_exit "Failed to install vastai-sdk" - fi + return 0 fi -} + echo "Installing default vastai-sdk" + if ! uv pip install vastai-sdk; then + report_error_and_exit "Failed to install vastai-sdk" + fi +} [ -n "$BACKEND" ] && [ -z "$HF_TOKEN" ] && report_error_and_exit "HF_TOKEN must be set when BACKEND is set!" [ -z "$CONTAINER_ID" ] && report_error_and_exit "CONTAINER_ID must be set!" From f319db6bd5f8a6f8a591e11efd5a2a9ab731b855 Mon Sep 17 00:00:00 2001 From: LucasArmandVast Date: Mon, 12 Jan 2026 20:03:18 -0500 Subject: [PATCH 39/40] flag for model log rotate (#78) --- start_server.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/start_server.sh b/start_server.sh index 129031c..2faefd2 100755 --- a/start_server.sh +++ b/start_server.sh @@ -90,7 +90,8 @@ echo_var DEBUG_LOG echo_var PYWORKER_LOG echo_var MODEL_LOG -if [ -e "$MODEL_LOG" ]; then +ROTATE_MODEL_LOG="${ROTATE_MODEL_LOG:-false}" +if [ "$ROTATE_MODEL_LOG" = "true" ] && [ -e "$MODEL_LOG" ]; then echo "Rotating model log at $MODEL_LOG to $MODEL_LOG.old" if ! cat "$MODEL_LOG" >> "$MODEL_LOG.old"; then report_error_and_exit "Failed to rotate model log" From aaca1c96459c3d41020e0e30dafd193d3d9a7c47 Mon Sep 17 00:00:00 2001 From: Lucas Armand Date: Wed, 14 Jan 2026 10:47:07 -0800 Subject: [PATCH 40/40] Updated requirements to only require vastai-sdk --- requirements.txt | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 807a1c4..6f33eb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1 @@ -aiohttp==3.10.1 -aiodns~=3.6.0 -pycares~=4.11.0 -anyio~=4.4 -lib~=4.0 -nltk~=3.9 -psutil~=6.0 -pycryptodome~=3.20 -Requests~=2.32 -transformers~=4.52 -utils==1.0.* -hf_transfer>=0.1.9 vastai-sdk>=0.3.0