mirror of
https://github.com/web-arena-x/webarena.git
synced 2026-02-06 11:16:53 +00:00
remove beartype for efficency purpose
This commit is contained in:
parent
ed93b3a88f
commit
e44972d335
@ -3,8 +3,6 @@ import json
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env import Trajectory
|
||||
@ -48,11 +46,9 @@ class TeacherForcingAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@beartype
|
||||
def set_action_set_tag(self, tag: str) -> None:
|
||||
self.action_set_tag = tag
|
||||
|
||||
@beartype
|
||||
def set_actions(self, action_seq: str | list[str]) -> None:
|
||||
if isinstance(action_seq, str):
|
||||
action_strs = action_seq.strip().split("\n")
|
||||
@ -79,14 +75,12 @@ class TeacherForcingAgent(Agent):
|
||||
|
||||
self.actions: list[Action] = actions
|
||||
|
||||
@beartype
|
||||
def next_action(
|
||||
self, trajectory: Trajectory, intent: str, meta_data: Any
|
||||
) -> Action:
|
||||
"""Predict the next action given the observation"""
|
||||
return self.actions.pop(0)
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
test_config_file: str,
|
||||
@ -113,11 +107,9 @@ class PromptAgent(Agent):
|
||||
self.prompt_constructor = prompt_constructor
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
@beartype
|
||||
def set_action_set_tag(self, tag: str) -> None:
|
||||
self.action_set_tag = tag
|
||||
|
||||
@beartype
|
||||
def next_action(
|
||||
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
|
||||
) -> Action:
|
||||
|
||||
@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
|
||||
from browser_env import Action, ActionParsingError, Trajectory
|
||||
from browser_env.env_config import URL_MAPPINGS
|
||||
@ -38,7 +37,6 @@ class PromptConstructor(object):
|
||||
self.instruction: Instruction = instruction
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@beartype
|
||||
def get_lm_api_input(
|
||||
self, intro: str, examples: list[tuple[str, str]], current: str
|
||||
) -> APIInput:
|
||||
@ -84,7 +82,6 @@ class PromptConstructor(object):
|
||||
f"Provider {self.lm_config.provider} not implemented"
|
||||
)
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
@ -93,7 +90,6 @@ class PromptConstructor(object):
|
||||
) -> APIInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@beartype
|
||||
def map_url_to_real(self, url: str) -> str:
|
||||
"""Map the urls to their real world counterparts"""
|
||||
for i, j in URL_MAPPINGS.items():
|
||||
@ -101,7 +97,6 @@ class PromptConstructor(object):
|
||||
url = url.replace(i, j)
|
||||
return url
|
||||
|
||||
@beartype
|
||||
def map_url_to_local(self, url: str) -> str:
|
||||
"""Map the urls to their local counterparts"""
|
||||
for i, j in URL_MAPPINGS.items():
|
||||
@ -109,11 +104,9 @@ class PromptConstructor(object):
|
||||
url = url.replace(j, i)
|
||||
return url
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@beartype
|
||||
def extract_action(self, response: str) -> str:
|
||||
response = self._extract_action(response)
|
||||
response = self.map_url_to_local(response)
|
||||
@ -131,7 +124,6 @@ class DirectPromptConstructor(PromptConstructor):
|
||||
):
|
||||
super().__init__(instruction_path, lm_config, tokenizer)
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
@ -167,7 +159,6 @@ class DirectPromptConstructor(PromptConstructor):
|
||||
prompt = self.get_lm_api_input(intro, examples, current)
|
||||
return prompt
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||
pattern = rf"{action_splitter}(.*?){action_splitter}"
|
||||
@ -192,7 +183,6 @@ class CoTPromptConstructor(PromptConstructor):
|
||||
super().__init__(instruction_path, lm_config, tokenizer)
|
||||
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
@ -225,7 +215,6 @@ class CoTPromptConstructor(PromptConstructor):
|
||||
prompt = self.get_lm_api_input(intro, examples, current)
|
||||
return prompt
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
# find the first occurence of action
|
||||
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||
|
||||
@ -12,8 +12,6 @@ from typing import Any, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
from gymnasium import spaces
|
||||
from playwright._impl._api_structures import ViewportSize
|
||||
from playwright.async_api import BrowserContext as ABrowserContext
|
||||
@ -55,7 +53,6 @@ from browser_env.processors import (
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def is_in_viewport(
|
||||
element: Locator, viewport: ViewportSize, threshold: float = 0.3
|
||||
) -> bool:
|
||||
@ -75,7 +72,6 @@ def is_in_viewport(
|
||||
return ratio > threshold
|
||||
|
||||
|
||||
@beartype
|
||||
async def async_is_in_viewport(
|
||||
element: ALocator, viewport: ViewportSize, threshold: float = 0.3
|
||||
) -> bool:
|
||||
@ -111,7 +107,6 @@ class Action(TypedDict):
|
||||
raw_prediction: str # raw prediction from the model
|
||||
|
||||
|
||||
@beartype
|
||||
def action2str(
|
||||
action: Action, action_set_tag: str, semantic_element: str = ""
|
||||
) -> str:
|
||||
@ -274,7 +269,6 @@ class ActionTypes(IntEnum):
|
||||
return f"ACTION_TYPES.{self.name}"
|
||||
|
||||
|
||||
@beartype
|
||||
def is_equivalent(a: Action, b: Action) -> bool:
|
||||
"""Return True if two actions are equal."""
|
||||
if a["action_type"] != b["action_type"]:
|
||||
@ -338,12 +332,11 @@ _role2id: dict[RolesType, int] = {
|
||||
_id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@beartype
|
||||
def _keys2ids(keys: list[int | str] | str) -> list[int]:
|
||||
return list(
|
||||
map(
|
||||
lambda key: _key2id[str(key)]
|
||||
if is_bearable(key, str)
|
||||
if isinstance(key, str)
|
||||
else int(key),
|
||||
keys,
|
||||
)
|
||||
@ -424,7 +417,6 @@ def create_random_action() -> Action:
|
||||
}
|
||||
|
||||
|
||||
@beartype
|
||||
def create_none_action() -> Action:
|
||||
"""Return a valid action object that does nothing."""
|
||||
return {
|
||||
@ -445,14 +437,12 @@ def create_none_action() -> Action:
|
||||
}
|
||||
|
||||
|
||||
@beartype
|
||||
def create_stop_action(answer: str) -> Action:
|
||||
action = create_none_action()
|
||||
action.update({"action_type": ActionTypes.STOP, "answer": answer})
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_scroll_action(direction: str) -> Action:
|
||||
"""Return the playwright action"""
|
||||
assert direction in ["up", "down"]
|
||||
@ -466,7 +456,6 @@ def create_scroll_action(direction: str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_mouse_hover_action(
|
||||
left: float | None = None, top: float | None = None
|
||||
) -> Action:
|
||||
@ -481,7 +470,6 @@ def create_mouse_hover_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_key_press_action(key_comb: str) -> Action:
|
||||
"""Return the key press action"""
|
||||
|
||||
@ -504,7 +492,6 @@ def create_key_press_action(key_comb: str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_page_focus_action(page_number: int) -> Action:
|
||||
"""Return a valid action object with type PAGE_FOCUS."""
|
||||
action = create_none_action()
|
||||
@ -517,7 +504,6 @@ def create_page_focus_action(page_number: int) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_new_tab_action() -> Action:
|
||||
"""Return a valid action object with type NEW_TAB."""
|
||||
action = create_none_action()
|
||||
@ -529,7 +515,6 @@ def create_new_tab_action() -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_go_back_action() -> Action:
|
||||
"""Return a valid action object with type GO_BACK."""
|
||||
action = create_none_action()
|
||||
@ -541,7 +526,6 @@ def create_go_back_action() -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_go_forward_action() -> Action:
|
||||
"""Return a valid action object with type GO_FORWARD."""
|
||||
action = create_none_action()
|
||||
@ -553,7 +537,6 @@ def create_go_forward_action() -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_goto_url_action(url: str) -> Action:
|
||||
"""Return a valid action object with type GOTO_URL."""
|
||||
action = create_none_action()
|
||||
@ -566,7 +549,6 @@ def create_goto_url_action(url: str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_page_close_action() -> Action:
|
||||
"""Return a valid action object with type PAGE_CLOSE."""
|
||||
action = create_none_action()
|
||||
@ -578,7 +560,6 @@ def create_page_close_action() -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_mouse_click_action(
|
||||
left: float | None = None, top: float | None = None
|
||||
) -> Action:
|
||||
@ -602,7 +583,6 @@ def create_mouse_click_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
|
||||
"""Return a valid action object with type TYPE."""
|
||||
action = create_none_action()
|
||||
@ -615,7 +595,6 @@ def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_click_action(
|
||||
element_id: str = "",
|
||||
element_role: RolesType = "link",
|
||||
@ -637,7 +616,6 @@ def create_click_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_hover_action(
|
||||
element_id: str = "",
|
||||
element_role: RolesType = "link",
|
||||
@ -659,7 +637,6 @@ def create_hover_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_type_action(
|
||||
text: str,
|
||||
element_id: str = "",
|
||||
@ -683,7 +660,6 @@ def create_type_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_check_action(pw_code: str) -> Action:
|
||||
action = create_none_action()
|
||||
action.update(
|
||||
@ -695,7 +671,6 @@ def create_check_action(pw_code: str) -> Action:
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_select_option_action(
|
||||
pw_code: str,
|
||||
) -> Action:
|
||||
@ -709,7 +684,6 @@ def create_select_option_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_focus_action(
|
||||
element_role: RolesType, element_name: str = "", nth: int = 0
|
||||
) -> Action:
|
||||
@ -728,7 +702,6 @@ def create_focus_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_focus_and_click_action(
|
||||
element_role: RolesType, element_name: str = "", nth: int = 0
|
||||
) -> Action:
|
||||
@ -748,7 +721,6 @@ def create_focus_and_click_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def create_focus_and_type_action(
|
||||
keys: list[int | str] | str,
|
||||
element_role: RolesType,
|
||||
@ -771,7 +743,6 @@ def create_focus_and_type_action(
|
||||
return action
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_scroll(direction: str, page: Page) -> None:
|
||||
# perform the action
|
||||
# code from natbot
|
||||
@ -785,7 +756,6 @@ def execute_scroll(direction: str, page: Page) -> None:
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_scroll(direction: str, page: APage) -> None:
|
||||
# perform the action
|
||||
# code from natbot
|
||||
@ -799,19 +769,16 @@ async def aexecute_scroll(direction: str, page: APage) -> None:
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_key_press(key: str, page: Page) -> None:
|
||||
"""Press a key."""
|
||||
page.keyboard.press(key)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_key_press(key: str, page: APage) -> None:
|
||||
"""Press a key."""
|
||||
await page.keyboard.press(key)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_mouse_hover(left: float, top: float, page: Page) -> None:
|
||||
"""Click at coordinates (left, top)."""
|
||||
viewport_size = page.viewport_size
|
||||
@ -821,7 +788,6 @@ def execute_mouse_hover(left: float, top: float, page: Page) -> None:
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None:
|
||||
"""Click at coordinates (left, top)."""
|
||||
viewport_size = page.viewport_size
|
||||
@ -840,7 +806,6 @@ def execute_mouse_click(left: float, top: float, page: Page) -> None:
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
|
||||
"""Click at coordinates (left, top)."""
|
||||
viewport_size = page.viewport_size
|
||||
@ -850,19 +815,16 @@ async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_keyboard_type(text: str, page: Page) -> None:
|
||||
"""Fill the focused element with text."""
|
||||
page.keyboard.type(text)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_keyboard_type(text: str, page: APage) -> None:
|
||||
"""Fill the focused element with text."""
|
||||
await page.keyboard.type(text)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_click_current(page: Page) -> None:
|
||||
"""Click at the current mouse position."""
|
||||
locators = page.locator("*:focus")
|
||||
@ -874,7 +836,6 @@ def execute_click_current(page: Page) -> None:
|
||||
locators.click()
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_click_current(page: APage) -> None:
|
||||
"""Click at the current mouse position."""
|
||||
locators = page.locator("*:focus")
|
||||
@ -889,21 +850,18 @@ async def aexecute_click_current(page: APage) -> None:
|
||||
await page.wait_for_load_state("load")
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_type(keys: list[int], page: Page) -> None:
|
||||
"""Send keystrokes to the focused element."""
|
||||
text = "".join([_id2key[key] for key in keys])
|
||||
page.keyboard.type(text)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_type(keys: list[int], page: APage) -> None:
|
||||
"""Send keystrokes to the focused element."""
|
||||
text = "".join([_id2key[key] for key in keys])
|
||||
await page.keyboard.type(text)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_focus(
|
||||
element_role: int, element_name: str, nth: int, page: Page
|
||||
) -> None:
|
||||
@ -940,7 +898,6 @@ def execute_focus(
|
||||
element_location_list[nth][0].focus()
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_focus(
|
||||
element_role: int, element_name: str, nth: int, page: APage
|
||||
) -> None:
|
||||
@ -977,7 +934,6 @@ async def aexecute_focus(
|
||||
await element_location_list[nth][0].focus()
|
||||
|
||||
|
||||
@beartype
|
||||
def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
|
||||
locator = page
|
||||
for call in locator_calls:
|
||||
@ -988,7 +944,6 @@ def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
|
||||
return locator # type: ignore[return-value]
|
||||
|
||||
|
||||
@beartype
|
||||
async def alocate(
|
||||
locator_calls: list[ParsedPlaywrightCode], page: APage
|
||||
) -> ALocator:
|
||||
@ -1001,7 +956,6 @@ async def alocate(
|
||||
return locator # type: ignore[return-value]
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_playwright_click(
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
page: Page,
|
||||
@ -1014,7 +968,6 @@ def execute_playwright_click(
|
||||
locator.click(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_playwright_click(
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
page: APage,
|
||||
@ -1027,7 +980,6 @@ async def aexecute_playwright_click(
|
||||
await locator.click(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_playwright_hover(
|
||||
locator_code: list[ParsedPlaywrightCode], page: Page
|
||||
) -> None:
|
||||
@ -1037,7 +989,6 @@ def execute_playwright_hover(
|
||||
locator.hover()
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_playwright_hover(
|
||||
locator_code: list[ParsedPlaywrightCode], page: APage
|
||||
) -> None:
|
||||
@ -1047,7 +998,6 @@ async def aexecute_playwright_hover(
|
||||
await locator.hover()
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_playwright_type(
|
||||
text: str,
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
@ -1061,7 +1011,6 @@ def execute_playwright_type(
|
||||
locator.type(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_playwright_type(
|
||||
text: str,
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
@ -1075,7 +1024,6 @@ async def aexecute_playwright_type(
|
||||
await locator.type(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_playwright_select_option(
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
page: Page,
|
||||
@ -1087,7 +1035,6 @@ def execute_playwright_select_option(
|
||||
locator.select_option(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_playwright_select_option(
|
||||
locator_code: list[ParsedPlaywrightCode],
|
||||
page: APage,
|
||||
@ -1099,7 +1046,6 @@ async def aexecute_playwright_select_option(
|
||||
await locator.select_option(*pw_action_args, **pw_action_kwargs)
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_playwright_check(
|
||||
locator_code: list[ParsedPlaywrightCode], page: Page
|
||||
) -> None:
|
||||
@ -1108,7 +1054,6 @@ def execute_playwright_check(
|
||||
locator.check()
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_playwright_check(
|
||||
locator_code: list[ParsedPlaywrightCode], page: APage
|
||||
) -> None:
|
||||
@ -1117,7 +1062,6 @@ async def aexecute_playwright_check(
|
||||
await locator.check()
|
||||
|
||||
|
||||
@beartype
|
||||
def execute_action(
|
||||
action: Action,
|
||||
page: Page,
|
||||
@ -1252,7 +1196,6 @@ def execute_action(
|
||||
return page
|
||||
|
||||
|
||||
@beartype
|
||||
async def aexecute_action(
|
||||
action: Action, page: APage, browser_ctx: ABrowserContext
|
||||
) -> APage:
|
||||
@ -1383,7 +1326,6 @@ async def aexecute_action(
|
||||
return page
|
||||
|
||||
|
||||
@beartype
|
||||
def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
|
||||
# extract function calls
|
||||
if not code.startswith("page."):
|
||||
@ -1444,14 +1386,12 @@ def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
|
||||
return parsed_chain
|
||||
|
||||
|
||||
@beartype
|
||||
class ActionParsingError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@beartype
|
||||
def create_playwright_action(playwright_code: str) -> Action:
|
||||
"""Main function to return individual playwright action"""
|
||||
# get the last action
|
||||
@ -1524,7 +1464,6 @@ def create_playwright_action(playwright_code: str) -> Action:
|
||||
raise ActionParsingError(f"Unknown playwright action {action}")
|
||||
|
||||
|
||||
@beartype
|
||||
def create_id_based_action(action_str: str) -> Action:
|
||||
"""Main function to return individual id based action"""
|
||||
action_str = action_str.strip()
|
||||
|
||||
@ -5,7 +5,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.async_api import Page, ViewportSize, async_playwright
|
||||
@ -23,7 +22,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 2048,
|
||||
@ -46,7 +44,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
self.timeout = timeout
|
||||
self.viewport_size = viewport_size
|
||||
|
||||
@beartype
|
||||
async def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = async_playwright()
|
||||
self.playwright = await self.context_manager.__aenter__()
|
||||
@ -73,7 +70,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
if start_url:
|
||||
await self.page.goto(start_url)
|
||||
|
||||
@beartype
|
||||
async def areset(
|
||||
self,
|
||||
*,
|
||||
@ -104,7 +100,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
{"page": DetachedPage(self.page.url, content)},
|
||||
)
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
@ -120,7 +115,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
def close(self) -> None:
|
||||
asyncio.run(self.aclose())
|
||||
|
||||
@beartype
|
||||
async def astep(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
@ -153,7 +147,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
},
|
||||
)
|
||||
|
||||
@beartype
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
|
||||
@ -3,7 +3,6 @@ import glob
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from browser_env.env_config import (
|
||||
@ -18,7 +17,6 @@ HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
@beartype
|
||||
def is_expired(
|
||||
storage_state: Path, url: str, keyword: str, url_exact: bool = True
|
||||
) -> bool:
|
||||
@ -44,7 +42,6 @@ def is_expired(
|
||||
return url not in d_url
|
||||
|
||||
|
||||
@beartype
|
||||
def renew_comb(comb: list[str]) -> None:
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
@ -91,7 +88,6 @@ def renew_comb(comb: list[str]) -> None:
|
||||
context_manager.__exit__()
|
||||
|
||||
|
||||
@beartype
|
||||
def main() -> None:
|
||||
sites = ["gitlab", "shopping", "shopping_admin", "reddit"]
|
||||
urls = [
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.sync_api import (
|
||||
@ -39,7 +38,6 @@ class PlaywrightScript:
|
||||
value: str | None = None # avatar movie, Enter
|
||||
|
||||
|
||||
@beartype
|
||||
def parse_action(action: str) -> PlaywrightScript:
|
||||
splitted = action.strip().split(" ")
|
||||
assert len(splitted) >= 2
|
||||
@ -73,7 +71,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 8192,
|
||||
@ -121,7 +118,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
self.observation_handler.get_observation_space()
|
||||
)
|
||||
|
||||
@beartype
|
||||
def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = sync_playwright()
|
||||
self.playwright = self.context_manager.__enter__()
|
||||
@ -168,23 +164,19 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
client.send("Accessibility.enable")
|
||||
self.page.client = client # type: ignore
|
||||
|
||||
@beartype
|
||||
def get_page_client(self, page: Page) -> CDPSession:
|
||||
return page.client # type: ignore
|
||||
|
||||
@beartype
|
||||
def _get_obs(self) -> dict[str, Observation]:
|
||||
obs = self.observation_handler.get_observation(
|
||||
self.page, self.get_page_client(self.page)
|
||||
)
|
||||
return obs
|
||||
|
||||
@beartype
|
||||
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
metadata = self.observation_handler.get_observation_metadata()
|
||||
return metadata
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
@ -223,12 +215,10 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
|
||||
return (observation, info)
|
||||
|
||||
@beartype
|
||||
def save_trace(self, trace_path: str | Path) -> None:
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.stop(path=trace_path)
|
||||
|
||||
@beartype
|
||||
def close(self) -> None:
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
@ -5,7 +5,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
|
||||
from agent.prompts import *
|
||||
@ -35,7 +34,6 @@ HTML_TEMPLATE = """
|
||||
"""
|
||||
|
||||
|
||||
@beartype
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
@ -63,7 +61,6 @@ def get_render_action(
|
||||
return action_str
|
||||
|
||||
|
||||
@beartype
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Any, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import spaces
|
||||
from playwright.sync_api import CDPSession, Page, ViewportSize
|
||||
|
||||
@ -60,7 +59,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
create_empty_metadata()
|
||||
) # use the store meta data of this observation type
|
||||
|
||||
@beartype
|
||||
def fetch_browser_info(
|
||||
self,
|
||||
page: Page,
|
||||
@ -108,7 +106,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
|
||||
return info
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def get_bounding_client_rect(
|
||||
client: CDPSession, backend_node_id: str
|
||||
@ -134,7 +131,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
except Exception as e:
|
||||
return {"result": {"subtype": "error"}}
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def get_element_in_viewport_ratio(
|
||||
elem_left_bound: float,
|
||||
@ -167,7 +163,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
ratio = overlap_width * overlap_height / width * height
|
||||
return ratio
|
||||
|
||||
@beartype
|
||||
def fetch_page_html(
|
||||
self,
|
||||
info: BrowserInfo,
|
||||
@ -323,7 +318,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
|
||||
return dom_tree
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
|
||||
"""Parse the html tree into a string text"""
|
||||
@ -367,7 +361,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
html = dfs(0, 0)
|
||||
return html, obs_nodes_info
|
||||
|
||||
@beartype
|
||||
def fetch_page_accessibility_tree(
|
||||
self,
|
||||
info: BrowserInfo,
|
||||
@ -487,7 +480,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
|
||||
return accessibility_tree
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def parse_accessibility_tree(
|
||||
accessibility_tree: AccessibilityTree,
|
||||
@ -575,7 +567,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
|
||||
return tree_str, obs_nodes_info
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def clean_accesibility_tree(tree_str: str) -> str:
|
||||
"""further clean accesibility tree"""
|
||||
@ -598,7 +589,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
|
||||
return "\n".join(clean_lines)
|
||||
|
||||
@beartype
|
||||
def process(self, page: Page, client: CDPSession) -> str:
|
||||
# get the tab info
|
||||
open_tabs = page.context.pages
|
||||
@ -657,7 +647,6 @@ class TextObervationProcessor(ObservationProcessor):
|
||||
content = f"{tab_title_str}\n\n{content}"
|
||||
return content
|
||||
|
||||
@beartype
|
||||
def get_element_center(self, element_id: str) -> tuple[float, float]:
|
||||
node_info = self.obs_nodes_info[element_id]
|
||||
node_bound = node_info["union_bound"]
|
||||
@ -705,7 +694,6 @@ class ObservationHandler:
|
||||
)
|
||||
self.viewport_size = viewport_size
|
||||
|
||||
@beartype
|
||||
def get_observation_space(self) -> spaces.Dict:
|
||||
text_space = spaces.Text(
|
||||
min_length=0,
|
||||
@ -729,7 +717,6 @@ class ObservationHandler:
|
||||
|
||||
return spaces.Dict({"text": text_space, "image": image_space})
|
||||
|
||||
@beartype
|
||||
def get_observation(
|
||||
self, page: Page, client: CDPSession
|
||||
) -> dict[str, Observation]:
|
||||
@ -737,7 +724,6 @@ class ObservationHandler:
|
||||
image_obs = self.image_processor.process(page, client)
|
||||
return {"text": text_obs, "image": image_obs}
|
||||
|
||||
@beartype
|
||||
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
return {
|
||||
"text": self.text_processor.meta_data,
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any, Dict, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -14,7 +13,6 @@ class DetachedPage:
|
||||
content: str # html
|
||||
|
||||
|
||||
@beartype
|
||||
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
|
||||
"""Convert png bytes to numpy array
|
||||
|
||||
|
||||
@ -8,8 +8,6 @@ from pathlib import Path
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import evaluate # type: ignore[import]
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
from playwright.sync_api import CDPSession, Page
|
||||
|
||||
from browser_env.actions import Action
|
||||
@ -26,7 +24,6 @@ from evaluation_harness.helper_functions import (
|
||||
Trajectory = list[Union[Action, StateInfo]]
|
||||
|
||||
|
||||
@beartype
|
||||
class Evaluator(object):
|
||||
def __init__(self, eval_tag: str = "") -> None:
|
||||
self.eval_tag = eval_tag
|
||||
@ -43,7 +40,7 @@ class Evaluator(object):
|
||||
@staticmethod
|
||||
def get_last_action(trajectory: Trajectory) -> Action:
|
||||
try:
|
||||
is_bearable(trajectory[-1], Action)
|
||||
# is_bearable(trajectory[-1], Action)
|
||||
last_action = trajectory[-1]
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
@ -55,7 +52,7 @@ class Evaluator(object):
|
||||
@staticmethod
|
||||
def get_last_state(trajectory: Trajectory) -> StateInfo:
|
||||
try:
|
||||
is_bearable(trajectory[-2], StateInfo)
|
||||
# is_bearable(trajectory[-2], StateInfo)
|
||||
last_state = trajectory[-2]
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
@ -65,7 +62,6 @@ class Evaluator(object):
|
||||
return last_state # type: ignore[return-value]
|
||||
|
||||
|
||||
@beartype
|
||||
class StringExactEvaluator(Evaluator):
|
||||
"""Check whether the answer is exactly the same as one of the reference answers"""
|
||||
|
||||
@ -95,7 +91,6 @@ class StringExactEvaluator(Evaluator):
|
||||
return 0.0
|
||||
|
||||
|
||||
@beartype
|
||||
class StringEvaluator(Evaluator):
|
||||
"""Check whether the answer is correct with:
|
||||
exact match: the answer is exactly the same as the reference answer
|
||||
@ -144,7 +139,6 @@ class StringEvaluator(Evaluator):
|
||||
return score
|
||||
|
||||
|
||||
@beartype
|
||||
class StringSoftEvaluator(Evaluator):
|
||||
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
|
||||
|
||||
@ -167,7 +161,6 @@ class StringSoftEvaluator(Evaluator):
|
||||
return float(rouge["rouge1"])
|
||||
|
||||
|
||||
@beartype
|
||||
class URLExactEvaluator(Evaluator):
|
||||
"""Check whether the URL is exactly the same as of the reference URLs"""
|
||||
|
||||
@ -205,7 +198,6 @@ class URLExactEvaluator(Evaluator):
|
||||
raise ValueError(f"Unknown matching rule: {matching_rule}")
|
||||
|
||||
|
||||
@beartype
|
||||
class HTMLContentExactEvaluator(Evaluator):
|
||||
"""Check whether the contents appear in the page"""
|
||||
|
||||
@ -286,7 +278,6 @@ class HTMLContentExactEvaluator(Evaluator):
|
||||
######
|
||||
|
||||
|
||||
@beartype
|
||||
class EvaluatorPartial(Evaluator):
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError
|
||||
@ -301,7 +292,6 @@ class EvaluatorPartial(Evaluator):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@beartype
|
||||
class URLSoftEvaluator(EvaluatorPartial):
|
||||
"""Parse the URL and compare the domain and parameters"""
|
||||
|
||||
@ -367,7 +357,6 @@ class EvaluatorComb:
|
||||
return score
|
||||
|
||||
|
||||
@beartype
|
||||
def evaluator_router(config_file: Path | str) -> EvaluatorComb:
|
||||
"""Router to get the evaluator class"""
|
||||
with open(config_file, "r") as f:
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import CDPSession, Page
|
||||
|
||||
from browser_env.env_config import (
|
||||
@ -21,7 +20,6 @@ from llms.providers.openai_utils import (
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_auth_token() -> str:
|
||||
response = requests.post(
|
||||
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
|
||||
@ -37,7 +35,6 @@ def shopping_get_auth_token() -> str:
|
||||
return token
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_latest_order_url() -> str:
|
||||
"""Get the latest order url from the shopping website."""
|
||||
|
||||
@ -62,7 +59,6 @@ def shopping_get_latest_order_url() -> str:
|
||||
return order_url
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_sku_latest_review_author(sku: str) -> str:
|
||||
"""Get the latest review for shopping admin."""
|
||||
header = {
|
||||
@ -80,7 +76,6 @@ def shopping_get_sku_latest_review_author(sku: str) -> str:
|
||||
return author
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_sku_latest_review_rating(sku: str) -> str:
|
||||
"""Get the latest review for shopping admin."""
|
||||
header = {
|
||||
@ -99,7 +94,6 @@ def shopping_get_sku_latest_review_rating(sku: str) -> str:
|
||||
return rating
|
||||
|
||||
|
||||
@beartype
|
||||
def reddit_get_post_url(url: str) -> str:
|
||||
"""Get the post url"""
|
||||
# Url is http://domain/f/subreddit/post_id/...
|
||||
@ -118,7 +112,6 @@ def reddit_get_post_url(url: str) -> str:
|
||||
return post_url
|
||||
|
||||
|
||||
@beartype
|
||||
def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
|
||||
# get the account index
|
||||
try:
|
||||
@ -150,7 +143,6 @@ def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
|
||||
return role
|
||||
|
||||
|
||||
@beartype
|
||||
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
|
||||
"""Check whether the prediction matches the reference with GPT-3.5"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
4
run.py
4
run.py
@ -9,7 +9,6 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import openai
|
||||
from beartype import beartype
|
||||
|
||||
from agent import (
|
||||
Agent,
|
||||
@ -144,7 +143,6 @@ def config() -> argparse.Namespace:
|
||||
return args
|
||||
|
||||
|
||||
@beartype
|
||||
def early_stop(
|
||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
) -> tuple[bool, str]:
|
||||
@ -201,7 +199,6 @@ def early_stop(
|
||||
return False, ""
|
||||
|
||||
|
||||
@beartype
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
agent: Agent | PromptAgent | TeacherForcingAgent,
|
||||
@ -369,7 +366,6 @@ def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
||||
return unfinished_configs
|
||||
|
||||
|
||||
@beartype
|
||||
def dump_config(args: argparse.Namespace) -> None:
|
||||
config_file = Path(args.result_dir) / "config.json"
|
||||
if not config_file.exists():
|
||||
|
||||
@ -6,7 +6,6 @@ import time
|
||||
from typing import Dict, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import Page, expect
|
||||
|
||||
from browser_env import (
|
||||
@ -21,13 +20,11 @@ from browser_env.env_config import *
|
||||
HEADLESS = False
|
||||
|
||||
|
||||
@beartype
|
||||
def gen_tmp_storage_state() -> None:
|
||||
with open(f"scripts/tmp_storage_state.json", "w") as f:
|
||||
json.dump({"storage_state": ".auth/gitlab_state.json"}, f)
|
||||
|
||||
|
||||
@beartype
|
||||
def get_observation(
|
||||
observation_type: str, current_viewport_only: bool
|
||||
) -> None:
|
||||
|
||||
@ -5,7 +5,6 @@ import tempfile
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from beartype.door import is_bearable
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
from playwright.sync_api import Page
|
||||
|
||||
@ -128,7 +127,7 @@ def test_parallel_script_browser_env() -> None:
|
||||
]
|
||||
)
|
||||
)
|
||||
assert is_bearable(info["page"].tolist(), list[DetachedPage])
|
||||
# assert is_bearable(info["page"].tolist(), list[DetachedPage])
|
||||
assert info["page"][0].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
|
||||
assert info["page"][1].url == "https://www.rfc-editor.org/rfc/rfc6761.html"
|
||||
vector_env.close() # type: ignore[no-untyped-call]
|
||||
|
||||
@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from beartype import beartype
|
||||
from py import test
|
||||
|
||||
from agent import Agent, TeacherForcingAgent
|
||||
@ -249,7 +248,6 @@ def test_html_content_url_comb_success(
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
@pytest.mark.skipif(
|
||||
IN_GITHUB_ACTIONS, reason="Won't work using the demo sites"
|
||||
)
|
||||
@ -273,7 +271,6 @@ def test_func_success(
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
@pytest.mark.skipif(
|
||||
IN_GITHUB_ACTIONS, reason="Won't work using the demo sites"
|
||||
)
|
||||
@ -297,7 +294,6 @@ def test_func_fail(
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_url_func_last_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
@ -319,7 +315,6 @@ def test_func_url_func_last_success(
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_url_func_page_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
|
||||
@ -2,8 +2,6 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from beartype import beartype
|
||||
|
||||
from browser_env import ScriptBrowserEnv
|
||||
from browser_env.env_config import *
|
||||
from evaluation_harness.helper_functions import (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user