diff --git a/src/generalagents/action.py b/src/generalagents/action.py index 453968d..1c23c0b 100644 --- a/src/generalagents/action.py +++ b/src/generalagents/action.py @@ -105,6 +105,7 @@ class ActionDrag: class ActionScroll: kind: Literal["scroll"] scroll_delta: int + coordinate: Coordinate @dataclass diff --git a/src/generalagents/agent.py b/src/generalagents/agent.py index ee3a06d..3f17697 100644 --- a/src/generalagents/agent.py +++ b/src/generalagents/agent.py @@ -6,7 +6,7 @@ import httpx from PIL import Image -from generalagents.action import Action +from generalagents.action import Action, ActionKind class Session: @@ -30,7 +30,12 @@ def __init__( self.previous_actions = [] self.temperature = temperature - def plan(self, observation: Image.Image) -> Action: + def plan( + self, + observation: Image.Image, + *, + allowed_action_kinds: list[ActionKind] | None = None, + ) -> Action: buffer = BytesIO() observation.save(buffer, format="WEBP") image_url = f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf8')}" @@ -41,6 +46,7 @@ def plan(self, observation: Image.Image) -> Action: "image_url": image_url, "previous_actions": self.previous_actions[-self.max_previous_actions :], "temperature": self.temperature, + "allowed_action_kinds": allowed_action_kinds, } res = self.client.post("/v1/control/predict", json=data) diff --git a/src/generalagents/macos/computer.py b/src/generalagents/macos/computer.py index caa1ad0..cdd6e82 100644 --- a/src/generalagents/macos/computer.py +++ b/src/generalagents/macos/computer.py @@ -76,7 +76,8 @@ def _execute_action(self, action: Action) -> None: pyautogui.moveTo(*self._scaled(start)) pyautogui.dragTo(*self._scaled(end), duration=0.5) - case ActionScroll(kind="scroll", scroll_delta=delta): + case ActionScroll(kind="scroll", scroll_delta=delta, coordinate=coord): + pyautogui.moveTo(*self._scaled(coord)) pyautogui.scroll(float(delta * self.scale_factor)) case ActionWait(kind="wait"):