remove beartype for efficency purpose

This commit is contained in:
alexisxy 2023-09-12 22:26:23 -04:00
parent ed93b3a88f
commit e44972d335
16 changed files with 4 additions and 158 deletions

View File

@ -3,8 +3,6 @@ import json
from typing import Any
import tiktoken
from beartype import beartype
from beartype.door import is_bearable
from agent.prompts import *
from browser_env import Trajectory
@ -48,11 +46,9 @@ class TeacherForcingAgent(Agent):
def __init__(self) -> None:
super().__init__()
@beartype
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def set_actions(self, action_seq: str | list[str]) -> None:
if isinstance(action_seq, str):
action_strs = action_seq.strip().split("\n")
@ -79,14 +75,12 @@ class TeacherForcingAgent(Agent):
self.actions: list[Action] = actions
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
return self.actions.pop(0)
@beartype
def reset(
self,
test_config_file: str,
@ -113,11 +107,9 @@ class PromptAgent(Agent):
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
@beartype
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
) -> Action:

View File

@ -4,7 +4,6 @@ from pathlib import Path
from typing import Any, TypedDict
import tiktoken
from beartype import beartype
from browser_env import Action, ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
@ -38,7 +37,6 @@ class PromptConstructor(object):
self.instruction: Instruction = instruction
self.tokenizer = tokenizer
@beartype
def get_lm_api_input(
self, intro: str, examples: list[tuple[str, str]], current: str
) -> APIInput:
@ -84,7 +82,6 @@ class PromptConstructor(object):
f"Provider {self.lm_config.provider} not implemented"
)
@beartype
def construct(
self,
trajectory: Trajectory,
@ -93,7 +90,6 @@ class PromptConstructor(object):
) -> APIInput:
raise NotImplementedError
@beartype
def map_url_to_real(self, url: str) -> str:
"""Map the urls to their real world counterparts"""
for i, j in URL_MAPPINGS.items():
@ -101,7 +97,6 @@ class PromptConstructor(object):
url = url.replace(i, j)
return url
@beartype
def map_url_to_local(self, url: str) -> str:
"""Map the urls to their local counterparts"""
for i, j in URL_MAPPINGS.items():
@ -109,11 +104,9 @@ class PromptConstructor(object):
url = url.replace(j, i)
return url
@beartype
def _extract_action(self, response: str) -> str:
raise NotImplementedError
@beartype
def extract_action(self, response: str) -> str:
response = self._extract_action(response)
response = self.map_url_to_local(response)
@ -131,7 +124,6 @@ class DirectPromptConstructor(PromptConstructor):
):
super().__init__(instruction_path, lm_config, tokenizer)
@beartype
def construct(
self,
trajectory: Trajectory,
@ -167,7 +159,6 @@ class DirectPromptConstructor(PromptConstructor):
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
@beartype
def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
@ -192,7 +183,6 @@ class CoTPromptConstructor(PromptConstructor):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
@beartype
def construct(
self,
trajectory: Trajectory,
@ -225,7 +215,6 @@ class CoTPromptConstructor(PromptConstructor):
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
@beartype
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]

View File

@ -12,8 +12,6 @@ from typing import Any, TypedDict, Union, cast
import numpy as np
import numpy.typing as npt
from beartype import beartype
from beartype.door import is_bearable
from gymnasium import spaces
from playwright._impl._api_structures import ViewportSize
from playwright.async_api import BrowserContext as ABrowserContext
@ -55,7 +53,6 @@ from browser_env.processors import (
)
@beartype
def is_in_viewport(
element: Locator, viewport: ViewportSize, threshold: float = 0.3
) -> bool:
@ -75,7 +72,6 @@ def is_in_viewport(
return ratio > threshold
@beartype
async def async_is_in_viewport(
element: ALocator, viewport: ViewportSize, threshold: float = 0.3
) -> bool:
@ -111,7 +107,6 @@ class Action(TypedDict):
raw_prediction: str # raw prediction from the model
@beartype
def action2str(
action: Action, action_set_tag: str, semantic_element: str = ""
) -> str:
@ -274,7 +269,6 @@ class ActionTypes(IntEnum):
return f"ACTION_TYPES.{self.name}"
@beartype
def is_equivalent(a: Action, b: Action) -> bool:
"""Return True if two actions are equal."""
if a["action_type"] != b["action_type"]:
@ -338,12 +332,11 @@ _role2id: dict[RolesType, int] = {
_id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type]
@beartype
def _keys2ids(keys: list[int | str] | str) -> list[int]:
return list(
map(
lambda key: _key2id[str(key)]
if is_bearable(key, str)
if isinstance(key, str)
else int(key),
keys,
)
@ -424,7 +417,6 @@ def create_random_action() -> Action:
}
@beartype
def create_none_action() -> Action:
"""Return a valid action object that does nothing."""
return {
@ -445,14 +437,12 @@ def create_none_action() -> Action:
}
@beartype
def create_stop_action(answer: str) -> Action:
action = create_none_action()
action.update({"action_type": ActionTypes.STOP, "answer": answer})
return action
@beartype
def create_scroll_action(direction: str) -> Action:
"""Return the playwright action"""
assert direction in ["up", "down"]
@ -466,7 +456,6 @@ def create_scroll_action(direction: str) -> Action:
return action
@beartype
def create_mouse_hover_action(
left: float | None = None, top: float | None = None
) -> Action:
@ -481,7 +470,6 @@ def create_mouse_hover_action(
return action
@beartype
def create_key_press_action(key_comb: str) -> Action:
"""Return the key press action"""
@ -504,7 +492,6 @@ def create_key_press_action(key_comb: str) -> Action:
return action
@beartype
def create_page_focus_action(page_number: int) -> Action:
"""Return a valid action object with type PAGE_FOCUS."""
action = create_none_action()
@ -517,7 +504,6 @@ def create_page_focus_action(page_number: int) -> Action:
return action
@beartype
def create_new_tab_action() -> Action:
"""Return a valid action object with type NEW_TAB."""
action = create_none_action()
@ -529,7 +515,6 @@ def create_new_tab_action() -> Action:
return action
@beartype
def create_go_back_action() -> Action:
"""Return a valid action object with type GO_BACK."""
action = create_none_action()
@ -541,7 +526,6 @@ def create_go_back_action() -> Action:
return action
@beartype
def create_go_forward_action() -> Action:
"""Return a valid action object with type GO_FORWARD."""
action = create_none_action()
@ -553,7 +537,6 @@ def create_go_forward_action() -> Action:
return action
@beartype
def create_goto_url_action(url: str) -> Action:
"""Return a valid action object with type GOTO_URL."""
action = create_none_action()
@ -566,7 +549,6 @@ def create_goto_url_action(url: str) -> Action:
return action
@beartype
def create_page_close_action() -> Action:
"""Return a valid action object with type PAGE_CLOSE."""
action = create_none_action()
@ -578,7 +560,6 @@ def create_page_close_action() -> Action:
return action
@beartype
def create_mouse_click_action(
left: float | None = None, top: float | None = None
) -> Action:
@ -602,7 +583,6 @@ def create_mouse_click_action(
return action
@beartype
def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
"""Return a valid action object with type TYPE."""
action = create_none_action()
@ -615,7 +595,6 @@ def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
return action
@beartype
def create_click_action(
element_id: str = "",
element_role: RolesType = "link",
@ -637,7 +616,6 @@ def create_click_action(
return action
@beartype
def create_hover_action(
element_id: str = "",
element_role: RolesType = "link",
@ -659,7 +637,6 @@ def create_hover_action(
return action
@beartype
def create_type_action(
text: str,
element_id: str = "",
@ -683,7 +660,6 @@ def create_type_action(
return action
@beartype
def create_check_action(pw_code: str) -> Action:
action = create_none_action()
action.update(
@ -695,7 +671,6 @@ def create_check_action(pw_code: str) -> Action:
return action
@beartype
def create_select_option_action(
pw_code: str,
) -> Action:
@ -709,7 +684,6 @@ def create_select_option_action(
return action
@beartype
def create_focus_action(
element_role: RolesType, element_name: str = "", nth: int = 0
) -> Action:
@ -728,7 +702,6 @@ def create_focus_action(
return action
@beartype
def create_focus_and_click_action(
element_role: RolesType, element_name: str = "", nth: int = 0
) -> Action:
@ -748,7 +721,6 @@ def create_focus_and_click_action(
return action
@beartype
def create_focus_and_type_action(
keys: list[int | str] | str,
element_role: RolesType,
@ -771,7 +743,6 @@ def create_focus_and_type_action(
return action
@beartype
def execute_scroll(direction: str, page: Page) -> None:
# perform the action
# code from natbot
@ -785,7 +756,6 @@ def execute_scroll(direction: str, page: Page) -> None:
)
@beartype
async def aexecute_scroll(direction: str, page: APage) -> None:
# perform the action
# code from natbot
@ -799,19 +769,16 @@ async def aexecute_scroll(direction: str, page: APage) -> None:
)
@beartype
def execute_key_press(key: str, page: Page) -> None:
"""Press a key."""
page.keyboard.press(key)
@beartype
async def aexecute_key_press(key: str, page: APage) -> None:
"""Press a key."""
await page.keyboard.press(key)
@beartype
def execute_mouse_hover(left: float, top: float, page: Page) -> None:
"""Click at coordinates (left, top)."""
viewport_size = page.viewport_size
@ -821,7 +788,6 @@ def execute_mouse_hover(left: float, top: float, page: Page) -> None:
)
@beartype
async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None:
"""Click at coordinates (left, top)."""
viewport_size = page.viewport_size
@ -840,7 +806,6 @@ def execute_mouse_click(left: float, top: float, page: Page) -> None:
)
@beartype
async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
"""Click at coordinates (left, top)."""
viewport_size = page.viewport_size
@ -850,19 +815,16 @@ async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
)
@beartype
def execute_keyboard_type(text: str, page: Page) -> None:
"""Fill the focused element with text."""
page.keyboard.type(text)
@beartype
async def aexecute_keyboard_type(text: str, page: APage) -> None:
"""Fill the focused element with text."""
await page.keyboard.type(text)
@beartype
def execute_click_current(page: Page) -> None:
"""Click at the current mouse position."""
locators = page.locator("*:focus")
@ -874,7 +836,6 @@ def execute_click_current(page: Page) -> None:
locators.click()
@beartype
async def aexecute_click_current(page: APage) -> None:
"""Click at the current mouse position."""
locators = page.locator("*:focus")
@ -889,21 +850,18 @@ async def aexecute_click_current(page: APage) -> None:
await page.wait_for_load_state("load")
@beartype
def execute_type(keys: list[int], page: Page) -> None:
"""Send keystrokes to the focused element."""
text = "".join([_id2key[key] for key in keys])
page.keyboard.type(text)
@beartype
async def aexecute_type(keys: list[int], page: APage) -> None:
"""Send keystrokes to the focused element."""
text = "".join([_id2key[key] for key in keys])
await page.keyboard.type(text)
@beartype
def execute_focus(
element_role: int, element_name: str, nth: int, page: Page
) -> None:
@ -940,7 +898,6 @@ def execute_focus(
element_location_list[nth][0].focus()
@beartype
async def aexecute_focus(
element_role: int, element_name: str, nth: int, page: APage
) -> None:
@ -977,7 +934,6 @@ async def aexecute_focus(
await element_location_list[nth][0].focus()
@beartype
def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
locator = page
for call in locator_calls:
@ -988,7 +944,6 @@ def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
return locator # type: ignore[return-value]
@beartype
async def alocate(
locator_calls: list[ParsedPlaywrightCode], page: APage
) -> ALocator:
@ -1001,7 +956,6 @@ async def alocate(
return locator # type: ignore[return-value]
@beartype
def execute_playwright_click(
locator_code: list[ParsedPlaywrightCode],
page: Page,
@ -1014,7 +968,6 @@ def execute_playwright_click(
locator.click(*pw_action_args, **pw_action_kwargs)
@beartype
async def aexecute_playwright_click(
locator_code: list[ParsedPlaywrightCode],
page: APage,
@ -1027,7 +980,6 @@ async def aexecute_playwright_click(
await locator.click(*pw_action_args, **pw_action_kwargs)
@beartype
def execute_playwright_hover(
locator_code: list[ParsedPlaywrightCode], page: Page
) -> None:
@ -1037,7 +989,6 @@ def execute_playwright_hover(
locator.hover()
@beartype
async def aexecute_playwright_hover(
locator_code: list[ParsedPlaywrightCode], page: APage
) -> None:
@ -1047,7 +998,6 @@ async def aexecute_playwright_hover(
await locator.hover()
@beartype
def execute_playwright_type(
text: str,
locator_code: list[ParsedPlaywrightCode],
@ -1061,7 +1011,6 @@ def execute_playwright_type(
locator.type(*pw_action_args, **pw_action_kwargs)
@beartype
async def aexecute_playwright_type(
text: str,
locator_code: list[ParsedPlaywrightCode],
@ -1075,7 +1024,6 @@ async def aexecute_playwright_type(
await locator.type(*pw_action_args, **pw_action_kwargs)
@beartype
def execute_playwright_select_option(
locator_code: list[ParsedPlaywrightCode],
page: Page,
@ -1087,7 +1035,6 @@ def execute_playwright_select_option(
locator.select_option(*pw_action_args, **pw_action_kwargs)
@beartype
async def aexecute_playwright_select_option(
locator_code: list[ParsedPlaywrightCode],
page: APage,
@ -1099,7 +1046,6 @@ async def aexecute_playwright_select_option(
await locator.select_option(*pw_action_args, **pw_action_kwargs)
@beartype
def execute_playwright_check(
locator_code: list[ParsedPlaywrightCode], page: Page
) -> None:
@ -1108,7 +1054,6 @@ def execute_playwright_check(
locator.check()
@beartype
async def aexecute_playwright_check(
locator_code: list[ParsedPlaywrightCode], page: APage
) -> None:
@ -1117,7 +1062,6 @@ async def aexecute_playwright_check(
await locator.check()
@beartype
def execute_action(
action: Action,
page: Page,
@ -1252,7 +1196,6 @@ def execute_action(
return page
@beartype
async def aexecute_action(
action: Action, page: APage, browser_ctx: ABrowserContext
) -> APage:
@ -1383,7 +1326,6 @@ async def aexecute_action(
return page
@beartype
def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
# extract function calls
if not code.startswith("page."):
@ -1444,14 +1386,12 @@ def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
return parsed_chain
@beartype
class ActionParsingError(Exception):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)
@beartype
def create_playwright_action(playwright_code: str) -> Action:
"""Main function to return individual playwright action"""
# get the last action
@ -1524,7 +1464,6 @@ def create_playwright_action(playwright_code: str) -> Action:
raise ActionParsingError(f"Unknown playwright action {action}")
@beartype
def create_id_based_action(action_str: str) -> Action:
"""Main function to return individual id based action"""
action_str = action_str.strip()

View File

@ -5,7 +5,6 @@ from pathlib import Path
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.async_api import Page, ViewportSize, async_playwright
@ -23,7 +22,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 2048,
@ -46,7 +44,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
self.timeout = timeout
self.viewport_size = viewport_size
@beartype
async def setup(self, config_file: Path | None = None) -> None:
self.context_manager = async_playwright()
self.playwright = await self.context_manager.__aenter__()
@ -73,7 +70,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
if start_url:
await self.page.goto(start_url)
@beartype
async def areset(
self,
*,
@ -104,7 +100,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
{"page": DetachedPage(self.page.url, content)},
)
@beartype
def reset(
self,
*,
@ -120,7 +115,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
def close(self) -> None:
asyncio.run(self.aclose())
@beartype
async def astep(
self, action: Action
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
@ -153,7 +147,6 @@ class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
},
)
@beartype
def step(
self, action: Action
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:

View File

@ -3,7 +3,6 @@ import glob
from itertools import combinations
from pathlib import Path
from beartype import beartype
from playwright.sync_api import sync_playwright
from browser_env.env_config import (
@ -18,7 +17,6 @@ HEADLESS = True
SLOW_MO = 0
@beartype
def is_expired(
storage_state: Path, url: str, keyword: str, url_exact: bool = True
) -> bool:
@ -44,7 +42,6 @@ def is_expired(
return url not in d_url
@beartype
def renew_comb(comb: list[str]) -> None:
context_manager = sync_playwright()
playwright = context_manager.__enter__()
@ -91,7 +88,6 @@ def renew_comb(comb: list[str]) -> None:
context_manager.__exit__()
@beartype
def main() -> None:
sites = ["gitlab", "shopping", "shopping_admin", "reddit"]
urls = [

View File

@ -8,7 +8,6 @@ from typing import Any, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.sync_api import (
@ -39,7 +38,6 @@ class PlaywrightScript:
value: str | None = None # avatar movie, Enter
@beartype
def parse_action(action: str) -> PlaywrightScript:
splitted = action.strip().split(" ")
assert len(splitted) >= 2
@ -73,7 +71,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 8192,
@ -121,7 +118,6 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
self.observation_handler.get_observation_space()
)
@beartype
def setup(self, config_file: Path | None = None) -> None:
self.context_manager = sync_playwright()
self.playwright = self.context_manager.__enter__()
@ -168,23 +164,19 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
client.send("Accessibility.enable")
self.page.client = client # type: ignore
@beartype
def get_page_client(self, page: Page) -> CDPSession:
return page.client # type: ignore
@beartype
def _get_obs(self) -> dict[str, Observation]:
obs = self.observation_handler.get_observation(
self.page, self.get_page_client(self.page)
)
return obs
@beartype
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
metadata = self.observation_handler.get_observation_metadata()
return metadata
@beartype
def reset(
self,
*,
@ -223,12 +215,10 @@ class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
return (observation, info)
@beartype
def save_trace(self, trace_path: str | Path) -> None:
if self.save_trace_enabled:
self.context.tracing.stop(path=trace_path)
@beartype
def close(self) -> None:
if self.reset_finished:
self.context_manager.__exit__()

View File

@ -5,7 +5,6 @@ import re
from pathlib import Path
from typing import Any
from beartype import beartype
from PIL import Image
from agent.prompts import *
@ -35,7 +34,6 @@ HTML_TEMPLATE = """
"""
@beartype
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
@ -63,7 +61,6 @@ def get_render_action(
return action_str
@beartype
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],

View File

@ -5,7 +5,6 @@ from typing import Any, TypedDict, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize
@ -60,7 +59,6 @@ class TextObervationProcessor(ObservationProcessor):
create_empty_metadata()
) # use the store meta data of this observation type
@beartype
def fetch_browser_info(
self,
page: Page,
@ -108,7 +106,6 @@ class TextObervationProcessor(ObservationProcessor):
return info
@beartype
@staticmethod
def get_bounding_client_rect(
client: CDPSession, backend_node_id: str
@ -134,7 +131,6 @@ class TextObervationProcessor(ObservationProcessor):
except Exception as e:
return {"result": {"subtype": "error"}}
@beartype
@staticmethod
def get_element_in_viewport_ratio(
elem_left_bound: float,
@ -167,7 +163,6 @@ class TextObervationProcessor(ObservationProcessor):
ratio = overlap_width * overlap_height / width * height
return ratio
@beartype
def fetch_page_html(
self,
info: BrowserInfo,
@ -323,7 +318,6 @@ class TextObervationProcessor(ObservationProcessor):
return dom_tree
@beartype
@staticmethod
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
"""Parse the html tree into a string text"""
@ -367,7 +361,6 @@ class TextObervationProcessor(ObservationProcessor):
html = dfs(0, 0)
return html, obs_nodes_info
@beartype
def fetch_page_accessibility_tree(
self,
info: BrowserInfo,
@ -487,7 +480,6 @@ class TextObervationProcessor(ObservationProcessor):
return accessibility_tree
@beartype
@staticmethod
def parse_accessibility_tree(
accessibility_tree: AccessibilityTree,
@ -575,7 +567,6 @@ class TextObervationProcessor(ObservationProcessor):
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
return tree_str, obs_nodes_info
@beartype
@staticmethod
def clean_accesibility_tree(tree_str: str) -> str:
"""further clean accesibility tree"""
@ -598,7 +589,6 @@ class TextObervationProcessor(ObservationProcessor):
return "\n".join(clean_lines)
@beartype
def process(self, page: Page, client: CDPSession) -> str:
# get the tab info
open_tabs = page.context.pages
@ -657,7 +647,6 @@ class TextObervationProcessor(ObservationProcessor):
content = f"{tab_title_str}\n\n{content}"
return content
@beartype
def get_element_center(self, element_id: str) -> tuple[float, float]:
node_info = self.obs_nodes_info[element_id]
node_bound = node_info["union_bound"]
@ -705,7 +694,6 @@ class ObservationHandler:
)
self.viewport_size = viewport_size
@beartype
def get_observation_space(self) -> spaces.Dict:
text_space = spaces.Text(
min_length=0,
@ -729,7 +717,6 @@ class ObservationHandler:
return spaces.Dict({"text": text_space, "image": image_space})
@beartype
def get_observation(
self, page: Page, client: CDPSession
) -> dict[str, Observation]:
@ -737,7 +724,6 @@ class ObservationHandler:
image_obs = self.image_processor.process(page, client)
return {"text": text_obs, "image": image_obs}
@beartype
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
return {
"text": self.text_processor.meta_data,

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, TypedDict, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from PIL import Image
@ -14,7 +13,6 @@ class DetachedPage:
content: str # html
@beartype
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
"""Convert png bytes to numpy array

View File

@ -8,8 +8,6 @@ from pathlib import Path
from typing import Any, Tuple, Union
import evaluate # type: ignore[import]
from beartype import beartype
from beartype.door import is_bearable
from playwright.sync_api import CDPSession, Page
from browser_env.actions import Action
@ -26,7 +24,6 @@ from evaluation_harness.helper_functions import (
Trajectory = list[Union[Action, StateInfo]]
@beartype
class Evaluator(object):
def __init__(self, eval_tag: str = "") -> None:
self.eval_tag = eval_tag
@ -43,7 +40,7 @@ class Evaluator(object):
@staticmethod
def get_last_action(trajectory: Trajectory) -> Action:
try:
is_bearable(trajectory[-1], Action)
# is_bearable(trajectory[-1], Action)
last_action = trajectory[-1]
except Exception:
raise ValueError(
@ -55,7 +52,7 @@ class Evaluator(object):
@staticmethod
def get_last_state(trajectory: Trajectory) -> StateInfo:
try:
is_bearable(trajectory[-2], StateInfo)
# is_bearable(trajectory[-2], StateInfo)
last_state = trajectory[-2]
except Exception:
raise ValueError(
@ -65,7 +62,6 @@ class Evaluator(object):
return last_state # type: ignore[return-value]
@beartype
class StringExactEvaluator(Evaluator):
"""Check whether the answer is exactly the same as one of the reference answers"""
@ -95,7 +91,6 @@ class StringExactEvaluator(Evaluator):
return 0.0
@beartype
class StringEvaluator(Evaluator):
"""Check whether the answer is correct with:
exact match: the answer is exactly the same as the reference answer
@ -144,7 +139,6 @@ class StringEvaluator(Evaluator):
return score
@beartype
class StringSoftEvaluator(Evaluator):
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
@ -167,7 +161,6 @@ class StringSoftEvaluator(Evaluator):
return float(rouge["rouge1"])
@beartype
class URLExactEvaluator(Evaluator):
"""Check whether the URL is exactly the same as of the reference URLs"""
@ -205,7 +198,6 @@ class URLExactEvaluator(Evaluator):
raise ValueError(f"Unknown matching rule: {matching_rule}")
@beartype
class HTMLContentExactEvaluator(Evaluator):
"""Check whether the contents appear in the page"""
@ -286,7 +278,6 @@ class HTMLContentExactEvaluator(Evaluator):
######
@beartype
class EvaluatorPartial(Evaluator):
def __init__(self) -> None:
raise NotImplementedError
@ -301,7 +292,6 @@ class EvaluatorPartial(Evaluator):
raise NotImplementedError
@beartype
class URLSoftEvaluator(EvaluatorPartial):
"""Parse the URL and compare the domain and parameters"""
@ -367,7 +357,6 @@ class EvaluatorComb:
return score
@beartype
def evaluator_router(config_file: Path | str) -> EvaluatorComb:
"""Router to get the evaluator class"""
with open(config_file, "r") as f:

View File

@ -4,7 +4,6 @@ from typing import Any
from urllib.parse import urlparse
import requests
from beartype import beartype
from playwright.sync_api import CDPSession, Page
from browser_env.env_config import (
@ -21,7 +20,6 @@ from llms.providers.openai_utils import (
)
@beartype
def shopping_get_auth_token() -> str:
response = requests.post(
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
@ -37,7 +35,6 @@ def shopping_get_auth_token() -> str:
return token
@beartype
def shopping_get_latest_order_url() -> str:
"""Get the latest order url from the shopping website."""
@ -62,7 +59,6 @@ def shopping_get_latest_order_url() -> str:
return order_url
@beartype
def shopping_get_sku_latest_review_author(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
@ -80,7 +76,6 @@ def shopping_get_sku_latest_review_author(sku: str) -> str:
return author
@beartype
def shopping_get_sku_latest_review_rating(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
@ -99,7 +94,6 @@ def shopping_get_sku_latest_review_rating(sku: str) -> str:
return rating
@beartype
def reddit_get_post_url(url: str) -> str:
"""Get the post url"""
# Url is http://domain/f/subreddit/post_id/...
@ -118,7 +112,6 @@ def reddit_get_post_url(url: str) -> str:
return post_url
@beartype
def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
# get the account index
try:
@ -150,7 +143,6 @@ def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
return role
@beartype
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
"""Check whether the prediction matches the reference with GPT-3.5"""
messages: list[dict[str, Any]] = []

4
run.py
View File

@ -9,7 +9,6 @@ import time
from pathlib import Path
import openai
from beartype import beartype
from agent import (
Agent,
@ -144,7 +143,6 @@ def config() -> argparse.Namespace:
return args
@beartype
def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
@ -201,7 +199,6 @@ def early_stop(
return False, ""
@beartype
def test(
args: argparse.Namespace,
agent: Agent | PromptAgent | TeacherForcingAgent,
@ -369,7 +366,6 @@ def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
return unfinished_configs
@beartype
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():

View File

@ -6,7 +6,6 @@ import time
from typing import Dict, Optional, Tuple, Type, Union, cast
import pytest
from beartype import beartype
from playwright.sync_api import Page, expect
from browser_env import (
@ -21,13 +20,11 @@ from browser_env.env_config import *
HEADLESS = False
@beartype
def gen_tmp_storage_state() -> None:
with open(f"scripts/tmp_storage_state.json", "w") as f:
json.dump({"storage_state": ".auth/gitlab_state.json"}, f)
@beartype
def get_observation(
observation_type: str, current_viewport_only: bool
) -> None:

View File

@ -5,7 +5,6 @@ import tempfile
from typing import Callable, Dict, Optional, Tuple, Type, Union, cast
import pytest
from beartype.door import is_bearable
from gymnasium.vector import AsyncVectorEnv
from playwright.sync_api import Page
@ -128,7 +127,7 @@ def test_parallel_script_browser_env() -> None:
]
)
)
assert is_bearable(info["page"].tolist(), list[DetachedPage])
# assert is_bearable(info["page"].tolist(), list[DetachedPage])
assert info["page"][0].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
assert info["page"][1].url == "https://www.rfc-editor.org/rfc/rfc6761.html"
vector_env.close() # type: ignore[no-untyped-call]

View File

@ -6,7 +6,6 @@ from pathlib import Path
from typing import Any
import pytest
from beartype import beartype
from py import test
from agent import Agent, TeacherForcingAgent
@ -249,7 +248,6 @@ def test_html_content_url_comb_success(
assert score == 1.0
@beartype
@pytest.mark.skipif(
IN_GITHUB_ACTIONS, reason="Won't work using the demo sites"
)
@ -273,7 +271,6 @@ def test_func_success(
assert score == 1.0
@beartype
@pytest.mark.skipif(
IN_GITHUB_ACTIONS, reason="Won't work using the demo sites"
)
@ -297,7 +294,6 @@ def test_func_fail(
assert score == 0.0
@beartype
def test_func_url_func_last_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
@ -319,7 +315,6 @@ def test_func_url_func_last_success(
assert score == 1.0
@beartype
def test_func_url_func_page_success(
script_browser_env: ScriptBrowserEnv,
) -> None:

View File

@ -2,8 +2,6 @@ import json
import os
from pathlib import Path
from beartype import beartype
from browser_env import ScriptBrowserEnv
from browser_env.env_config import *
from evaluation_harness.helper_functions import (