diff --git a/src/placeos-rest-api/controllers/mcp.cr b/src/placeos-rest-api/controllers/mcp.cr new file mode 100644 index 00000000..e6dde4f2 --- /dev/null +++ b/src/placeos-rest-api/controllers/mcp.cr @@ -0,0 +1,345 @@ +require "./application" +require "./mcp/models" +require "./mcp/sse" + +module PlaceOS::Api + class MCP < Application + include Utils::CoreHelper + include MCPModels + base "/api/engine/v2/mcp" + + alias SessionStream = Hash(String, SSE::Connection) + class_getter session_streams : SessionStream = SessionStream.new + + class_getter session_store : Hash(String, String | Bool | Nil) = Hash(String, String | Bool | Nil).new + + add_responder("text/event-stream") { |_io, _result| } + + @[AC::Route::Filter(:before_action, only: [:handler])] + def check_accept_headers + accept = request.headers["accept"]? || "" + accept_types = accept.strip.split(',') + has_json = accept_types.any? { |media| media.strip.starts_with?(CONTENT_TYPE_JSON) } + has_sse = accept_types.any? { |media| media.strip.starts_with?(CONTENT_TYPE_SSE) } + + render_error_resp("Not Acceptable: Client must accept both application/json and text/event-stream", :not_acceptable) unless has_json && has_sse + end + + @[AC::Route::Filter(:before_action, only: [:index])] + def check_accept_headers + accept = request.headers["accept"]? || "" + accept_types = accept.strip.split(',') + has_sse = accept_types.any? { |media| media.strip.starts_with?(CONTENT_TYPE_SSE) } + + render_error_resp("Not Acceptable: Client must accept text/event-stream", :not_acceptable) unless has_sse + end + + @[AC::Route::Filter(:before_action, only: [:handler])] + def check_content_type + content_type = request.headers["content-type"]? || "" + valid = content_type.strip.split(',').any? { |media| media.strip == CONTENT_TYPE_JSON } + render_error_resp("Unsupported Media Type: Content-Type must be application/json", :unsupported_media_type) unless valid + end + + @[AC::Route::Filter(:before_action, except: [:destroy])] + def validate_protocol_version + return unless request.headers[MCP_SESSION_ID_HEADER]? + protocol_version = request.headers[MCP_PROTOCOL_VERSION_HEADER] || DEFAULT_NEGOTIATED_PROTOCOL_VERSION + unless protocol_version.in?(SUPPORTED_PROTOCOL_VERSIONS) + supported_versions = SUPPORTED_PROTOCOL_VERSIONS.join(", ") + render_error_resp("Bad Request: Unsupported protocol version: #{protocol_version}. Supported versions: #{supported_versions}", :bad_request) + end + end + + @[AC::Route::Filter(:before_action)] + def validate_or_add_session + mcp_session_id = session_store[MCP_SESSION_ID_HEADER]? + request_mcp_session = request.headers[MCP_SESSION_ID_HEADER]? + return render_error_resp("Bad Request: Missing session ID", :bad_request) if request_mcp_session.nil? && mcp_session_id + + render_error_resp("Not Found: Invalid or expired session ID", :bad_request) unless mcp_session_id == request_mcp_session + if session_val = mcp_session_id + response.headers[MCP_SESSION_ID_HEADER] = session_val.to_s + end + end + + # MCP Server HTTP Streamable endpoint + @[AC::Route::POST("/:sys_id/:module_slug", body: :raw)] + def handler( + sys_id : String, + @[AC::Param::Info(description: "the combined module class and index, index is optional and defaults to 1", example: "Display_2")] + module_slug : String, + raw : JSON::Any, + ) + return render_error_resp("Payload Too Large: Message exceeds maximum size", :payload_too_large) if raw.to_json.size > MAXIMUM_MESSAGE_SIZE + messages = raw.as_a? ? raw.as_a.map(&.as_h) : [raw.as_h] + + initialize_req = messages.any?(&.["method"]?.try &.as_s.== "initialize") + if initialize_req + return render_error_resp("Invalid Request: Server already initialized", :bad_request) if session_store[MCP_SESSION_ID_HEADER]? + return render_error_resp("Invalid Request: Only one initialization request is allowed", :bad_request) if messages.size > 1 + + session_val = UUID.random.to_s + session_store[MCP_SESSION_ID_HEADER] = session_val + response.headers[MCP_SESSION_ID_HEADER] = session_val + render json: initialize_resp(messages.first).to_json + end + + notifications = messages.select { |msg| msg["method"]?.try &.as_s.starts_with?("notifications/") } + if notifications.size > 0 + notifications.each { |notification| Log.info { {message: "Receive notification", notification: notification.to_json} } } + render :accepted + end + + errors = messages.select { |msg| msg.has_key?("error") } + if errors.size > 0 + errors.each { |error| Log.info { {message: "Receive error", error: error.to_json} } } + render :accepted + end + + result = [] of JSONRPCResponse + messages.each do |rpc_request| + method = rpc_request["method"].as_s + if method == "ping" + result << ping_resp(rpc_request) + elsif method == "tools/list" + resp = handle_tools_list(sys_id, module_slug, rpc_request["id"]) + if resp.is_a?(CallError) + break render_error_resp(resp) + end + result << resp + elsif method == "tools/call" + result << handle_tools_call(sys_id, module_slug, rpc_request) + end + end + + render json: result + end + + # MCP HTTP Streamable SSE connection requested by client for server to client communication + @[AC::Route::GET("/:sys_id/:module_slug")] + def index(sys_id : String, + @[AC::Param::Info(description: "the combined module class and index, index is optional and defaults to 1", example: "Display_2")] + module_slug : String,) + session_id = request.headers.get?(MCP_SESSION_ID_HEADER) + return render_error_resp("Bad Request: #{MCP_SESSION_ID_HEADER} header is required", :bad_request) unless session_id + return render_error_resp("Bad Request: #{MCP_SESSION_ID_HEADER} header must be a single value", :bad_request) if session_id && session_id.size > 1 + + response.headers.add(MCP_SESSION_ID_HEADER, session_id.not_nil!) + sess_key = "#{sys_id}|#{module_slug}|#{session_id}" + return render_error_resp("Conflict: Only one SSE stream is allowed per session", :conflict, ErrorCode::ConnectionClosed) if self.class.session_streams.has_key?(sess_key) + SSE.upgrade_response(response) do |conn| + self.class.session_streams[sess_key] = conn + end + end + + # Deletes established session and closes SSE connection (if any) + @[AC::Route::DELETE("/:sys_id/:module_slug")] + def destroy(sys_id : String, + @[AC::Param::Info(description: "the combined module class and index, index is optional and defaults to 1", example: "Display_2")] + module_slug : String,) : Nil + session_id = request.headers.get?(MCP_SESSION_ID_HEADER) + return render_error_resp("Bad Request: #{MCP_SESSION_ID_HEADER} header is required", :bad_request) unless session_id + return render_error_resp("Bad Request: #{MCP_SESSION_ID_HEADER} header must be a single value", :bad_request) if session_id && session_id.size > 1 + + sess_key = "#{sys_id}|#{module_slug}|#{session_id}" + return render_error_resp("Bad Request: SSE session not found", :bad_request) unless self.class.session_streams.has_key?(sess_key) + self.class.session_streams[sess_key].close + session_store.delete(MCP_SESSION_ID_HEADER) + render :ok + end + + # MCP HTTP Streamable is only requested to suport POST/GET/DELETE methods. This method returns JSONRPC error and closes connection + @[AC::Route::PUT("/:sys_id/:module_slug")] + @[AC::Route::PATCH("/:sys_id/:module_slug")] + def unsupported( + sys_id : String, + @[AC::Param::Info(description: "the combined module class and index, index is optional and defaults to 1", example: "Display_2")] + module_slug : String, + ) + header = HTTP::Headers{ + "Allow" => "GET, POST, DELETE", + } + + render_error_resp("Method Not Allowed", :method_not_allowed, ErrorCode::ConnectionClosed, header) + end + + private def ping_resp(req : Hash(String, JSON::Any)) : JSONRPCResponse + req_id = req["id"].raw.is_a?(Number) ? req["id"].as_i64 : req["id"].as_s + JSONRPCResponse.new(req_id, EmptyResult.new) + end + + private def initialize_resp(client : Hash(String, JSON::Any)) + server_info = Implementation.new(name: Api::APP_NAME, version: Api::VERSION) + req_id = client["id"].raw.is_a?(Number) ? client["id"].as_i64 : client["id"].as_s + requested_version = client["params"]["protocolVersion"].as_s + + proto_version = requested_version.in?(SUPPORTED_PROTOCOL_VERSIONS) ? requested_version : DEFAULT_NEGOTIATED_PROTOCOL_VERSION + capabilities = ServerCapabilities.new(tools: ServerCapabilities.new_capability(false)) + result = InitializeResult.new(proto_version, capabilities, server_info) + + JSONRPCResponse.new(req_id, result) + end + + alias FunctionSchema = NamedTuple(function: String, description: String, parameters: Hash(String, JSON::Any)) + + private def handle_tools_list(sys_id : String, module_slug : String, id : JSON::Any) : JSONRPCResponse | CallError + req_id = id.raw.is_a?(Number) ? id.as_i64 : id.as_s + resp = exec_func("function_schemas", sys_id, module_slug) + if resp.is_a?(CallError) + return resp.as(CallError) + end + + schemas = Array(FunctionSchema).from_json(resp.as(String)) + + tools = [] of Tool + schemas.each do |schema| + required = [] of String + properties = {} of String => JSON::Any + schema[:parameters].each do |param_name, param_spec| + next unless param_hash = param_spec.as_h? + + prop_schema = {} of String => JSON::Any + optional = false + if any_of = param_hash["anyOf"]? + types = any_of.as_a.map { |t| t["type"].as_s } + optional = types.any? { |type| type.downcase == "null" } + type = types.reject { |type| type.downcase == "null" }.first + prop_schema["type"] = JSON::Any.new(type) + selected_type = any_of.as_a.select { |val| val["type"].as_s == type && val.as_h.has_key?("format") } + format = selected_type.empty? ? type.capitalize : selected_type.first["format"].as_s + prop_schema["description"] = JSON::Any.new(format) + else + prop_schema["type"] = param_hash["type"] + prop_schema["description"] = JSON::Any.new(param_hash["type"].as_s.capitalize) + optional = false + end + + properties[param_name] = JSON::Any.new(prop_schema) + required << param_name unless optional + end + input_schema = ToolSchema.new( + properties: properties, + required: required.empty? ? nil : required + ) + tool_name = schema[:function] + title = tool_name.split('_').map(&.capitalize).join(" ") + tools << Tool.new(name: tool_name, title: title, description: schema[:description], input_schema: input_schema) + end + + JSONRPCResponse.new(req_id, ListToolResult.new(tools)) + end + + private def handle_tools_call(sys_id : String, module_slug : String, req : Hash(String, JSON::Any)) : JSONRPCResponse + req_id = req["id"].raw.is_a?(Number) ? req["id"].as_i64 : req["id"].as_s + result = if params = req["params"]?.try &.as_h? + method = params["name"] + args = params["arguments"] + resp = exec_func("function_schemas", sys_id, module_slug, args) + if resp.is_a?(CallError) + call_error = resp.as(CallError) + content = [] of ContentBlock + content << TextContent.new(call_error.error.error.message) + CallToolResult.new(content, is_error: true) + else + is_json = resp.strip.starts_with?("[") || resp.strip.starts_with?("{") + structured_content = JSON.parse(resp).as_h rescue nil if is_json + content = [] of ContentBlock + content << TextContent.new(resp) + CallToolResult.new(content, structured_content: structured_content) + end + else + content = [] of ContentBlock + content << TextContent.new("Invalid tools/call. Missing params") + CallToolResult.new(content, is_error: true) + end + + JSONRPCResponse.new(req_id, result) + end + + private def session_store + self.class.session_store + end + + private def exec_func(method : String, sys_id : String, module_slug : String, args : JSON::Any? = nil) : String | CallError + module_name, index = RemoteDriver.get_parts(module_slug) + Log.context.set(module_name: module_name, index: index, method: method) + + remote_driver = RemoteDriver.new( + sys_id: sys_id, + module_name: module_name, + index: index, + user_id: current_user.id, + ) { |module_id| + ::PlaceOS::Model::Module.find!(module_id).edge_id.as(String) + } + + response_text, status_code = remote_driver.exec( + security: driver_clearance(user_token), + function: method, + args: args, + request_id: request_id, + ) + return response_text + rescue e : RemoteDriver::Error + handle_tool_list_execute_error(e) + rescue e + create_error(e.message || "Uknown Internal error", :internal_server_error, ErrorCode::InternalError) + end + + record CallError, status_code : Symbol, error : JSONRPCError, headers : HTTP::Headers = HTTP::Headers.new + + private def create_error(message, status_code : Symbol, error_code = ErrorCode::InvalidRequest, headers = HTTP::Headers.new) + error = JSONRPCError.new("server-error", error_code, message) + CallError.new(status_code, error, headers) + end + + private def render_error_resp(message, status_code : Symbol, error_code = ErrorCode::InvalidRequest, headers = HTTP::Headers.new) + headers.add("Content-Type", CONTENT_TYPE_JSON) + error = create_error(message, status_code, error_code, headers) + render_error_resp(error) + end + + private def render_error_resp(error : CallError) + response.headers.merge!(error.headers) + + render ActionController::Responders::STATUS_CODES[error.status_code], json: error.error.to_json + end + + private def handle_tool_list_execute_error(error : Driver::Proxy::RemoteDriver::Error) + message = error.error_code.to_s.gsub('_', ' ') + Log.context.set( + error: message, + sys_id: error.system_id, + module_name: error.module_name, + index: error.index, + remote_backtrace: error.remote_backtrace, + ) + + status, error_code = case error.error_code + in DriverError::ModuleNotFound, DriverError::SystemNotFound + Log.info { error.message } + {:not_found, ErrorCode::InvalidRequest} + in DriverError::ParseError, DriverError::BadRequest, DriverError::UnknownCommand + Log.error { error.message } + {:bad_request, ErrorCode::InvalidRequest} + in DriverError::RequestFailed, DriverError::UnexpectedFailure + Log.info { error.message } + {:internal_server_error, ErrorCode::InternalError} + in DriverError::AccessDenied + Log.info { error.message } + {:unauthorized, ErrorCode::InvalidRequest} + end + + msg = { + error: message, + sys_id: error.system_id, + module_name: error.module_name, + index: error.index, + message: error.message, + }.to_json + + create_error(msg, status, error_code) + end + end +end diff --git a/src/placeos-rest-api/controllers/mcp/models.cr b/src/placeos-rest-api/controllers/mcp/models.cr new file mode 100644 index 00000000..31001369 --- /dev/null +++ b/src/placeos-rest-api/controllers/mcp/models.cr @@ -0,0 +1,320 @@ +require "json" + +module PlaceOS::Api + module MCPModels + # Protocol constants + LATEST_PROTOCOL_VERSION = "2025-06-18" + DEFAULT_NEGOTIATED_PROTOCOL_VERSION = "2025-03-26" + SUPPORTED_PROTOCOL_VERSIONS = [ + LATEST_PROTOCOL_VERSION, + "2025-03-26", + "2024-11-05", + "2024-10-07", + ] + + JSONRPC_VERSION = "2.0" + + # Maximum size for incoming messages + MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + # Header names + MCP_SESSION_ID_HEADER = "mcp-session-id" + MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" + LAST_EVENT_ID_HEADER = "last-event-id" + + # Content types + CONTENT_TYPE_JSON = "application/json" + CONTENT_TYPE_SSE = "text/event-stream" + + # JSON-RPC types + alias ProgressToken = String | Int64 + alias Cursor = String + alias RequestId = String | Int64 + alias ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource + EmptyJsonObject = Hash(String, JSON::Any).new + # Enums for constrained values + enum ErrorCode + ConnectionClosed = -32000 + RequestTimeout = -32001 + ParseError = -32700 + InvalidRequest = -32600 + MethodNotFound = -32601 + InvalidParams = -32602 + InternalError = -32603 + end + + struct JSONRPCError + include JSON::Serializable + getter jsonrpc : String + getter id : RequestId + getter error : ErrorDetail + + def initialize(@id, @error, @jsonrpc = JSONRPC_VERSION) + end + + def self.new(id : RequestId, code : ErrorCode, msg : String) + details = ErrorDetail.new(code, msg) + new(id, details) + end + end + + struct ErrorDetail + include JSON::Serializable + getter code : ErrorCode + getter message : String + getter data : JSON::Any? + + def initialize(@code, @message, @data = nil) + end + end + + abstract struct Result + include JSON::Serializable + include JSON::Serializable::Unmapped + + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@meta = nil) + end + end + + struct JSONRPCResponse + include JSON::Serializable + getter jsonrpc : String + getter id : RequestId + getter result : Result + + def initialize(@id, @result, @jsonrpc = JSONRPC_VERSION) + end + end + + struct EmptyResult < Result + def initialize + super + end + end + + struct InitializeResult < Result + @[JSON::Field(key: "protocolVersion")] + getter protocol_version : String + getter capabilities : ServerCapabilities + @[JSON::Field(key: "serverInfo")] + getter server_info : Implementation + getter instructions : String? + + def initialize(@protocol_version, @capabilities, @server_info, @instructions = nil, @meta = nil) + end + end + + struct ListToolResult < Result + getter tools : Array(Tool) + @[JSON::Field(key: "nextCursor")] + getter next_cursor : Cursor? + + def initialize(@tools, @next_cursor = nil, @meta = nil) + end + end + + struct CallToolResult < Result + getter content : Array(ContentBlock) + @[JSON::Field(key: "structuredContent")] + getter structured_content : Hash(String, JSON::Any)? + @[JSON::Field(key: "isError")] + getter is_error : Bool? + + def initialize(@content = [] of ContentBlock, @structured_content = nil, + @is_error = nil, @meta = nil) + end + end + + struct Implementation + include JSON::Serializable + include JSON::Serializable::Unmapped + getter name : String + getter version : String + getter title : String? + + def initialize(@name, @version, @title = nil) + end + end + + struct ServerCapabilities + include JSON::Serializable + include JSON::Serializable::Unmapped + getter experimental : Hash(String, JSON::Any)? + getter logging : Hash(String, JSON::Any)? + getter completions : Hash(String, JSON::Any)? + getter prompts : Capability? + getter resources : Capability? + getter tools : Capability? + + def initialize(@experimental = nil, @logging = nil, @completions = nil, + @prompts = nil, @resources = nil, @tools = nil) + end + + def self.new_capability(list_changed : Bool, subscribe : Bool? = nil) + Capability.new(list_changed, subscribe) + end + + struct Capability + include JSON::Serializable + include JSON::Serializable::Unmapped + getter subscribe : Bool? + @[JSON::Field(key: "listChanged")] + getter list_changed : Bool? + + def initialize(@list_changed = nil, @subscribe = nil) + end + end + end + + struct TextContent + include JSON::Serializable + include JSON::Serializable::Unmapped + getter type : String + getter text : String + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@text, @meta = nil, @type = "text") + end + end + + struct ImageContent + include JSON::Serializable + include JSON::Serializable::Unmapped + getter type : String + getter data : String # Base64 + @[JSON::Field(key: "mimeType")] + getter mime_type : String + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@data, @mime_type, @meta = nil, @type = "image") + end + end + + struct AudioContent + include JSON::Serializable + include JSON::Serializable::Unmapped + getter type : String + getter data : String # Base64 + @[JSON::Field(key: "mimeType")] + getter mime_type : String + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@data, @mime_type, @meta = nil, @type = "audio") + end + end + + struct EmbeddedResource + include JSON::Serializable + include JSON::Serializable::Unmapped + getter type : String + getter resource : TextResourceContents | BlobResourceContents + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@resource, @meta = nil, @type = "resource") + end + end + + struct ResourceLink + include JSON::Serializable + include JSON::Serializable::Unmapped + getter type : String + getter name : String + getter uri : String + getter description : String? + getter title : String? + @[JSON::Field(key: "mimeType")] + getter mime_type : String? + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@name, @uri, @description = nil, @title = nil, @mime_type = nil, @meta = nil, @type = "resource_link") + end + end + + struct ToolSchema + include JSON::Serializable + include JSON::Serializable::Unmapped + @type : String + getter properties : Hash(String, JSON::Any)? + getter required : Array(String)? + + def initialize(@properties = EmptyJsonObject, @required = nil) + @type = "object" + end + end + + struct ToolAnnotations + include JSON::Serializable + include JSON::Serializable::Unmapped + getter title : String? + @[JSON::Field(key: "readOnlyHint")] + getter read_only_hint : Bool? + @[JSON::Field(key: "destructiveHint")] + getter destructive_hint : Bool? + @[JSON::Field(key: "idempotentHint")] + getter idempotent_hint : Bool? + @[JSON::Field(key: "openWorldHint")] + getter open_world_hint : Bool? + + def initialize(@title = nil, @read_only_hint = nil, @destructive_hint = nil, + @idempotent_hint = nil, @open_world_hint = nil) + end + end + + struct Tool + include JSON::Serializable + include JSON::Serializable::Unmapped + getter name : String + getter description : String? + getter title : String? + @[JSON::Field(key: "inputSchema")] + getter input_schema : ToolSchema + @[JSON::Field(key: "outputSchema")] + getter output_schema : ToolSchema? + getter annotations : ToolAnnotations? + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@name, @input_schema, @title = nil, @description = nil, + @output_schema = nil, @annotations = nil, @meta = nil) + end + end + + abstract struct ResourceContents + include JSON::Serializable + include JSON::Serializable::Unmapped + getter uri : String + @[JSON::Field(key: "mimeType")] + getter mime_type : String? + @[JSON::Field(key: "_meta")] + getter meta : Hash(String, JSON::Any)? + + def initialize(@uri, @mime_type = nil, @meta = nil) + end + end + + struct TextResourceContents < ResourceContents + getter text : String + + def initialize(uri, text, mime_type = nil, meta = nil) + super(uri, mime_type, meta) + @text = text + end + end + + struct BlobResourceContents < ResourceContents + getter blob : String # Base64 encoded + + def initialize(uri, blob, mime_type = nil, meta = nil) + super(uri, mime_type, meta) + @blob = blob + end + end + end +end diff --git a/src/placeos-rest-api/controllers/mcp/sse.cr b/src/placeos-rest-api/controllers/mcp/sse.cr new file mode 100644 index 00000000..63e1d43e --- /dev/null +++ b/src/placeos-rest-api/controllers/mcp/sse.cr @@ -0,0 +1,134 @@ +require "http" +require "http/server/handler" + +module PlaceOS::Api + module SSE + # Represents an SSE client connection + class Connection + @mutex = Mutex.new + @closed = false + @closed_channel = Channel(Nil).new(1) + property on_close : Proc(Nil) = -> { } + + def initialize(@io : IO) + spawn_reader + end + + # Spawns reader fiber to detect client disconnects + private def spawn_reader + spawn do + begin + # Read any incoming data to detect disconnects + buffer = Bytes.new(128) + while @io.read(buffer) > 0 + # SSE clients shouldn't send data except for initial request + end + rescue IO::EOFError | IO::Error + # Normal disconnect + ensure + close + end + end + end + + # Close connection and clean up + def close + return if @closed + @closed = true + @on_close.call + @io.close rescue nil + @closed_channel.send(nil) rescue nil + end + + # Check if connection is closed + def closed? + @closed + end + + # Wait until connection is closed + def wait + @closed_channel.receive unless closed? + end + + # Send SSE-formatted message + def send(data : String, id : String? = nil, event : String? = nil, retry : Int32? = nil) + @mutex.synchronize do + return if closed? + build_message(event, data, id, retry) + @io.flush + end + rescue ex : IO::Error + close + end + + # Format message according to SSE spec + private def build_message(event, data, id, retry) + @io << "id: #{id.gsub(/\R/, " ")}\n" if id + @io << "retry: #{retry}\n" if retry + @io << "event: #{event.gsub(/\R/, " ")}\n" if event + + data.each_line do |line| + @io << "data: #{line.chomp("\r")}\n" + end + + @io << '\n' + end + end + + # Manages multiple SSE connections + class SSEChannel + @connections = [] of Connection + @mutex = Mutex.new + + # Add connection to channel + def add(connection) + @mutex.synchronize do + @connections << connection + connection.on_close = -> { remove(connection) } + end + end + + # Remove connection from channel + def remove(connection) + @mutex.synchronize do + @connections.delete(connection) + end + end + + # Broadcast message to all connections + def broadcast(data : String, id = nil, event = nil, retry = nil) + @mutex.synchronize do + @connections.reject! do |conn| + if conn.closed? + true + else + conn.send(event, data, id, retry) rescue true + false + end + end + end + end + + # Get current connection count + def size + @mutex.synchronize { @connections.size } + end + end + + # Helper to upgrade HTTP response to SSE + def self.upgrade_response(response : HTTP::Server::Response, &block : Connection ->) + # Set required SSE headers + response.headers["Content-Type"] = "text/event-stream" + response.headers["Cache-Control"] = "no-cache" + response.headers["Connection"] = "keep-alive" + response.status = HTTP::Status::OK + + # Upgrade connection + response.upgrade do |io| + conn = Connection.new(io) + block.call(conn) # Pass connection to block + conn.wait # Keep fiber alive until connection closes + end + end + end +end