mirror of
https://github.com/web-arena-x/webarena.git
synced 2026-02-06 11:16:53 +00:00
624 lines
21 KiB
Python
624 lines
21 KiB
Python
"""Script to run end-to-end evaluation on the benchmark"""
|
|
import argparse
|
|
import base64
|
|
import glob
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import subprocess
|
|
import time
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import openai
|
|
import tiktoken
|
|
from beartype import beartype
|
|
from PIL import Image
|
|
from prompt_toolkit import prompt
|
|
|
|
from agent import Agent, PromptAgent, TeacherForcingAgent
|
|
from agent.prompts import *
|
|
from browser_env import (
|
|
Action,
|
|
ActionTypes,
|
|
ObservationMetadata,
|
|
ScriptBrowserEnv,
|
|
StateInfo,
|
|
action2str,
|
|
create_stop_action,
|
|
)
|
|
from browser_env.actions import is_equivalent
|
|
from evaluation_harness import evaluator_router
|
|
from llms import lm_config
|
|
|
|
LOG_FOLDER = "log_files"
|
|
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
|
|
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
|
|
|
|
logger = logging.getLogger("logger")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.DEBUG)
|
|
logger.addHandler(console_handler)
|
|
|
|
file_handler = logging.FileHandler(LOG_FILE_NAME)
|
|
file_handler.setLevel(logging.DEBUG)
|
|
logger.addHandler(file_handler)
|
|
|
|
# Set the log format
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
console_handler.setFormatter(formatter)
|
|
file_handler.setFormatter(formatter)
|
|
|
|
Trajectory = list[Action | StateInfo]
|
|
HTML_TEMPLATE = """
|
|
<!DOCTYPE html>
|
|
<head>
|
|
<style>
|
|
pre {{
|
|
white-space: pre-wrap;
|
|
word-wrap: break-word;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<html>
|
|
<body>
|
|
{body}
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
|
|
def config() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Run end-to-end evaluation on the benchmark"
|
|
)
|
|
parser.add_argument(
|
|
"--render", action="store_true", help="Render the browser"
|
|
)
|
|
parser.add_argument(
|
|
"--slow_mo",
|
|
type=int,
|
|
default=0,
|
|
help="Slow down the browser by the specified amount",
|
|
)
|
|
parser.add_argument(
|
|
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
|
)
|
|
parser.add_argument(
|
|
"--observation_type",
|
|
choices=["accessibility_tree", "html", "image"],
|
|
default="accessibility_tree",
|
|
help="Observation type",
|
|
)
|
|
parser.add_argument(
|
|
"--current_viewport_only",
|
|
action="store_true",
|
|
help="Only use the current viewport for the observation",
|
|
)
|
|
parser.add_argument("--viewport_width", type=int, default=1280)
|
|
parser.add_argument("--viewport_height", type=int, default=720)
|
|
parser.add_argument("--save_trace_enabled", action="store_true")
|
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
|
|
|
parser.add_argument("--max_steps", type=int, default=30)
|
|
|
|
# agent config
|
|
parser.add_argument("--agent_type", type=str, default="prompt")
|
|
parser.add_argument(
|
|
"--instruction_path",
|
|
type=str,
|
|
default="agents/prompts/state_action_agent.json",
|
|
)
|
|
parser.add_argument(
|
|
"--parsing_failure_th",
|
|
help="When concesecutive parsing failure exceeds this threshold, the agent will stop",
|
|
type=int,
|
|
default=3,
|
|
)
|
|
parser.add_argument(
|
|
"--repeating_action_failure_th",
|
|
help="When concesecutive repeating action exceeds this threshold, the agent will stop",
|
|
type=int,
|
|
default=3,
|
|
)
|
|
|
|
# lm config
|
|
parser.add_argument("--provider", type=str, default="openai")
|
|
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
|
parser.add_argument("--mode", type=str, default="chat")
|
|
parser.add_argument("--temperature", type=float, default=1.0)
|
|
parser.add_argument("--top_p", type=float, default=0.9)
|
|
parser.add_argument("--context_length", type=int, default=0)
|
|
parser.add_argument("--max_tokens", type=int, default=384)
|
|
parser.add_argument("--stop_token", type=str, default=None)
|
|
parser.add_argument(
|
|
"--max_obs_length",
|
|
type=int,
|
|
help="when not zero, will truncate the observation to this length before feeding to the model",
|
|
default=1920,
|
|
)
|
|
|
|
# example config
|
|
parser.add_argument("--test_start_idx", type=int, default=0)
|
|
parser.add_argument("--test_end_idx", type=int, default=1000)
|
|
|
|
# logging related
|
|
parser.add_argument("--result_dir", type=str, default="")
|
|
args = parser.parse_args()
|
|
|
|
# check the whether the action space is compatible with the observation space
|
|
if (
|
|
args.action_set_tag == "id_accessibility_tree"
|
|
and args.observation_type != "accessibility_tree"
|
|
):
|
|
raise ValueError(
|
|
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
|
)
|
|
|
|
return args
|
|
|
|
|
|
@beartype
|
|
def get_render_action(
|
|
action: Action,
|
|
observation_metadata: dict[str, ObservationMetadata],
|
|
action_set_tag: str,
|
|
) -> str:
|
|
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
|
match action_set_tag:
|
|
case "id_accessibility_tree":
|
|
text_meta_data = observation_metadata["text"]
|
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
|
node_content = text_meta_data["obs_nodes_info"][
|
|
action["element_id"]
|
|
]["text"]
|
|
else:
|
|
node_content = "No match found"
|
|
|
|
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
|
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
|
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
|
|
|
case "playwright":
|
|
action_str = action["pw_code"]
|
|
case _:
|
|
raise ValueError(f"Unknown action type {action['action_type']}")
|
|
return action_str
|
|
|
|
|
|
@beartype
|
|
def get_action_description(
|
|
action: Action,
|
|
observation_metadata: dict[str, ObservationMetadata],
|
|
action_set_tag: str,
|
|
prompt_constructor: PromptConstructor | None,
|
|
) -> str:
|
|
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
|
May contain hint information to recover from the failures"""
|
|
|
|
match action_set_tag:
|
|
case "id_accessibility_tree":
|
|
text_meta_data = observation_metadata["text"]
|
|
if action["action_type"] in [
|
|
ActionTypes.CLICK,
|
|
ActionTypes.HOVER,
|
|
ActionTypes.TYPE,
|
|
]:
|
|
action_name = str(action["action_type"]).split(".")[1].lower()
|
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
|
node_content = text_meta_data["obs_nodes_info"][
|
|
action["element_id"]
|
|
]["text"]
|
|
node_content = " ".join(node_content.split()[1:])
|
|
action_str = action2str(
|
|
action, action_set_tag, node_content
|
|
)
|
|
else:
|
|
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
|
else:
|
|
if (
|
|
action["action_type"] == ActionTypes.NONE
|
|
and prompt_constructor is not None
|
|
):
|
|
action_splitter = prompt_constructor.instruction[
|
|
"meta_data"
|
|
]["action_splitter"]
|
|
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
|
else:
|
|
action_str = action2str(action, action_set_tag, "")
|
|
|
|
case "playwright":
|
|
action_str = action["pw_code"]
|
|
|
|
case _:
|
|
raise ValueError(f"Unknown action type {action['action_type']}")
|
|
|
|
return action_str
|
|
|
|
|
|
class RenderHelper(object):
|
|
"""Helper class to render text and image observations and meta data in the trajectory"""
|
|
|
|
def __init__(
|
|
self, config_file: str, result_dir: str, action_set_tag: str
|
|
) -> None:
|
|
with open(config_file, "r") as f:
|
|
_config = json.load(f)
|
|
_config_str = ""
|
|
for k, v in _config.items():
|
|
_config_str += f"{k}: {v}\n"
|
|
_config_str = f"<pre>{_config_str}</pre>\n"
|
|
task_id = _config["task_id"]
|
|
|
|
self.action_set_tag = action_set_tag
|
|
|
|
self.render_file = open(
|
|
Path(result_dir) / f"render_{task_id}.html", "a+"
|
|
)
|
|
self.render_file.truncate(0)
|
|
# write init template
|
|
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
|
self.render_file.read()
|
|
self.render_file.flush()
|
|
|
|
def render(
|
|
self,
|
|
action: Action,
|
|
state_info: StateInfo,
|
|
meta_data: dict[str, Any],
|
|
render_screenshot: bool = False,
|
|
) -> None:
|
|
"""Render the trajectory"""
|
|
# text observation
|
|
observation = state_info["observation"]
|
|
text_obs = observation["text"]
|
|
info = state_info["info"]
|
|
new_content = f"<h2>New Page</h2>\n"
|
|
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
|
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
|
|
|
if render_screenshot:
|
|
# image observation
|
|
img_obs = observation["image"]
|
|
image = Image.fromarray(img_obs)
|
|
byte_io = io.BytesIO()
|
|
image.save(byte_io, format="PNG")
|
|
byte_io.seek(0)
|
|
image_bytes = base64.b64encode(byte_io.read())
|
|
image_str = image_bytes.decode("utf-8")
|
|
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
|
|
|
# meta data
|
|
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
|
|
|
# action
|
|
action_str = get_render_action(
|
|
action,
|
|
info["observation_metadata"],
|
|
action_set_tag=self.action_set_tag,
|
|
)
|
|
# with yellow background
|
|
action_str = f"<div class='predict_action'>{action_str}</div>"
|
|
new_content += f"{action_str}\n"
|
|
|
|
# add new content
|
|
self.render_file.seek(0)
|
|
html = self.render_file.read()
|
|
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
|
html_body += new_content
|
|
|
|
html = HTML_TEMPLATE.format(body=html_body)
|
|
self.render_file.seek(0)
|
|
self.render_file.truncate()
|
|
self.render_file.write(html)
|
|
self.render_file.flush()
|
|
|
|
def close(self) -> None:
|
|
self.render_file.close()
|
|
|
|
|
|
@beartype
|
|
def early_stop(
|
|
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
|
) -> tuple[bool, str]:
|
|
"""Check whether need to early stop"""
|
|
|
|
# reach the max step
|
|
num_steps = (len(trajectory) - 1) / 2
|
|
if num_steps >= max_steps:
|
|
return True, f"Reach max steps {max_steps}"
|
|
|
|
last_k_actions: list[Action]
|
|
action_seq: list[Action]
|
|
|
|
# Case: parsing failure for k times
|
|
k = thresholds["parsing_failure"]
|
|
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
|
if len(last_k_actions) >= k:
|
|
if all(
|
|
[
|
|
action["action_type"] == ActionTypes.NONE
|
|
for action in last_k_actions
|
|
]
|
|
):
|
|
return True, f"Failed to parse actions for {k} times"
|
|
|
|
# Case: same action for k times
|
|
k = thresholds["repeating_action"]
|
|
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
|
action_seq = trajectory[1::2] # type: ignore[assignment]
|
|
|
|
if len(action_seq) == 0:
|
|
return False, ""
|
|
|
|
last_action: Action = action_seq[-1]
|
|
|
|
if last_action["action_type"] != ActionTypes.TYPE:
|
|
if len(last_k_actions) >= k:
|
|
if all(
|
|
[
|
|
is_equivalent(action, last_action)
|
|
for action in last_k_actions
|
|
]
|
|
):
|
|
return True, f"Same action for {k} times"
|
|
|
|
else:
|
|
# check the action sequence
|
|
if (
|
|
sum([is_equivalent(action, last_action) for action in action_seq])
|
|
>= k
|
|
):
|
|
return True, f"Same typing action for {k} times"
|
|
|
|
return False, ""
|
|
|
|
|
|
@beartype
|
|
def test(
|
|
args: argparse.Namespace,
|
|
agent: Agent | PromptAgent,
|
|
config_file_list: list[str],
|
|
) -> None:
|
|
scores = []
|
|
max_steps = args.max_steps
|
|
|
|
early_stop_thresholds = {
|
|
"parsing_failure": args.parsing_failure_th,
|
|
"repeating_action": args.repeating_action_failure_th,
|
|
}
|
|
|
|
env = ScriptBrowserEnv(
|
|
headless=not args.render,
|
|
slow_mo=args.slow_mo,
|
|
observation_type=args.observation_type,
|
|
current_viewport_only=args.current_viewport_only,
|
|
viewport_size={
|
|
"width": args.viewport_width,
|
|
"height": args.viewport_height,
|
|
},
|
|
save_trace_enabled=args.save_trace_enabled,
|
|
sleep_after_execution=args.sleep_after_execution,
|
|
)
|
|
|
|
for config_file in config_file_list:
|
|
try:
|
|
render_helper = RenderHelper(
|
|
config_file, args.result_dir, args.action_set_tag
|
|
)
|
|
|
|
# get intent
|
|
with open(config_file) as f:
|
|
_c = json.load(f)
|
|
intent = _c["intent"]
|
|
task_id = _c["task_id"]
|
|
|
|
logger.info(f"[Config file]: {config_file}")
|
|
logger.info(f"[Intent]: {intent}")
|
|
|
|
agent.reset(config_file)
|
|
trajectory: Trajectory = []
|
|
obs, info = env.reset(options={"config_file": config_file})
|
|
state_info: StateInfo = {"observation": obs, "info": info}
|
|
trajectory.append(state_info)
|
|
|
|
meta_data = {"action_history": ["None"]}
|
|
while True:
|
|
early_stop_flag, stop_info = early_stop(
|
|
trajectory, max_steps, early_stop_thresholds
|
|
)
|
|
|
|
if early_stop_flag:
|
|
action = create_stop_action(f"Early stop: {stop_info}")
|
|
else:
|
|
try:
|
|
action = agent.next_action(
|
|
trajectory, intent, meta_data=meta_data
|
|
)
|
|
except ValueError as e:
|
|
# get the error message
|
|
action = create_stop_action(f"ERROR: {str(e)}")
|
|
|
|
trajectory.append(action)
|
|
|
|
action_str = get_action_description(
|
|
action,
|
|
state_info["info"]["observation_metadata"],
|
|
action_set_tag=args.action_set_tag,
|
|
prompt_constructor=agent.prompt_constructor
|
|
if isinstance(agent, PromptAgent)
|
|
else None,
|
|
)
|
|
render_helper.render(
|
|
action, state_info, meta_data, args.render_screenshot
|
|
)
|
|
meta_data["action_history"].append(action_str)
|
|
|
|
if action["action_type"] == ActionTypes.STOP:
|
|
break
|
|
|
|
obs, _, terminated, _, info = env.step(action)
|
|
state_info = {"observation": obs, "info": info}
|
|
trajectory.append(state_info)
|
|
|
|
if terminated:
|
|
# add a action place holder
|
|
trajectory.append(create_stop_action(""))
|
|
break
|
|
|
|
evaluator = evaluator_router(config_file)
|
|
score = evaluator(
|
|
trajectory=trajectory,
|
|
config_file=config_file,
|
|
page=env.page,
|
|
client=env.get_page_client(env.page),
|
|
)
|
|
|
|
scores.append(score)
|
|
|
|
if score == 1:
|
|
logger.info(f"[Result] (PASS) {config_file}")
|
|
else:
|
|
logger.info(f"[Result] (FAIL) {config_file}")
|
|
|
|
if args.save_trace_enabled:
|
|
env.save_trace(
|
|
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
|
)
|
|
|
|
except openai.error.OpenAIError as e:
|
|
logger.info(f"[OpenAI Error] {repr(e)}")
|
|
except Exception as e:
|
|
logger.info(f"[Unhandled Error] {repr(e)}]")
|
|
import traceback
|
|
|
|
# write to error file
|
|
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
|
f.write(f"[Config file]: {config_file}\n")
|
|
f.write(f"[Unhandled Error] {repr(e)}\n")
|
|
f.write(traceback.format_exc()) # write stack trace to file
|
|
|
|
# logger.info(f"[Render] {render_helper.render_file.name}")
|
|
# subprocess.run(["open", render_helper.render_file.name])
|
|
render_helper.close()
|
|
|
|
env.close()
|
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
|
|
|
|
|
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
|
|
llm_config = lm_config.LMConfig(
|
|
provider=args.provider, model=args.model, mode=args.mode
|
|
)
|
|
if args.provider == "openai":
|
|
llm_config.gen_config["temperature"] = args.temperature
|
|
llm_config.gen_config["top_p"] = args.top_p
|
|
llm_config.gen_config["context_length"] = args.context_length
|
|
llm_config.gen_config["max_tokens"] = args.max_tokens
|
|
llm_config.gen_config["stop_token"] = args.stop_token
|
|
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
|
else:
|
|
raise NotImplementedError(f"provider {args.provider} not implemented")
|
|
return llm_config
|
|
|
|
|
|
def construct_agent(args: argparse.Namespace) -> Agent:
|
|
llm_config = construct_llm_config(args)
|
|
|
|
agent: Agent
|
|
if args.agent_type == "teacher_forcing":
|
|
agent = TeacherForcingAgent()
|
|
elif args.agent_type == "prompt":
|
|
with open(args.instruction_path) as f:
|
|
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
|
tokenizer = tiktoken.encoding_for_model(llm_config.model)
|
|
prompt_constructor = eval(constructor_type)(
|
|
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
|
)
|
|
agent = PromptAgent(
|
|
action_set_tag=args.action_set_tag,
|
|
lm_config=llm_config,
|
|
prompt_constructor=prompt_constructor,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"agent type {args.agent_type} not implemented"
|
|
)
|
|
return agent
|
|
|
|
|
|
def prepare(args: argparse.Namespace) -> None:
|
|
# convert prompt python files to json
|
|
from agent.prompts import to_json
|
|
|
|
to_json.run()
|
|
|
|
# prepare result dir
|
|
result_dir = args.result_dir
|
|
if not result_dir:
|
|
result_dir = (
|
|
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
|
)
|
|
if not Path(result_dir).exists():
|
|
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
|
args.result_dir = result_dir
|
|
logger.info(f"Create result dir: {result_dir}")
|
|
|
|
if not (Path(result_dir) / "traces").exists():
|
|
(Path(result_dir) / "traces").mkdir(parents=True)
|
|
|
|
# log the log file
|
|
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
|
f.write(f"{LOG_FILE_NAME}\n")
|
|
|
|
|
|
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
|
result_files = glob.glob(f"{result_dir}/*.html")
|
|
task_ids = [
|
|
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
|
]
|
|
unfinished_configs = []
|
|
for config_file in config_files:
|
|
task_id = os.path.basename(config_file).split(".")[0]
|
|
if task_id not in task_ids:
|
|
unfinished_configs.append(config_file)
|
|
return unfinished_configs
|
|
|
|
|
|
@beartype
|
|
def dump_config(args: argparse.Namespace) -> None:
|
|
config_file = Path(args.result_dir) / "config.json"
|
|
if not config_file.exists():
|
|
with open(config_file, "w") as f:
|
|
json.dump(vars(args), f, indent=4)
|
|
logger.info(f"Dump config to {config_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = config()
|
|
args.sleep_after_execution = 2.5
|
|
prepare(args)
|
|
|
|
test_file_list = []
|
|
st_idx = args.test_start_idx
|
|
ed_idx = args.test_end_idx
|
|
for i in range(st_idx, ed_idx):
|
|
test_file_list.append(f"config_files/{i}.json")
|
|
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
|
print(f"Total {len(test_file_list)} tasks left")
|
|
args.render = True
|
|
args.render_screenshot = True
|
|
args.save_trace_enabled = True
|
|
|
|
args.current_viewport_only = True
|
|
dump_config(args)
|
|
|
|
agent = construct_agent(args)
|
|
test(args, agent, test_file_list)
|