diff --git a/agent/agent.py b/agent/agent.py index d561238..908feee 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -3,8 +3,6 @@ import json from typing import Any import tiktoken -from beartype import beartype -from beartype.door import is_bearable from agent.prompts import * from browser_env import Trajectory @@ -48,11 +46,9 @@ class TeacherForcingAgent(Agent): def __init__(self) -> None: super().__init__() - @beartype def set_action_set_tag(self, tag: str) -> None: self.action_set_tag = tag - @beartype def set_actions(self, action_seq: str | list[str]) -> None: if isinstance(action_seq, str): action_strs = action_seq.strip().split("\n") @@ -79,14 +75,12 @@ class TeacherForcingAgent(Agent): self.actions: list[Action] = actions - @beartype def next_action( self, trajectory: Trajectory, intent: str, meta_data: Any ) -> Action: """Predict the next action given the observation""" return self.actions.pop(0) - @beartype def reset( self, test_config_file: str, @@ -113,11 +107,9 @@ class PromptAgent(Agent): self.prompt_constructor = prompt_constructor self.action_set_tag = action_set_tag - @beartype def set_action_set_tag(self, tag: str) -> None: self.action_set_tag = tag - @beartype def next_action( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] ) -> Action: diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 23419c1..6e2d3cb 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, TypedDict import tiktoken -from beartype import beartype from browser_env import Action, ActionParsingError, Trajectory from browser_env.env_config import URL_MAPPINGS @@ -38,7 +37,6 @@ class PromptConstructor(object): self.instruction: Instruction = instruction self.tokenizer = tokenizer - @beartype def get_lm_api_input( self, intro: str, examples: list[tuple[str, str]], current: str ) -> APIInput: @@ -84,7 +82,6 @@ class PromptConstructor(object): f"Provider {self.lm_config.provider} not implemented" ) - @beartype def construct( self, trajectory: Trajectory, @@ -93,7 +90,6 @@ class PromptConstructor(object): ) -> APIInput: raise NotImplementedError - @beartype def map_url_to_real(self, url: str) -> str: """Map the urls to their real world counterparts""" for i, j in URL_MAPPINGS.items(): @@ -101,7 +97,6 @@ class PromptConstructor(object): url = url.replace(i, j) return url - @beartype def map_url_to_local(self, url: str) -> str: """Map the urls to their local counterparts""" for i, j in URL_MAPPINGS.items(): @@ -109,11 +104,9 @@ class PromptConstructor(object): url = url.replace(j, i) return url - @beartype def _extract_action(self, response: str) -> str: raise NotImplementedError - @beartype def extract_action(self, response: str) -> str: response = self._extract_action(response) response = self.map_url_to_local(response) @@ -131,7 +124,6 @@ class DirectPromptConstructor(PromptConstructor): ): super().__init__(instruction_path, lm_config, tokenizer) - @beartype def construct( self, trajectory: Trajectory, @@ -167,7 +159,6 @@ class DirectPromptConstructor(PromptConstructor): prompt = self.get_lm_api_input(intro, examples, current) return prompt - @beartype def _extract_action(self, response: str) -> str: action_splitter = self.instruction["meta_data"]["action_splitter"] pattern = rf"{action_splitter}(.*?){action_splitter}" @@ -192,7 +183,6 @@ class CoTPromptConstructor(PromptConstructor): super().__init__(instruction_path, lm_config, tokenizer) self.answer_phrase = self.instruction["meta_data"]["answer_phrase"] - @beartype def construct( self, trajectory: Trajectory, @@ -225,7 +215,6 @@ class CoTPromptConstructor(PromptConstructor): prompt = self.get_lm_api_input(intro, examples, current) return prompt - @beartype def _extract_action(self, response: str) -> str: # find the first occurence of action action_splitter = self.instruction["meta_data"]["action_splitter"] diff --git a/browser_env/actions.py b/browser_env/actions.py index 6dbc21c..60f941a 100644 --- a/browser_env/actions.py +++ b/browser_env/actions.py @@ -12,8 +12,6 @@ from typing import Any, TypedDict, Union, cast import numpy as np import numpy.typing as npt -from beartype import beartype -from beartype.door import is_bearable from gymnasium import spaces from playwright._impl._api_structures import ViewportSize from playwright.async_api import BrowserContext as ABrowserContext @@ -55,7 +53,6 @@ from browser_env.processors import ( ) -@beartype def is_in_viewport( element: Locator, viewport: ViewportSize, threshold: float = 0.3 ) -> bool: @@ -75,7 +72,6 @@ def is_in_viewport( return ratio > threshold -@beartype async def async_is_in_viewport( element: ALocator, viewport: ViewportSize, threshold: float = 0.3 ) -> bool: @@ -111,7 +107,6 @@ class Action(TypedDict): raw_prediction: str # raw prediction from the model -@beartype def action2str( action: Action, action_set_tag: str, semantic_element: str = "" ) -> str: @@ -274,7 +269,6 @@ class ActionTypes(IntEnum): return f"ACTION_TYPES.{self.name}" -@beartype def is_equivalent(a: Action, b: Action) -> bool: """Return True if two actions are equal.""" if a["action_type"] != b["action_type"]: @@ -338,12 +332,11 @@ _role2id: dict[RolesType, int] = { _id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type] -@beartype def _keys2ids(keys: list[int | str] | str) -> list[int]: return list( map( lambda key: _key2id[str(key)] - if is_bearable(key, str) + if isinstance(key, str) else int(key), keys, ) @@ -424,7 +417,6 @@ def create_random_action() -> Action: } -@beartype def create_none_action() -> Action: """Return a valid action object that does nothing.""" return { @@ -445,14 +437,12 @@ def create_none_action() -> Action: } -@beartype def create_stop_action(answer: str) -> Action: action = create_none_action() action.update({"action_type": ActionTypes.STOP, "answer": answer}) return action -@beartype def create_scroll_action(direction: str) -> Action: """Return the playwright action""" assert direction in ["up", "down"] @@ -466,7 +456,6 @@ def create_scroll_action(direction: str) -> Action: return action -@beartype def create_mouse_hover_action( left: float | None = None, top: float | None = None ) -> Action: @@ -481,7 +470,6 @@ def create_mouse_hover_action( return action -@beartype def create_key_press_action(key_comb: str) -> Action: """Return the key press action""" @@ -504,7 +492,6 @@ def create_key_press_action(key_comb: str) -> Action: return action -@beartype def create_page_focus_action(page_number: int) -> Action: """Return a valid action object with type PAGE_FOCUS.""" action = create_none_action() @@ -517,7 +504,6 @@ def create_page_focus_action(page_number: int) -> Action: return action -@beartype def create_new_tab_action() -> Action: """Return a valid action object with type NEW_TAB.""" action = create_none_action() @@ -529,7 +515,6 @@ def create_new_tab_action() -> Action: return action -@beartype def create_go_back_action() -> Action: """Return a valid action object with type GO_BACK.""" action = create_none_action() @@ -541,7 +526,6 @@ def create_go_back_action() -> Action: return action -@beartype def create_go_forward_action() -> Action: """Return a valid action object with type GO_FORWARD.""" action = create_none_action() @@ -553,7 +537,6 @@ def create_go_forward_action() -> Action: return action -@beartype def create_goto_url_action(url: str) -> Action: """Return a valid action object with type GOTO_URL.""" action = create_none_action() @@ -566,7 +549,6 @@ def create_goto_url_action(url: str) -> Action: return action -@beartype def create_page_close_action() -> Action: """Return a valid action object with type PAGE_CLOSE.""" action = create_none_action() @@ -578,7 +560,6 @@ def create_page_close_action() -> Action: return action -@beartype def create_mouse_click_action( left: float | None = None, top: float | None = None ) -> Action: @@ -602,7 +583,6 @@ def create_mouse_click_action( return action -@beartype def create_keyboard_type_action(keys: list[int | str] | str) -> Action: """Return a valid action object with type TYPE.""" action = create_none_action() @@ -615,7 +595,6 @@ def create_keyboard_type_action(keys: list[int | str] | str) -> Action: return action -@beartype def create_click_action( element_id: str = "", element_role: RolesType = "link", @@ -637,7 +616,6 @@ def create_click_action( return action -@beartype def create_hover_action( element_id: str = "", element_role: RolesType = "link", @@ -659,7 +637,6 @@ def create_hover_action( return action -@beartype def create_type_action( text: str, element_id: str = "", @@ -683,7 +660,6 @@ def create_type_action( return action -@beartype def create_check_action(pw_code: str) -> Action: action = create_none_action() action.update( @@ -695,7 +671,6 @@ def create_check_action(pw_code: str) -> Action: return action -@beartype def create_select_option_action( pw_code: str, ) -> Action: @@ -709,7 +684,6 @@ def create_select_option_action( return action -@beartype def create_focus_action( element_role: RolesType, element_name: str = "", nth: int = 0 ) -> Action: @@ -728,7 +702,6 @@ def create_focus_action( return action -@beartype def create_focus_and_click_action( element_role: RolesType, element_name: str = "", nth: int = 0 ) -> Action: @@ -748,7 +721,6 @@ def create_focus_and_click_action( return action -@beartype def create_focus_and_type_action( keys: list[int | str] | str, element_role: RolesType, @@ -771,7 +743,6 @@ def create_focus_and_type_action( return action -@beartype def execute_scroll(direction: str, page: Page) -> None: # perform the action # code from natbot @@ -785,7 +756,6 @@ def execute_scroll(direction: str, page: Page) -> None: ) -@beartype async def aexecute_scroll(direction: str, page: APage) -> None: # perform the action # code from natbot @@ -799,19 +769,16 @@ async def aexecute_scroll(direction: str, page: APage) -> None: ) -@beartype def execute_key_press(key: str, page: Page) -> None: """Press a key.""" page.keyboard.press(key) -@beartype async def aexecute_key_press(key: str, page: APage) -> None: """Press a key.""" await page.keyboard.press(key) -@beartype def execute_mouse_hover(left: float, top: float, page: Page) -> None: """Click at coordinates (left, top).""" viewport_size = page.viewport_size @@ -821,7 +788,6 @@ def execute_mouse_hover(left: float, top: float, page: Page) -> None: ) -@beartype async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None: """Click at coordinates (left, top).""" viewport_size = page.viewport_size @@ -840,7 +806,6 @@ def execute_mouse_click(left: float, top: float, page: Page) -> None: ) -@beartype async def aexecute_mouse_click(left: float, top: float, page: APage) -> None: """Click at coordinates (left, top).""" viewport_size = page.viewport_size @@ -850,19 +815,16 @@ async def aexecute_mouse_click(left: float, top: float, page: APage) -> None: ) -@beartype def execute_keyboard_type(text: str, page: Page) -> None: """Fill the focused element with text.""" page.keyboard.type(text) -@beartype async def aexecute_keyboard_type(text: str, page: APage) -> None: """Fill the focused element with text.""" await page.keyboard.type(text) -@beartype def execute_click_current(page: Page) -> None: """Click at the current mouse position.""" locators = page.locator("*:focus") @@ -874,7 +836,6 @@ def execute_click_current(page: Page) -> None: locators.click() -@beartype async def aexecute_click_current(page: APage) -> None: """Click at the current mouse position.""" locators = page.locator("*:focus") @@ -889,21 +850,18 @@ async def aexecute_click_current(page: APage) -> None: await page.wait_for_load_state("load") -@beartype def execute_type(keys: list[int], page: Page) -> None: """Send keystrokes to the focused element.""" text = "".join([_id2key[key] for key in keys]) page.keyboard.type(text) -@beartype async def aexecute_type(keys: list[int], page: APage) -> None: """Send keystrokes to the focused element.""" text = "".join([_id2key[key] for key in keys]) await page.keyboard.type(text) -@beartype def execute_focus( element_role: int, element_name: str, nth: int, page: Page ) -> None: @@ -940,7 +898,6 @@ def execute_focus( element_location_list[nth][0].focus() -@beartype async def aexecute_focus( element_role: int, element_name: str, nth: int, page: APage ) -> None: @@ -977,7 +934,6 @@ async def aexecute_focus( await element_location_list[nth][0].focus() -@beartype def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator: locator = page for call in locator_calls: @@ -988,7 +944,6 @@ def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator: return locator # type: ignore[return-value] -@beartype async def alocate( locator_calls: list[ParsedPlaywrightCode], page: APage ) -> ALocator: @@ -1001,7 +956,6 @@ async def alocate( return locator # type: ignore[return-value] -@beartype def execute_playwright_click( locator_code: list[ParsedPlaywrightCode], page: Page, @@ -1014,7 +968,6 @@ def execute_playwright_click( locator.click(*pw_action_args, **pw_action_kwargs) -@beartype async def aexecute_playwright_click( locator_code: list[ParsedPlaywrightCode], page: APage, @@ -1027,7 +980,6 @@ async def aexecute_playwright_click( await locator.click(*pw_action_args, **pw_action_kwargs) -@beartype def execute_playwright_hover( locator_code: list[ParsedPlaywrightCode], page: Page ) -> None: @@ -1037,7 +989,6 @@ def execute_playwright_hover( locator.hover() -@beartype async def aexecute_playwright_hover( locator_code: list[ParsedPlaywrightCode], page: APage ) -> None: @@ -1047,7 +998,6 @@ async def aexecute_playwright_hover( await locator.hover() -@beartype def execute_playwright_type( text: str, locator_code: list[ParsedPlaywrightCode], @@ -1061,7 +1011,6 @@ def execute_playwright_type( locator.type(*pw_action_args, **pw_action_kwargs) -@beartype async def aexecute_playwright_type( text: str, locator_code: list[ParsedPlaywrightCode], @@ -1075,7 +1024,6 @@ async def aexecute_playwright_type( await locator.type(*pw_action_args, **pw_action_kwargs) -@beartype def execute_playwright_select_option( locator_code: list[ParsedPlaywrightCode], page: Page, @@ -1087,7 +1035,6 @@ def execute_playwright_select_option( locator.select_option(*pw_action_args, **pw_action_kwargs) -@beartype async def aexecute_playwright_select_option( locator_code: list[ParsedPlaywrightCode], page: APage, @@ -1099,7 +1046,6 @@ async def aexecute_playwright_select_option( await locator.select_option(*pw_action_args, **pw_action_kwargs) -@beartype def execute_playwright_check( locator_code: list[ParsedPlaywrightCode], page: Page ) -> None: @@ -1108,7 +1054,6 @@ def execute_playwright_check( locator.check() -@beartype async def aexecute_playwright_check( locator_code: list[ParsedPlaywrightCode], page: APage ) -> None: @@ -1117,7 +1062,6 @@ async def aexecute_playwright_check( await locator.check() -@beartype def execute_action( action: Action, page: Page, @@ -1252,7 +1196,6 @@ def execute_action( return page -@beartype async def aexecute_action( action: Action, page: APage, browser_ctx: ABrowserContext ) -> APage: @@ -1383,7 +1326,6 @@ async def aexecute_action( return page -@beartype def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]: # extract function calls if not code.startswith("page."): @@ -1444,14 +1386,12 @@ def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]: return parsed_chain -@beartype class ActionParsingError(Exception): def __init__(self, message: str) -> None: self.message = message super().__init__(self.message) -@beartype def create_playwright_action(playwright_code: str) -> Action: """Main function to return individual playwright action""" # get the last action @@ -1524,7 +1464,6 @@ def create_playwright_action(playwright_code: str) -> Action: raise ActionParsingError(f"Unknown playwright action {action}") -@beartype def create_id_based_action(action_str: str) -> Action: """Main function to return individual id based action""" action_str = action_str.strip() diff --git a/browser_env/async_envs.py b/browser_env/async_envs.py index 312d770..29fb32f 100644 --- a/browser_env/async_envs.py +++ b/browser_env/async_envs.py @@ -5,7 +5,6 @@ from pathlib import Path import numpy as np import numpy.typing as npt -from beartype import beartype from gymnasium import Env from gymnasium.spaces import Box, Text from playwright.async_api import Page, ViewportSize, async_playwright @@ -23,7 +22,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): and observation space is the html content of the page. """ - @beartype def __init__( self, max_page_length: int = 2048, @@ -46,7 +44,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): self.timeout = timeout self.viewport_size = viewport_size - @beartype async def setup(self, config_file: Path | None = None) -> None: self.context_manager = async_playwright() self.playwright = await self.context_manager.__aenter__() @@ -73,7 +70,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): if start_url: await self.page.goto(start_url) - @beartype async def areset( self, *, @@ -104,7 +100,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): {"page": DetachedPage(self.page.url, content)}, ) - @beartype def reset( self, *, @@ -120,7 +115,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): def close(self) -> None: asyncio.run(self.aclose()) - @beartype async def astep( self, action: Action ) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]: @@ -153,7 +147,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]): }, ) - @beartype def step( self, action: Action ) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]: diff --git a/browser_env/auto_login.py b/browser_env/auto_login.py index 689ec32..d466603 100644 --- a/browser_env/auto_login.py +++ b/browser_env/auto_login.py @@ -3,7 +3,6 @@ import glob from itertools import combinations from pathlib import Path -from beartype import beartype from playwright.sync_api import sync_playwright from browser_env.env_config import ( @@ -18,7 +17,6 @@ HEADLESS = True SLOW_MO = 0 -@beartype def is_expired( storage_state: Path, url: str, keyword: str, url_exact: bool = True ) -> bool: @@ -44,7 +42,6 @@ def is_expired( return url not in d_url -@beartype def renew_comb(comb: list[str]) -> None: context_manager = sync_playwright() playwright = context_manager.__enter__() @@ -91,7 +88,6 @@ def renew_comb(comb: list[str]) -> None: context_manager.__exit__() -@beartype def main() -> None: sites = ["gitlab", "shopping", "shopping_admin", "reddit"] urls = [ diff --git a/browser_env/envs.py b/browser_env/envs.py index af8388a..d820502 100644 --- a/browser_env/envs.py +++ b/browser_env/envs.py @@ -8,7 +8,6 @@ from typing import Any, Union import numpy as np import numpy.typing as npt -from beartype import beartype from gymnasium import Env from gymnasium.spaces import Box, Text from playwright.sync_api import ( @@ -39,7 +38,6 @@ class PlaywrightScript: value: str | None = None # avatar movie, Enter -@beartype def parse_action(action: str) -> PlaywrightScript: splitted = action.strip().split(" ") assert len(splitted) >= 2 @@ -73,7 +71,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]): and observation space is the html content of the page. """ - @beartype def __init__( self, max_page_length: int = 8192, @@ -121,7 +118,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]): self.observation_handler.get_observation_space() ) - @beartype def setup(self, config_file: Path | None = None) -> None: self.context_manager = sync_playwright() self.playwright = self.context_manager.__enter__() @@ -168,23 +164,19 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]): client.send("Accessibility.enable") self.page.client = client # type: ignore - @beartype def get_page_client(self, page: Page) -> CDPSession: return page.client # type: ignore - @beartype def _get_obs(self) -> dict[str, Observation]: obs = self.observation_handler.get_observation( self.page, self.get_page_client(self.page) ) return obs - @beartype def _get_obs_metadata(self) -> dict[str, ObservationMetadata]: metadata = self.observation_handler.get_observation_metadata() return metadata - @beartype def reset( self, *, @@ -223,12 +215,10 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]): return (observation, info) - @beartype def save_trace(self, trace_path: str | Path) -> None: if self.save_trace_enabled: self.context.tracing.stop(path=trace_path) - @beartype def close(self) -> None: if self.reset_finished: self.context_manager.__exit__() diff --git a/browser_env/helper_functions.py b/browser_env/helper_functions.py index ac91b30..3c66f70 100644 --- a/browser_env/helper_functions.py +++ b/browser_env/helper_functions.py @@ -5,7 +5,6 @@ import re from pathlib import Path from typing import Any -from beartype import beartype from PIL import Image from agent.prompts import * @@ -35,7 +34,6 @@ HTML_TEMPLATE = """ """ -@beartype def get_render_action( action: Action, observation_metadata: dict[str, ObservationMetadata], @@ -63,7 +61,6 @@ def get_render_action( return action_str -@beartype def get_action_description( action: Action, observation_metadata: dict[str, ObservationMetadata], diff --git a/browser_env/processors.py b/browser_env/processors.py index 76a0371..d4de787 100644 --- a/browser_env/processors.py +++ b/browser_env/processors.py @@ -5,7 +5,6 @@ from typing import Any, TypedDict, Union import numpy as np import numpy.typing as npt -from beartype import beartype from gymnasium import spaces from playwright.sync_api import CDPSession, Page, ViewportSize @@ -60,7 +59,6 @@ class TextObervationProcessor(ObservationProcessor): create_empty_metadata() ) # use the store meta data of this observation type - @beartype def fetch_browser_info( self, page: Page, @@ -108,7 +106,6 @@ class TextObervationProcessor(ObservationProcessor): return info - @beartype @staticmethod def get_bounding_client_rect( client: CDPSession, backend_node_id: str @@ -134,7 +131,6 @@ class TextObervationProcessor(ObservationProcessor): except Exception as e: return {"result": {"subtype": "error"}} - @beartype @staticmethod def get_element_in_viewport_ratio( elem_left_bound: float, @@ -167,7 +163,6 @@ class TextObervationProcessor(ObservationProcessor): ratio = overlap_width * overlap_height / width * height return ratio - @beartype def fetch_page_html( self, info: BrowserInfo, @@ -323,7 +318,6 @@ class TextObervationProcessor(ObservationProcessor): return dom_tree - @beartype @staticmethod def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]: """Parse the html tree into a string text""" @@ -367,7 +361,6 @@ class TextObervationProcessor(ObservationProcessor): html = dfs(0, 0) return html, obs_nodes_info - @beartype def fetch_page_accessibility_tree( self, info: BrowserInfo, @@ -487,7 +480,6 @@ class TextObervationProcessor(ObservationProcessor): return accessibility_tree - @beartype @staticmethod def parse_accessibility_tree( accessibility_tree: AccessibilityTree, @@ -575,7 +567,6 @@ class TextObervationProcessor(ObservationProcessor): tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0) return tree_str, obs_nodes_info - @beartype @staticmethod def clean_accesibility_tree(tree_str: str) -> str: """further clean accesibility tree""" @@ -598,7 +589,6 @@ class TextObervationProcessor(ObservationProcessor): return "\n".join(clean_lines) - @beartype def process(self, page: Page, client: CDPSession) -> str: # get the tab info open_tabs = page.context.pages @@ -657,7 +647,6 @@ class TextObervationProcessor(ObservationProcessor): content = f"{tab_title_str}\n\n{content}" return content - @beartype def get_element_center(self, element_id: str) -> tuple[float, float]: node_info = self.obs_nodes_info[element_id] node_bound = node_info["union_bound"] @@ -705,7 +694,6 @@ class ObservationHandler: ) self.viewport_size = viewport_size - @beartype def get_observation_space(self) -> spaces.Dict: text_space = spaces.Text( min_length=0, @@ -729,7 +717,6 @@ class ObservationHandler: return spaces.Dict({"text": text_space, "image": image_space}) - @beartype def get_observation( self, page: Page, client: CDPSession ) -> dict[str, Observation]: @@ -737,7 +724,6 @@ class ObservationHandler: image_obs = self.image_processor.process(page, client) return {"text": text_obs, "image": image_obs} - @beartype def get_observation_metadata(self) -> dict[str, ObservationMetadata]: return { "text": self.text_processor.meta_data, diff --git a/browser_env/utils.py b/browser_env/utils.py index 568a92b..1814242 100644 --- a/browser_env/utils.py +++ b/browser_env/utils.py @@ -4,7 +4,6 @@ from typing import Any, Dict, TypedDict, Union import numpy as np import numpy.typing as npt -from beartype import beartype from PIL import Image @@ -14,7 +13,6 @@ class DetachedPage: content: str # html -@beartype def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]: """Convert png bytes to numpy array diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 2a70d2b..1ec2526 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -8,8 +8,6 @@ from pathlib import Path from typing import Any, Tuple, Union import evaluate # type: ignore[import] -from beartype import beartype -from beartype.door import is_bearable from playwright.sync_api import CDPSession, Page from browser_env.actions import Action @@ -26,7 +24,6 @@ from evaluation_harness.helper_functions import ( Trajectory = list[Union[Action, StateInfo]] -@beartype class Evaluator(object): def __init__(self, eval_tag: str = "") -> None: self.eval_tag = eval_tag @@ -43,7 +40,7 @@ class Evaluator(object): @staticmethod def get_last_action(trajectory: Trajectory) -> Action: try: - is_bearable(trajectory[-1], Action) + # is_bearable(trajectory[-1], Action) last_action = trajectory[-1] except Exception: raise ValueError( @@ -55,7 +52,7 @@ class Evaluator(object): @staticmethod def get_last_state(trajectory: Trajectory) -> StateInfo: try: - is_bearable(trajectory[-2], StateInfo) + # is_bearable(trajectory[-2], StateInfo) last_state = trajectory[-2] except Exception: raise ValueError( @@ -65,7 +62,6 @@ class Evaluator(object): return last_state # type: ignore[return-value] -@beartype class StringExactEvaluator(Evaluator): """Check whether the answer is exactly the same as one of the reference answers""" @@ -95,7 +91,6 @@ class StringExactEvaluator(Evaluator): return 0.0 -@beartype class StringEvaluator(Evaluator): """Check whether the answer is correct with: exact match: the answer is exactly the same as the reference answer @@ -144,7 +139,6 @@ class StringEvaluator(Evaluator): return score -@beartype class StringSoftEvaluator(Evaluator): """Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer""" @@ -167,7 +161,6 @@ class StringSoftEvaluator(Evaluator): return float(rouge["rouge1"]) -@beartype class URLExactEvaluator(Evaluator): """Check whether the URL is exactly the same as of the reference URLs""" @@ -205,7 +198,6 @@ class URLExactEvaluator(Evaluator): raise ValueError(f"Unknown matching rule: {matching_rule}") -@beartype class HTMLContentExactEvaluator(Evaluator): """Check whether the contents appear in the page""" @@ -286,7 +278,6 @@ class HTMLContentExactEvaluator(Evaluator): ###### -@beartype class EvaluatorPartial(Evaluator): def __init__(self) -> None: raise NotImplementedError @@ -301,7 +292,6 @@ class EvaluatorPartial(Evaluator): raise NotImplementedError -@beartype class URLSoftEvaluator(EvaluatorPartial): """Parse the URL and compare the domain and parameters""" @@ -367,7 +357,6 @@ class EvaluatorComb: return score -@beartype def evaluator_router(config_file: Path | str) -> EvaluatorComb: """Router to get the evaluator class""" with open(config_file, "r") as f: diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index 3d59efd..915ef1f 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -4,7 +4,6 @@ from typing import Any from urllib.parse import urlparse import requests -from beartype import beartype from playwright.sync_api import CDPSession, Page from browser_env.env_config import ( @@ -21,7 +20,6 @@ from llms.providers.openai_utils import ( ) -@beartype def shopping_get_auth_token() -> str: response = requests.post( url=f"{SHOPPING}/rest/default/V1/integration/admin/token", @@ -37,7 +35,6 @@ def shopping_get_auth_token() -> str: return token -@beartype def shopping_get_latest_order_url() -> str: """Get the latest order url from the shopping website.""" @@ -62,7 +59,6 @@ def shopping_get_latest_order_url() -> str: return order_url -@beartype def shopping_get_sku_latest_review_author(sku: str) -> str: """Get the latest review for shopping admin.""" header = { @@ -80,7 +76,6 @@ def shopping_get_sku_latest_review_author(sku: str) -> str: return author -@beartype def shopping_get_sku_latest_review_rating(sku: str) -> str: """Get the latest review for shopping admin.""" header = { @@ -99,7 +94,6 @@ def shopping_get_sku_latest_review_rating(sku: str) -> str: return rating -@beartype def reddit_get_post_url(url: str) -> str: """Get the post url""" # Url is http://domain/f/subreddit/post_id/... @@ -118,7 +112,6 @@ def reddit_get_post_url(url: str) -> str: return post_url -@beartype def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str: # get the account index try: @@ -150,7 +143,6 @@ def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str: return role -@beartype def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: """Check whether the prediction matches the reference with GPT-3.5""" messages: list[dict[str, Any]] = [] diff --git a/run.py b/run.py index 7c8a7b8..c4781c2 100644 --- a/run.py +++ b/run.py @@ -9,7 +9,6 @@ import time from pathlib import Path import openai -from beartype import beartype from agent import ( Agent, @@ -144,7 +143,6 @@ def config() -> argparse.Namespace: return args -@beartype def early_stop( trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] ) -> tuple[bool, str]: @@ -201,7 +199,6 @@ def early_stop( return False, "" -@beartype def test( args: argparse.Namespace, agent: Agent | PromptAgent | TeacherForcingAgent, @@ -369,7 +366,6 @@ def get_unfinished(config_files: list[str], result_dir: str) -> list[str]: return unfinished_configs -@beartype def dump_config(args: argparse.Namespace) -> None: config_file = Path(args.result_dir) / "config.json" if not config_file.exists(): diff --git a/scripts/collect_obs.py b/scripts/collect_obs.py index c361b86..e5121b0 100644 --- a/scripts/collect_obs.py +++ b/scripts/collect_obs.py @@ -6,7 +6,6 @@ import time from typing import Dict, Optional, Tuple, Type, Union, cast import pytest -from beartype import beartype from playwright.sync_api import Page, expect from browser_env import ( @@ -21,13 +20,11 @@ from browser_env.env_config import * HEADLESS = False -@beartype def gen_tmp_storage_state() -> None: with open(f"scripts/tmp_storage_state.json", "w") as f: json.dump({"storage_state": ".auth/gitlab_state.json"}, f) -@beartype def get_observation( observation_type: str, current_viewport_only: bool ) -> None: diff --git a/tests/test_browser_env/test_script_browser_env.py b/tests/test_browser_env/test_script_browser_env.py index 7f9fcf1..33a7886 100644 --- a/tests/test_browser_env/test_script_browser_env.py +++ b/tests/test_browser_env/test_script_browser_env.py @@ -5,7 +5,6 @@ import tempfile from typing import Callable, Dict, Optional, Tuple, Type, Union, cast import pytest -from beartype.door import is_bearable from gymnasium.vector import AsyncVectorEnv from playwright.sync_api import Page @@ -128,7 +127,7 @@ def test_parallel_script_browser_env() -> None: ] ) ) - assert is_bearable(info["page"].tolist(), list[DetachedPage]) + # assert is_bearable(info["page"].tolist(), list[DetachedPage]) assert info["page"][0].url == "https://www.rfc-editor.org/rfc/rfc2606.html" assert info["page"][1].url == "https://www.rfc-editor.org/rfc/rfc6761.html" vector_env.close() # type: ignore[no-untyped-call] diff --git a/tests/test_evaluation_harness/test_exact_evaluators.py b/tests/test_evaluation_harness/test_exact_evaluators.py index a0def14..9715ccf 100644 --- a/tests/test_evaluation_harness/test_exact_evaluators.py +++ b/tests/test_evaluation_harness/test_exact_evaluators.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any import pytest -from beartype import beartype from py import test from agent import Agent, TeacherForcingAgent @@ -249,7 +248,6 @@ def test_html_content_url_comb_success( assert score == 1.0 -@beartype @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="Won't work using the demo sites" ) @@ -273,7 +271,6 @@ def test_func_success( assert score == 1.0 -@beartype @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="Won't work using the demo sites" ) @@ -297,7 +294,6 @@ def test_func_fail( assert score == 0.0 -@beartype def test_func_url_func_last_success( script_browser_env: ScriptBrowserEnv, ) -> None: @@ -319,7 +315,6 @@ def test_func_url_func_last_success( assert score == 1.0 -@beartype def test_func_url_func_page_success( script_browser_env: ScriptBrowserEnv, ) -> None: diff --git a/tests/test_evaluation_harness/test_helper_functions.py b/tests/test_evaluation_harness/test_helper_functions.py index b8406e4..bd671b9 100644 --- a/tests/test_evaluation_harness/test_helper_functions.py +++ b/tests/test_evaluation_harness/test_helper_functions.py @@ -2,8 +2,6 @@ import json import os from pathlib import Path -from beartype import beartype - from browser_env import ScriptBrowserEnv from browser_env.env_config import * from evaluation_harness.helper_functions import (