mirror of
https://github.com/web-arena-x/webarena.git
synced 2026-02-06 11:16:53 +00:00
improve README, re-organize helper functions
This commit is contained in:
parent
5949a4b4ab
commit
6ea72e0d84
97
README.md
97
README.md
@ -1,17 +1,23 @@
|
||||
[](https://www.python.org/downloads/release/python-3109/)
|
||||
[](https://pre-commit.com/)
|
||||
<a href="https://github.com/psf/black"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
|
||||
[](https://mypy-lang.org/)
|
||||
[](https://beartype.readthedocs.io)
|
||||
|
||||
# WebArena: A Realistic Web Environment for Building Autonomous Agents
|
||||
[[Website]](https://webarena.dev/)
|
||||
[[Paper]](https://arxiv.org/pdf/2307.13854.pdf)
|
||||
<p align="center">
|
||||
<b>WebArena is a standalone, self-hostable web environment for building autonomous agents</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.python.org/downloads/release/python-3109/"><img src="https://img.shields.io/badge/python-3.10-blue.svg" alt="Python 3.10"></a>
|
||||
<a href="https://pre-commit.com/"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white" alt="pre-commit"></a>
|
||||
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="Code style: black"></a>
|
||||
<a href="https://mypy-lang.org/"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" alt="Checked with mypy"></a>
|
||||
<a href="https://beartype.readthedocs.io"><img src="https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg" alt="bear-ified"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://webarena.dev/">Website</a> •
|
||||
<a href="https://arxiv.org/2307.13854">Paper</a>
|
||||
</p>
|
||||
|
||||

|
||||
> WebArena is a standalone, self-hostable web environment for building autonomous agents
|
||||
|
||||
> **Note** This README is still under constructions. Stay tuned!
|
||||
## News
|
||||
* [8/4/2023] Added the instructions and the docker resources to host your own WebArena Environment. Check out [this page](environment_docker/README.md) for details.
|
||||
* [7/29/2023] Added [a well commented script](minimal_example.py) to walk through the environment setup.
|
||||
@ -25,7 +31,7 @@ pip install -e .
|
||||
|
||||
# optional, dev only
|
||||
pip install -e ".[dev]"
|
||||
mypy --install-types --non-interactive browser_env
|
||||
mypy --install-types --non-interactive browser_env agents evaluation_harness
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
@ -33,11 +39,70 @@ pre-commit install
|
||||
Check out [this script](minimal_example.py) for a quick walkthrough on how to set up the environment and interact with it.
|
||||
|
||||
## To Reproduce Our Results
|
||||
* Setup the `environ` as described in the quick walkthrough
|
||||
* `python scripts/generate_test_data.py` will generate individual config file for each test example in [config_files](config_files)
|
||||
* `bash prepare.sh` to obtain the auto-login cookies for all websites
|
||||
* export OPENAI_API_KEY=your_key
|
||||
* `python run.py --instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json --test_start_idx 0 --test_end_idx 1 --model gpt-3.5-turbo --result_dir your_result_dir` to run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `your_result_dir/0.html`
|
||||
1. Configurate the urls for each website, in the following example, we use the demo websites we host as an example. You can replace the URLs with your own websites if you [host your own WebArena environment](./environment_docker/).
|
||||
```bash
|
||||
export SHOPPING="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:7770"
|
||||
export SHOPPING_ADMIN="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:7780/admin"
|
||||
export REDDIT="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:9999"
|
||||
export GITLAB="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:8023"
|
||||
export MAP="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:3000"
|
||||
export WIKIPEDIA="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing"
|
||||
export HOMEPAGE="PASS" # this is a placeholder
|
||||
```
|
||||
|
||||
2. Generate config file for each test example
|
||||
```bash
|
||||
python scripts/generate_test_data.py
|
||||
```
|
||||
You will see `*.json` files generated in [config_files](./config_files) folder. Each file contains the configuration for one test example.
|
||||
|
||||
3. Obtain the auto-login cookies for all websites
|
||||
```
|
||||
bash prepare.sh
|
||||
```
|
||||
4. export `OPENAI_API_KEY=your_key`, a valid OpenAI API key starts with `sk-`
|
||||
|
||||
5. Launch the evaluation
|
||||
```bash
|
||||
python run.py \
|
||||
--instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json \ # this is the reasoning agent prompt we used in the paper
|
||||
--test_start_idx 0 \
|
||||
--test_end_idx 1 \
|
||||
--model gpt-3.5-turbo \
|
||||
--result_dir <your_result_dir>
|
||||
```
|
||||
This script will run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `<your_result_dir>/0.html`
|
||||
|
||||
## To Develop Your Prompt-based Agent
|
||||
1. Define the prompts. We provide two baseline agents whose correrponding prompts are listed [here](./agent/prompts/raw). Each prompt is a dictionary with the following keys:
|
||||
```python
|
||||
prompt = {
|
||||
"intro": <The overall guideline which includes the task description, available action, hint and others>,
|
||||
"examples": [
|
||||
(
|
||||
example_1_observation,
|
||||
example_1_response
|
||||
),
|
||||
(
|
||||
example_2_observation,
|
||||
example_2_response
|
||||
),
|
||||
...
|
||||
],
|
||||
"template": <How to organize different information such as observation, previous action, instruction, url>,
|
||||
"meta_data": {
|
||||
"observation": <Which observation space the agent uses>,
|
||||
"action_type": <Which action space the agent uses>,
|
||||
"keywords": <The keywords used in the template, the program will later enumerate all keywords in the template to see if all of them are correctly replaced with the content>,
|
||||
"prompt_constructor": <Which prompt construtor is in used, the prompt constructor will construct the input feed to an LLM and extract the action from the generation, more details below>,
|
||||
"action_splitter": <Inside which splitter can we extract the action, used by the prompt constructor>
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. Implement the prompt constructor. An example prompt constructor using Chain-of-thought/ReAct style reasoning is [here](./agent/prompts/prompt_constructor.py#L184). The prompt constructor is a class with the following methods:
|
||||
* `construct`: construct the input feed to an LLM
|
||||
* `_extract_action`: given the generation from an LLM, how to extract the phrase that corresponds to the action
|
||||
|
||||
## Citation
|
||||
If you use our environment or data, please cite our paper:
|
||||
|
||||
@ -1 +1,8 @@
|
||||
from .agent import *
|
||||
from .agent import (
|
||||
Agent,
|
||||
PromptAgent,
|
||||
TeacherForcingAgent,
|
||||
construct_agent,
|
||||
)
|
||||
|
||||
__all__ = ["Agent", "TeacherForcingAgent", "PromptAgent", "construct_agent"]
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env import Trajectory
|
||||
from browser_env.actions import (
|
||||
Action,
|
||||
ActionParsingError,
|
||||
@ -19,11 +22,6 @@ from llms.providers.openai_utils import (
|
||||
generate_from_openai_completion,
|
||||
)
|
||||
|
||||
from .utils import *
|
||||
|
||||
# from llms.providers.openai_utils import generate_from_openai_completion
|
||||
# from llms.providers.openai_utils import fake_generate_from_openai_chat_completion as generate_from_openai_chat_completion
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Base class for the agent"""
|
||||
@ -175,3 +173,44 @@ class PromptAgent(Agent):
|
||||
|
||||
def reset(self, test_config_file: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
|
||||
llm_config = lm_config.LMConfig(
|
||||
provider=args.provider, model=args.model, mode=args.mode
|
||||
)
|
||||
if args.provider == "openai":
|
||||
llm_config.gen_config["temperature"] = args.temperature
|
||||
llm_config.gen_config["top_p"] = args.top_p
|
||||
llm_config.gen_config["context_length"] = args.context_length
|
||||
llm_config.gen_config["max_tokens"] = args.max_tokens
|
||||
llm_config.gen_config["stop_token"] = args.stop_token
|
||||
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
||||
else:
|
||||
raise NotImplementedError(f"provider {args.provider} not implemented")
|
||||
return llm_config
|
||||
|
||||
|
||||
def construct_agent(args: argparse.Namespace) -> Agent:
|
||||
llm_config = construct_llm_config(args)
|
||||
|
||||
agent: Agent
|
||||
if args.agent_type == "teacher_forcing":
|
||||
agent = TeacherForcingAgent()
|
||||
elif args.agent_type == "prompt":
|
||||
with open(args.instruction_path) as f:
|
||||
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
||||
tokenizer = tiktoken.encoding_for_model(llm_config.model)
|
||||
prompt_constructor = eval(constructor_type)(
|
||||
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
||||
)
|
||||
agent = PromptAgent(
|
||||
action_set_tag=args.action_set_tag,
|
||||
lm_config=llm_config,
|
||||
prompt_constructor=prompt_constructor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"agent type {args.agent_type} not implemented"
|
||||
)
|
||||
return agent
|
||||
|
||||
@ -6,8 +6,7 @@ from typing import Any, TypedDict
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
|
||||
from agent.utils import Trajectory
|
||||
from browser_env import Action, ActionParsingError
|
||||
from browser_env import Action, ActionParsingError, Trajectory
|
||||
from browser_env.env_config import URL_MAPPINGS
|
||||
from browser_env.utils import StateInfo
|
||||
from llms import lm_config
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from browser_env.actions import Action
|
||||
from browser_env.utils import StateInfo
|
||||
|
||||
Trajectory = list[Union[StateInfo, Action]]
|
||||
@ -34,6 +34,7 @@ from .actions import (
|
||||
from .async_envs import AsyncScriptBrowserEnv
|
||||
from .envs import ScriptBrowserEnv
|
||||
from .processors import ObservationMetadata
|
||||
from .trajectory import Trajectory
|
||||
from .utils import DetachedPage, StateInfo
|
||||
|
||||
__all__ = [
|
||||
@ -71,4 +72,5 @@ __all__ = [
|
||||
"create_select_option_action",
|
||||
"create_stop_action",
|
||||
"ActionParsingError",
|
||||
"Trajectory",
|
||||
]
|
||||
|
||||
194
browser_env/helper_functions.py
Normal file
194
browser_env/helper_functions.py
Normal file
@ -0,0 +1,194 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env import (
|
||||
Action,
|
||||
ActionTypes,
|
||||
ObservationMetadata,
|
||||
StateInfo,
|
||||
action2str,
|
||||
)
|
||||
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
<style>
|
||||
pre {{
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<html>
|
||||
<body>
|
||||
{body}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@beartype
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
) -> str:
|
||||
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
return action_str
|
||||
|
||||
|
||||
@beartype
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
prompt_constructor: PromptConstructor | None,
|
||||
) -> str:
|
||||
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||
May contain hint information to recover from the failures"""
|
||||
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
|
||||
return action_str
|
||||
|
||||
|
||||
class RenderHelper(object):
|
||||
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||
|
||||
def __init__(
|
||||
self, config_file: str, result_dir: str, action_set_tag: str
|
||||
) -> None:
|
||||
with open(config_file, "r") as f:
|
||||
_config = json.load(f)
|
||||
_config_str = ""
|
||||
for k, v in _config.items():
|
||||
_config_str += f"{k}: {v}\n"
|
||||
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||
task_id = _config["task_id"]
|
||||
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
self.render_file = open(
|
||||
Path(result_dir) / f"render_{task_id}.html", "a+"
|
||||
)
|
||||
self.render_file.truncate(0)
|
||||
# write init template
|
||||
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||
self.render_file.read()
|
||||
self.render_file.flush()
|
||||
|
||||
def render(
|
||||
self,
|
||||
action: Action,
|
||||
state_info: StateInfo,
|
||||
meta_data: dict[str, Any],
|
||||
render_screenshot: bool = False,
|
||||
) -> None:
|
||||
"""Render the trajectory"""
|
||||
# text observation
|
||||
observation = state_info["observation"]
|
||||
text_obs = observation["text"]
|
||||
info = state_info["info"]
|
||||
new_content = f"<h2>New Page</h2>\n"
|
||||
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||
|
||||
if render_screenshot:
|
||||
# image observation
|
||||
img_obs = observation["image"]
|
||||
image = Image.fromarray(img_obs)
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
byte_io.seek(0)
|
||||
image_bytes = base64.b64encode(byte_io.read())
|
||||
image_str = image_bytes.decode("utf-8")
|
||||
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||
|
||||
# meta data
|
||||
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||
|
||||
# action
|
||||
action_str = get_render_action(
|
||||
action,
|
||||
info["observation_metadata"],
|
||||
action_set_tag=self.action_set_tag,
|
||||
)
|
||||
# with yellow background
|
||||
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||
new_content += f"{action_str}\n"
|
||||
|
||||
# add new content
|
||||
self.render_file.seek(0)
|
||||
html = self.render_file.read()
|
||||
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||
html_body += new_content
|
||||
|
||||
html = HTML_TEMPLATE.format(body=html_body)
|
||||
self.render_file.seek(0)
|
||||
self.render_file.truncate()
|
||||
self.render_file.write(html)
|
||||
self.render_file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.render_file.close()
|
||||
6
browser_env/trajectory.py
Normal file
6
browser_env/trajectory.py
Normal file
@ -0,0 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from .actions import Action
|
||||
from .utils import StateInfo
|
||||
|
||||
Trajectory = list[Union[StateInfo, Action]]
|
||||
@ -4,22 +4,22 @@ app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
def index() -> str:
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/scratchpad.html")
|
||||
def scratchpad():
|
||||
def scratchpad() -> str:
|
||||
return render_template("scratchpad.html")
|
||||
|
||||
|
||||
@app.route("/calculator.html")
|
||||
def calculator():
|
||||
def calculator() -> str:
|
||||
return render_template("calculator.html")
|
||||
|
||||
|
||||
@app.route("/password.html")
|
||||
def password():
|
||||
def password() -> str:
|
||||
return render_template("password.html")
|
||||
|
||||
|
||||
|
||||
@ -51,8 +51,6 @@ print("Done generating config files with the correct URLs")
|
||||
subprocess.run(["bash", "prepare.sh"])
|
||||
print("Done saving account cookies")
|
||||
|
||||
from agent.utils import Trajectory
|
||||
|
||||
# Init an environment
|
||||
from browser_env import (
|
||||
Action,
|
||||
@ -60,6 +58,7 @@ from browser_env import (
|
||||
ObservationMetadata,
|
||||
ScriptBrowserEnv,
|
||||
StateInfo,
|
||||
Trajectory,
|
||||
action2str,
|
||||
create_id_based_action,
|
||||
create_stop_action,
|
||||
|
||||
246
run.py
246
run.py
@ -1,39 +1,37 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark"""
|
||||
import argparse
|
||||
import base64
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
from prompt_toolkit import prompt
|
||||
|
||||
from agent import Agent, PromptAgent, TeacherForcingAgent
|
||||
from agent import (
|
||||
Agent,
|
||||
PromptAgent,
|
||||
TeacherForcingAgent,
|
||||
construct_agent,
|
||||
)
|
||||
from agent.prompts import *
|
||||
from browser_env import (
|
||||
Action,
|
||||
ActionTypes,
|
||||
ObservationMetadata,
|
||||
ScriptBrowserEnv,
|
||||
StateInfo,
|
||||
action2str,
|
||||
Trajectory,
|
||||
create_stop_action,
|
||||
)
|
||||
from browser_env.actions import is_equivalent
|
||||
from browser_env.helper_functions import (
|
||||
RenderHelper,
|
||||
get_action_description,
|
||||
)
|
||||
from evaluation_harness import evaluator_router
|
||||
from llms import lm_config
|
||||
|
||||
LOG_FOLDER = "log_files"
|
||||
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
|
||||
@ -55,24 +53,6 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
Trajectory = list[Action | StateInfo]
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
<style>
|
||||
pre {{
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<html>
|
||||
<body>
|
||||
{body}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -164,165 +144,6 @@ def config() -> argparse.Namespace:
|
||||
return args
|
||||
|
||||
|
||||
@beartype
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
) -> str:
|
||||
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
return action_str
|
||||
|
||||
|
||||
@beartype
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
prompt_constructor: PromptConstructor | None,
|
||||
) -> str:
|
||||
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||
May contain hint information to recover from the failures"""
|
||||
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
|
||||
return action_str
|
||||
|
||||
|
||||
class RenderHelper(object):
|
||||
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||
|
||||
def __init__(
|
||||
self, config_file: str, result_dir: str, action_set_tag: str
|
||||
) -> None:
|
||||
with open(config_file, "r") as f:
|
||||
_config = json.load(f)
|
||||
_config_str = ""
|
||||
for k, v in _config.items():
|
||||
_config_str += f"{k}: {v}\n"
|
||||
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||
task_id = _config["task_id"]
|
||||
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
self.render_file = open(
|
||||
Path(result_dir) / f"render_{task_id}.html", "a+"
|
||||
)
|
||||
self.render_file.truncate(0)
|
||||
# write init template
|
||||
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||
self.render_file.read()
|
||||
self.render_file.flush()
|
||||
|
||||
def render(
|
||||
self,
|
||||
action: Action,
|
||||
state_info: StateInfo,
|
||||
meta_data: dict[str, Any],
|
||||
render_screenshot: bool = False,
|
||||
) -> None:
|
||||
"""Render the trajectory"""
|
||||
# text observation
|
||||
observation = state_info["observation"]
|
||||
text_obs = observation["text"]
|
||||
info = state_info["info"]
|
||||
new_content = f"<h2>New Page</h2>\n"
|
||||
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||
|
||||
if render_screenshot:
|
||||
# image observation
|
||||
img_obs = observation["image"]
|
||||
image = Image.fromarray(img_obs)
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
byte_io.seek(0)
|
||||
image_bytes = base64.b64encode(byte_io.read())
|
||||
image_str = image_bytes.decode("utf-8")
|
||||
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||
|
||||
# meta data
|
||||
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||
|
||||
# action
|
||||
action_str = get_render_action(
|
||||
action,
|
||||
info["observation_metadata"],
|
||||
action_set_tag=self.action_set_tag,
|
||||
)
|
||||
# with yellow background
|
||||
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||
new_content += f"{action_str}\n"
|
||||
|
||||
# add new content
|
||||
self.render_file.seek(0)
|
||||
html = self.render_file.read()
|
||||
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||
html_body += new_content
|
||||
|
||||
html = HTML_TEMPLATE.format(body=html_body)
|
||||
self.render_file.seek(0)
|
||||
self.render_file.truncate()
|
||||
self.render_file.write(html)
|
||||
self.render_file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.render_file.close()
|
||||
|
||||
|
||||
@beartype
|
||||
def early_stop(
|
||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
@ -383,7 +204,7 @@ def early_stop(
|
||||
@beartype
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
agent: Agent | PromptAgent,
|
||||
agent: Agent | PromptAgent | TeacherForcingAgent,
|
||||
config_file_list: list[str],
|
||||
) -> None:
|
||||
scores = []
|
||||
@ -504,55 +325,12 @@ def test(
|
||||
f.write(f"[Unhandled Error] {repr(e)}\n")
|
||||
f.write(traceback.format_exc()) # write stack trace to file
|
||||
|
||||
# logger.info(f"[Render] {render_helper.render_file.name}")
|
||||
# subprocess.run(["open", render_helper.render_file.name])
|
||||
render_helper.close()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
|
||||
llm_config = lm_config.LMConfig(
|
||||
provider=args.provider, model=args.model, mode=args.mode
|
||||
)
|
||||
if args.provider == "openai":
|
||||
llm_config.gen_config["temperature"] = args.temperature
|
||||
llm_config.gen_config["top_p"] = args.top_p
|
||||
llm_config.gen_config["context_length"] = args.context_length
|
||||
llm_config.gen_config["max_tokens"] = args.max_tokens
|
||||
llm_config.gen_config["stop_token"] = args.stop_token
|
||||
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
||||
else:
|
||||
raise NotImplementedError(f"provider {args.provider} not implemented")
|
||||
return llm_config
|
||||
|
||||
|
||||
def construct_agent(args: argparse.Namespace) -> Agent:
|
||||
llm_config = construct_llm_config(args)
|
||||
|
||||
agent: Agent
|
||||
if args.agent_type == "teacher_forcing":
|
||||
agent = TeacherForcingAgent()
|
||||
elif args.agent_type == "prompt":
|
||||
with open(args.instruction_path) as f:
|
||||
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
||||
tokenizer = tiktoken.encoding_for_model(llm_config.model)
|
||||
prompt_constructor = eval(constructor_type)(
|
||||
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
||||
)
|
||||
agent = PromptAgent(
|
||||
action_set_tag=args.action_set_tag,
|
||||
lm_config=llm_config,
|
||||
prompt_constructor=prompt_constructor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"agent type {args.agent_type} not implemented"
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def prepare(args: argparse.Namespace) -> None:
|
||||
# convert prompt python files to json
|
||||
from agent.prompts import to_json
|
||||
|
||||
Loading…
Reference in New Issue
Block a user