diff --git a/python/agents/personalized-shopping/deployment/deploy.py b/python/agents/personalized-shopping/deployment/deploy.py index 069a4350d..0cd175f5e 100644 --- a/python/agents/personalized-shopping/deployment/deploy.py +++ b/python/agents/personalized-shopping/deployment/deploy.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import vertexai -from vertexai.preview.reasoning_engines import AdkApp -from vertexai import agent_engines -from dotenv import load_dotenv import os +import vertexai +from dotenv import load_dotenv from personalized_shopping.agent import root_agent +from vertexai import agent_engines +from vertexai.preview.reasoning_engines import AdkApp load_dotenv() diff --git a/python/agents/personalized-shopping/personalized_shopping/__init__.py b/python/agents/personalized-shopping/personalized_shopping/__init__.py index 6d8e36d1d..5bf2451fe 100644 --- a/python/agents/personalized-shopping/personalized_shopping/__init__.py +++ b/python/agents/personalized-shopping/personalized_shopping/__init__.py @@ -21,7 +21,7 @@ os.environ["GOOGLE_CLOUD_LOCATION"] = "global" os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True") -import torch +import torch # noqa: E402 # Workaround to Resolve the PyTorch-Streamlit Incompatibility Issue torch.classes.__path__ = [] @@ -29,9 +29,9 @@ # Initialize webshop environment (requires Java) # If Java is not available (e.g., in CI), set webshop_env to None try: - from .shared_libraries.init_env import init_env, webshop_env + from .shared_libraries.init_env import init_env, webshop_env # noqa: E402 except Exception: webshop_env = None init_env = None -from . import agent +from . import agent # noqa: F401, E402 diff --git a/python/agents/personalized-shopping/personalized_shopping/agent.py b/python/agents/personalized-shopping/personalized_shopping/agent.py index 578cb5f1b..b733b6566 100644 --- a/python/agents/personalized-shopping/personalized_shopping/agent.py +++ b/python/agents/personalized-shopping/personalized_shopping/agent.py @@ -15,10 +15,9 @@ from google.adk.agents import Agent from google.adk.tools import FunctionTool -from .tools.search import search -from .tools.click import click - from .prompt import personalized_shopping_agent_instruction +from .tools.click import click +from .tools.search import search root_agent = Agent( model="gemini-2.5-flash", diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/init_env.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/init_env.py index fb6293f68..4d92121df 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/init_env.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/init_env.py @@ -41,5 +41,7 @@ def get_webshop_env(): if _webshop_env is None: _webshop_env = init_env(num_product_items) _webshop_env.reset() - print(f"Finished initializing WebshopEnv with {num_product_items} items.") + print( + f"Finished initializing WebshopEnv with {num_product_items} items." + ) return _webshop_env diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/search_engine/convert_product_file_format.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/search_engine/convert_product_file_format.py index fd9c978e7..0e870c1b1 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/search_engine/convert_product_file_format.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/search_engine/convert_product_file_format.py @@ -14,11 +14,12 @@ import json import sys + from tqdm import tqdm sys.path.insert(0, "../") -from web_agent_site.engine.engine import load_products +from web_agent_site.engine.engine import load_products # noqa: E402 all_products, *_ = load_products(filepath="../data/items_shuffle.json") diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/__init__.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/__init__.py index 937bbf92f..8cab9c749 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/__init__.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .envs.web_agent_text_env import WebAgentTextEnv +from .envs.web_agent_text_env import WebAgentTextEnv # noqa: F401 diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/engine.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/engine.py index 2e3e28bef..38237d875 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/engine.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/engine.py @@ -14,13 +14,13 @@ """ """ -from ast import literal_eval -from collections import defaultdict -from decimal import Decimal import json import os import random import re +from ast import literal_eval +from collections import defaultdict +from decimal import Decimal from flask import render_template_string from pyserini.search.lucene import LuceneSearcher @@ -179,7 +179,9 @@ def get_top_n_product_from_keywords( docs = [search_engine.doc(hit.docid) for hit in hits] top_n_asins = [json.loads(doc.raw())["id"] for doc in docs] top_n_products = [ - product_item_dict[asin] for asin in top_n_asins if asin in product_item_dict + product_item_dict[asin] + for asin in top_n_asins + if asin in product_item_dict ] return top_n_products @@ -334,7 +336,10 @@ def load_products(filepath, num_products=None, human_goals=True): option_values = [] for option_content in option_contents: option_value = ( - option_content["value"].strip().replace("/", " | ").lower() + option_content["value"] + .strip() + .replace("/", " | ") + .lower() ) option_image = option_content.get("image", None) @@ -364,7 +369,9 @@ def load_products(filepath, num_products=None, human_goals=True): if asin in human_attributes: products[i]["instructions"] = human_attributes[asin] else: - products[i]["instruction_text"] = attributes[asin].get("instruction", None) + products[i]["instruction_text"] = attributes[asin].get( + "instruction", None + ) products[i]["instruction_attributes"] = attributes[asin].get( "instruction_attributes", None diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/goal.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/goal.py index 2d2fc9333..de9523d5c 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/goal.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/engine/goal.py @@ -14,12 +14,14 @@ """Functions for specifying goals and reward calculations.""" -from collections import defaultdict import itertools import random -from rich import print +from collections import defaultdict + import spacy +from rich import print from thefuzz import fuzz + from .normalize import normalize_color nlp = spacy.load("en_core_web_sm") @@ -53,7 +55,9 @@ def get_human_goals(all_products, product_prices): price_range = [p for p in PRICE_RANGE if p > price][:4] if len(price_range) >= 2: _, price_upper = sorted(random.sample(price_range, 2)) - price_text = f", and price lower than {price_upper:.2f} dollars" + price_text = ( + f", and price lower than {price_upper:.2f} dollars" + ) else: price_upper = 1000000 price_text = "" @@ -67,7 +71,8 @@ def get_human_goals(all_products, product_prices): "query": item["query"], "name": item["name"], "product_category": item["product_category"], - "instruction_text": product["instruction"].strip(".") + price_text, + "instruction_text": product["instruction"].strip(".") + + price_text, "attributes": attributes, "price_upper": price_upper, "goal_options": product["instruction_options"], @@ -86,7 +91,10 @@ def get_synthetic_goals(all_products, product_prices): goals = [] cnt_atts = defaultdict(int) for product in all_products: - if "instruction_text" not in product or product["instruction_text"] is None: + if ( + "instruction_text" not in product + or product["instruction_text"] is None + ): continue product_goals = [] asin = product["asin"] @@ -111,14 +119,18 @@ def get_synthetic_goals(all_products, product_prices): options = product["options"] option_names = sorted(options) combinations = list( - itertools.product(*(options[option_name] for option_name in option_names)) + itertools.product( + *(options[option_name] for option_name in option_names) + ) ) for combination in combinations: goal_options = dict() for i, o in enumerate(combination): # option_text.append(f'{option_names[i]}: {o}') goal_options[option_names[i]] = o - option_text = ", and ".join([f"{k}: {v}" for k, v in goal_options.items()]) + option_text = ", and ".join( + [f"{k}: {v}" for k, v in goal_options.items()] + ) option_text = " with " + option_text if option_text else "" product_goals.append( { @@ -138,9 +150,9 @@ def get_synthetic_goals(all_products, product_prices): cnt_atts[att] += 1 goals += product_goals for goal in goals: - goal["weight"] = sum(1.0 / cnt_atts[att] for att in goal["attributes"]) / len( - goal["attributes"] - ) + goal["weight"] = sum( + 1.0 / cnt_atts[att] for att in goal["attributes"] + ) / len(goal["attributes"]) return goals @@ -152,7 +164,9 @@ def get_type_reward(purchased_product, goal): purchased_product_category = [ x.strip() for x in purchased_product["product_category"].split("›") ] - goal_product_category = [x.strip() for x in goal["product_category"].split("›")] + goal_product_category = [ + x.strip() for x in goal["product_category"].split("›") + ] category_match = ( len(set(purchased_product_category) & set(goal_product_category)) >= 2 ) @@ -245,7 +259,11 @@ def get_option_reward(purchased_options, goal_options): break # Calculate option reward as fraction of goal options hit - r_option = num_option_matches / len(goal_options) if len(goal_options) > 0 else None + r_option = ( + num_option_matches / len(goal_options) + if len(goal_options) > 0 + else None + ) return r_option, num_option_matches @@ -253,7 +271,9 @@ def get_reward(purchased_product, goal, price, options, **kwargs): """Get cumulative reward score for purchased product and goal""" r_type_dict = get_type_reward(purchased_product, goal) - r_price = (price <= goal["price_upper"]) if goal["price_upper"] > 0 else None + r_price = ( + (price <= goal["price_upper"]) if goal["price_upper"] > 0 else None + ) r_att, num_attr_matches = get_attribute_reward(purchased_product, goal) diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/envs/web_agent_text_env.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/envs/web_agent_text_env.py index d8ecc303b..cd7c568f2 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/envs/web_agent_text_env.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/envs/web_agent_text_env.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict import json import random import string import time +from collections import defaultdict + +import gym +import numpy as np +import torch from bs4 import BeautifulSoup from bs4.element import Comment from flask import Flask -import gym from gym.envs.registration import register -import numpy as np -import torch + from ..engine.engine import ( ACTION_TO_TEMPLATE, BACK_TO_SEARCH, @@ -45,7 +47,6 @@ random_idx, ) - app = Flask(__name__) @@ -124,7 +125,11 @@ def step(self, action): action_name, action_arg = parse_action(action) if action_arg is not None: action_arg = action_arg.lower() - if action_name == "search" and action_arg is not None and action_arg != "": + if ( + action_name == "search" + and action_arg is not None + and action_arg != "" + ): status = self.browser.search(action_arg) elif ( action_name == "click" @@ -214,7 +219,9 @@ def observation(self): elif self.observation_mode == "url": return self.state["url"] else: - raise ValueError(f"Observation mode {self.observation_mode} not supported.") + raise ValueError( + f"Observation mode {self.observation_mode} not supported." + ) @property def state(self): @@ -246,13 +253,20 @@ def convert_html_to_text(self, html, simple=False): processed_t = f"[button] {t} [button_]" elif t.parent.name == "label": # options if f'"{t}"' in self.state["url"]: - processed_t = f" [clicked button] {t} [clicked button_]" + processed_t = ( + f" [clicked button] {t} [clicked button_]" + ) observation = f"You have clicked {t}.\n" + observation else: processed_t = f" [button] {t} [button_]" elif t.parent.get("class") == ["product-link"]: # product asins - if f"{t}" in self.server.user_sessions[self.session]["asins"]: - processed_t = f"\n[clicked button] {t} [clicked button_]" + if ( + f"{t}" + in self.server.user_sessions[self.session]["asins"] + ): + processed_t = ( + f"\n[clicked button] {t} [clicked button_]" + ) else: processed_t = f"\n[button] {t} [button_]" else: # regular, unclickable text @@ -273,7 +287,9 @@ def reset(self, session=None, instruction_text=None): self.session = self.session_prefix + self.session init_url = f"{self.base_url}/{self.session}" - self.browser.get(init_url, session_id=self.session, session_int=session_int) + self.browser.get( + init_url, session_id=self.session, session_int=session_int + ) self.text_to_clickable = None self.instruction_text = ( @@ -295,7 +311,9 @@ def close(self): def tag_visible(element): ignore = {"style", "script", "head", "title", "meta", "[document]"} - return element.parent.name not in ignore and not isinstance(element, Comment) + return element.parent.name not in ignore and not isinstance( + element, Comment + ) class SimServer: @@ -332,7 +350,9 @@ def __init__( ) ) self.search_engine = init_search_engine(num_products=num_products) - self.goals = get_goals(self.all_products, self.product_prices, human_goals) + self.goals = get_goals( + self.all_products, self.product_prices, human_goals + ) self.show_attrs = show_attrs # Fix outcome for random shuffling of goals @@ -342,7 +362,9 @@ def __init__( # Apply `filter_goals` parameter if exists to select speific goal(s) if filter_goals is not None: self.goals = [ - goal for (i, goal) in enumerate(self.goals) if filter_goals(i, goal) + goal + for (i, goal) in enumerate(self.goals) + if filter_goals(i, goal) ] # Imposes `limit` on goals via random selection @@ -561,7 +583,9 @@ def receive(self, session_id, current_url, session_int=None, **kwargs): if session_id not in self.user_sessions: idx = ( session_int - if (session_int is not None and isinstance(session_int, int)) + if ( + session_int is not None and isinstance(session_int, int) + ) else random_idx(self.cum_weights) ) goal = self.goals[idx] @@ -679,7 +703,9 @@ def __init__(self, server): def get(self, url, session_id=None, session_int=None): """Set browser variables to corresponding link, page HTML for URL""" - self.session_id = url.split("/")[-1] if session_id is None else session_id + self.session_id = ( + url.split("/")[-1] if session_id is None else session_id + ) self.page_source, _, _ = self.server.receive( self.session_id, self.current_url, session_int=session_int ) diff --git a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/utils.py b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/utils.py index aaebe62a7..3206119a6 100644 --- a/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/utils.py +++ b/python/agents/personalized-shopping/personalized_shopping/shared_libraries/web_agent_site/utils.py @@ -15,8 +15,8 @@ import bisect import hashlib import logging -from os.path import abspath, dirname, join import random +from os.path import abspath, dirname, join BASE_DIR = dirname(abspath(__file__)) DEBUG_PROD_SIZE = None # set to `None` to disable @@ -49,7 +49,9 @@ def setup_logger(session_id, user_log_dir): """Creates a log file and logging object for the corresponding session ID""" logger = logging.getLogger(session_id) formatter = logging.Formatter("%(message)s") - file_handler = logging.FileHandler(user_log_dir / f"{session_id}.jsonl", mode="w") + file_handler = logging.FileHandler( + user_log_dir / f"{session_id}.jsonl", mode="w" + ) file_handler.setFormatter(formatter) logger.setLevel(logging.INFO) logger.addHandler(file_handler)