From a0ba941e233e21ff0630a45889777c0755c0b6d2 Mon Sep 17 00:00:00 2001 From: Jurgis Pasukonis Date: Fri, 28 Mar 2025 18:41:01 -0700 Subject: [PATCH] Set allowed_action_kinds --- src/generalagents/agent.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/generalagents/agent.py b/src/generalagents/agent.py index ee3a06d..3f17697 100644 --- a/src/generalagents/agent.py +++ b/src/generalagents/agent.py @@ -6,7 +6,7 @@ import httpx from PIL import Image -from generalagents.action import Action +from generalagents.action import Action, ActionKind class Session: @@ -30,7 +30,12 @@ def __init__( self.previous_actions = [] self.temperature = temperature - def plan(self, observation: Image.Image) -> Action: + def plan( + self, + observation: Image.Image, + *, + allowed_action_kinds: list[ActionKind] | None = None, + ) -> Action: buffer = BytesIO() observation.save(buffer, format="WEBP") image_url = f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf8')}" @@ -41,6 +46,7 @@ def plan(self, observation: Image.Image) -> Action: "image_url": image_url, "previous_actions": self.previous_actions[-self.max_previous_actions :], "temperature": self.temperature, + "allowed_action_kinds": allowed_action_kinds, } res = self.client.post("/v1/control/predict", json=data)