webarena/run.py
2023-07-25 00:30:29 -10:00

624 lines
21 KiB
Python

"""Script to run end-to-end evaluation on the benchmark"""
import argparse
import base64
import glob
import io
import json
import logging
import os
import random
import re
import subprocess
import time
from itertools import chain
from pathlib import Path
from typing import Any
import openai
import tiktoken
from beartype import beartype
from PIL import Image
from prompt_toolkit import prompt
from agent import Agent, PromptAgent, TeacherForcingAgent
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ObservationMetadata,
ScriptBrowserEnv,
StateInfo,
action2str,
create_stop_action,
)
from browser_env.actions import is_equivalent
from evaluation_harness import evaluator_router
from llms import lm_config
LOG_FOLDER = "log_files"
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
file_handler = logging.FileHandler(LOG_FILE_NAME)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
# Set the log format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
Trajectory = list[Action | StateInfo]
HTML_TEMPLATE = """
<!DOCTYPE html>
<head>
<style>
pre {{
white-space: pre-wrap;
word-wrap: break-word;
}}
</style>
</head>
<html>
<body>
{body}
</body>
</html>
"""
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
parser.add_argument(
"--render", action="store_true", help="Render the browser"
)
parser.add_argument(
"--slow_mo",
type=int,
default=0,
help="Slow down the browser by the specified amount",
)
parser.add_argument(
"--action_set_tag", default="id_accessibility_tree", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["accessibility_tree", "html", "image"],
default="accessibility_tree",
help="Observation type",
)
parser.add_argument(
"--current_viewport_only",
action="store_true",
help="Only use the current viewport for the observation",
)
parser.add_argument("--viewport_width", type=int, default=1280)
parser.add_argument("--viewport_height", type=int, default=720)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
parser.add_argument("--agent_type", type=str, default="prompt")
parser.add_argument(
"--instruction_path",
type=str,
default="agents/prompts/state_action_agent.json",
)
parser.add_argument(
"--parsing_failure_th",
help="When concesecutive parsing failure exceeds this threshold, the agent will stop",
type=int,
default=3,
)
parser.add_argument(
"--repeating_action_failure_th",
help="When concesecutive repeating action exceeds this threshold, the agent will stop",
type=int,
default=3,
)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
parser.add_argument("--mode", type=str, default="chat")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--context_length", type=int, default=0)
parser.add_argument("--max_tokens", type=int, default=384)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--max_obs_length",
type=int,
help="when not zero, will truncate the observation to this length before feeding to the model",
default=1920,
)
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=1000)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
if (
args.action_set_tag == "id_accessibility_tree"
and args.observation_type != "accessibility_tree"
):
raise ValueError(
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
)
return args
@beartype
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
) -> str:
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
@beartype
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
prompt_constructor: PromptConstructor | None,
) -> str:
"""Generate the text version of the predicted actions to store in action history for prompt use.
May contain hint information to recover from the failures"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
node_content = " ".join(node_content.split()[1:])
action_str = action2str(
action, action_set_tag, node_content
)
else:
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
class RenderHelper(object):
"""Helper class to render text and image observations and meta data in the trajectory"""
def __init__(
self, config_file: str, result_dir: str, action_set_tag: str
) -> None:
with open(config_file, "r") as f:
_config = json.load(f)
_config_str = ""
for k, v in _config.items():
_config_str += f"{k}: {v}\n"
_config_str = f"<pre>{_config_str}</pre>\n"
task_id = _config["task_id"]
self.action_set_tag = action_set_tag
self.render_file = open(
Path(result_dir) / f"render_{task_id}.html", "a+"
)
self.render_file.truncate(0)
# write init template
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
self.render_file.read()
self.render_file.flush()
def render(
self,
action: Action,
state_info: StateInfo,
meta_data: dict[str, Any],
render_screenshot: bool = False,
) -> None:
"""Render the trajectory"""
# text observation
observation = state_info["observation"]
text_obs = observation["text"]
info = state_info["info"]
new_content = f"<h2>New Page</h2>\n"
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
if render_screenshot:
# image observation
img_obs = observation["image"]
image = Image.fromarray(img_obs)
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
byte_io.seek(0)
image_bytes = base64.b64encode(byte_io.read())
image_str = image_bytes.decode("utf-8")
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
# meta data
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
# action
action_str = get_render_action(
action,
info["observation_metadata"],
action_set_tag=self.action_set_tag,
)
# with yellow background
action_str = f"<div class='predict_action'>{action_str}</div>"
new_content += f"{action_str}\n"
# add new content
self.render_file.seek(0)
html = self.render_file.read()
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
html_body += new_content
html = HTML_TEMPLATE.format(body=html_body)
self.render_file.seek(0)
self.render_file.truncate()
self.render_file.write(html)
self.render_file.flush()
def close(self) -> None:
self.render_file.close()
@beartype
def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
"""Check whether need to early stop"""
# reach the max step
num_steps = (len(trajectory) - 1) / 2
if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"
last_k_actions: list[Action]
action_seq: list[Action]
# Case: parsing failure for k times
k = thresholds["parsing_failure"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
if len(last_k_actions) >= k:
if all(
[
action["action_type"] == ActionTypes.NONE
for action in last_k_actions
]
):
return True, f"Failed to parse actions for {k} times"
# Case: same action for k times
k = thresholds["repeating_action"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
action_seq = trajectory[1::2] # type: ignore[assignment]
if len(action_seq) == 0:
return False, ""
last_action: Action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
@beartype
def test(
args: argparse.Namespace,
agent: Agent | PromptAgent,
config_file_list: list[str],
) -> None:
scores = []
max_steps = args.max_steps
early_stop_thresholds = {
"parsing_failure": args.parsing_failure_th,
"repeating_action": args.repeating_action_failure_th,
}
env = ScriptBrowserEnv(
headless=not args.render,
slow_mo=args.slow_mo,
observation_type=args.observation_type,
current_viewport_only=args.current_viewport_only,
viewport_size={
"width": args.viewport_width,
"height": args.viewport_height,
},
save_trace_enabled=args.save_trace_enabled,
sleep_after_execution=args.sleep_after_execution,
)
for config_file in config_file_list:
try:
render_helper = RenderHelper(
config_file, args.result_dir, args.action_set_tag
)
# get intent
with open(config_file) as f:
_c = json.load(f)
intent = _c["intent"]
task_id = _c["task_id"]
logger.info(f"[Config file]: {config_file}")
logger.info(f"[Intent]: {intent}")
agent.reset(config_file)
trajectory: Trajectory = []
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
meta_data = {"action_history": ["None"]}
while True:
early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds
)
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
try:
action = agent.next_action(
trajectory, intent, meta_data=meta_data
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
trajectory.append(action)
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
action_set_tag=args.action_set_tag,
prompt_constructor=agent.prompt_constructor
if isinstance(agent, PromptAgent)
else None,
)
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
meta_data["action_history"].append(action_str)
if action["action_type"] == ActionTypes.STOP:
break
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
if terminated:
# add a action place holder
trajectory.append(create_stop_action(""))
break
evaluator = evaluator_router(config_file)
score = evaluator(
trajectory=trajectory,
config_file=config_file,
page=env.page,
client=env.get_page_client(env.page),
)
scores.append(score)
if score == 1:
logger.info(f"[Result] (PASS) {config_file}")
else:
logger.info(f"[Result] (FAIL) {config_file}")
if args.save_trace_enabled:
env.save_trace(
Path(args.result_dir) / "traces" / f"{task_id}.zip"
)
except openai.error.OpenAIError as e:
logger.info(f"[OpenAI Error] {repr(e)}")
except Exception as e:
logger.info(f"[Unhandled Error] {repr(e)}]")
import traceback
# write to error file
with open(Path(args.result_dir) / "error.txt", "a") as f:
f.write(f"[Config file]: {config_file}\n")
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
# logger.info(f"[Render] {render_helper.render_file.name}")
# subprocess.run(["open", render_helper.render_file.name])
render_helper.close()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config
def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json
to_json.run()
# prepare result dir
result_dir = args.result_dir
if not result_dir:
result_dir = (
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
)
if not Path(result_dir).exists():
Path(result_dir).mkdir(parents=True, exist_ok=True)
args.result_dir = result_dir
logger.info(f"Create result dir: {result_dir}")
if not (Path(result_dir) / "traces").exists():
(Path(result_dir) / "traces").mkdir(parents=True)
# log the log file
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
f.write(f"{LOG_FILE_NAME}\n")
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
result_files = glob.glob(f"{result_dir}/*.html")
task_ids = [
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
]
unfinished_configs = []
for config_file in config_files:
task_id = os.path.basename(config_file).split(".")[0]
if task_id not in task_ids:
unfinished_configs.append(config_file)
return unfinished_configs
@beartype
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():
with open(config_file, "w") as f:
json.dump(vars(args), f, indent=4)
logger.info(f"Dump config to {config_file}")
if __name__ == "__main__":
args = config()
args.sleep_after_execution = 2.5
prepare(args)
test_file_list = []
st_idx = args.test_start_idx
ed_idx = args.test_end_idx
for i in range(st_idx, ed_idx):
test_file_list.append(f"config_files/{i}.json")
test_file_list = get_unfinished(test_file_list, args.result_dir)
print(f"Total {len(test_file_list)} tasks left")
args.render = True
args.render_screenshot = True
args.save_trace_enabled = True
args.current_viewport_only = True
dump_config(args)
agent = construct_agent(args)
test(args, agent, test_file_list)