improve README, re-organize helper functions

This commit is contained in:
alexisxy 2023-08-15 15:27:44 -04:00
parent 5949a4b4ab
commit 6ea72e0d84
11 changed files with 353 additions and 270 deletions

View File

@ -1,17 +1,23 @@
[![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3109/)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/)
<a href="https://github.com/psf/black"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/)
[![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.readthedocs.io)
# WebArena: A Realistic Web Environment for Building Autonomous Agents
[[Website]](https://webarena.dev/)
[[Paper]](https://arxiv.org/pdf/2307.13854.pdf)
<p align="center">
<b>WebArena is a standalone, self-hostable web environment for building autonomous agents</b>
</p>
<p align="center">
<a href="https://www.python.org/downloads/release/python-3109/"><img src="https://img.shields.io/badge/python-3.10-blue.svg" alt="Python 3.10"></a>
<a href="https://pre-commit.com/"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white" alt="pre-commit"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="Code style: black"></a>
<a href="https://mypy-lang.org/"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" alt="Checked with mypy"></a>
<a href="https://beartype.readthedocs.io"><img src="https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg" alt="bear-ified"></a>
</p>
<p align="center">
<a href="https://webarena.dev/">Website</a>
<a href="https://arxiv.org/2307.13854">Paper</a>
</p>
![Overview](media/overview.png)
> WebArena is a standalone, self-hostable web environment for building autonomous agents
> **Note** This README is still under constructions. Stay tuned!
## News
* [8/4/2023] Added the instructions and the docker resources to host your own WebArena Environment. Check out [this page](environment_docker/README.md) for details.
* [7/29/2023] Added [a well commented script](minimal_example.py) to walk through the environment setup.
@ -25,7 +31,7 @@ pip install -e .
# optional, dev only
pip install -e ".[dev]"
mypy --install-types --non-interactive browser_env
mypy --install-types --non-interactive browser_env agents evaluation_harness
pip install pre-commit
pre-commit install
```
@ -33,11 +39,70 @@ pre-commit install
Check out [this script](minimal_example.py) for a quick walkthrough on how to set up the environment and interact with it.
## To Reproduce Our Results
* Setup the `environ` as described in the quick walkthrough
* `python scripts/generate_test_data.py` will generate individual config file for each test example in [config_files](config_files)
* `bash prepare.sh` to obtain the auto-login cookies for all websites
* export OPENAI_API_KEY=your_key
* `python run.py --instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json --test_start_idx 0 --test_end_idx 1 --model gpt-3.5-turbo --result_dir your_result_dir` to run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `your_result_dir/0.html`
1. Configurate the urls for each website, in the following example, we use the demo websites we host as an example. You can replace the URLs with your own websites if you [host your own WebArena environment](./environment_docker/).
```bash
export SHOPPING="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:7770"
export SHOPPING_ADMIN="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:7780/admin"
export REDDIT="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:9999"
export GITLAB="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:8023"
export MAP="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:3000"
export WIKIPEDIA="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing"
export HOMEPAGE="PASS" # this is a placeholder
```
2. Generate config file for each test example
```bash
python scripts/generate_test_data.py
```
You will see `*.json` files generated in [config_files](./config_files) folder. Each file contains the configuration for one test example.
3. Obtain the auto-login cookies for all websites
```
bash prepare.sh
```
4. export `OPENAI_API_KEY=your_key`, a valid OpenAI API key starts with `sk-`
5. Launch the evaluation
```bash
python run.py \
--instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json \ # this is the reasoning agent prompt we used in the paper
--test_start_idx 0 \
--test_end_idx 1 \
--model gpt-3.5-turbo \
--result_dir <your_result_dir>
```
This script will run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `<your_result_dir>/0.html`
## To Develop Your Prompt-based Agent
1. Define the prompts. We provide two baseline agents whose correrponding prompts are listed [here](./agent/prompts/raw). Each prompt is a dictionary with the following keys:
```python
prompt = {
"intro": <The overall guideline which includes the task description, available action, hint and others>,
"examples": [
(
example_1_observation,
example_1_response
),
(
example_2_observation,
example_2_response
),
...
],
"template": <How to organize different information such as observation, previous action, instruction, url>,
"meta_data": {
"observation": <Which observation space the agent uses>,
"action_type": <Which action space the agent uses>,
"keywords": <The keywords used in the template, the program will later enumerate all keywords in the template to see if all of them are correctly replaced with the content>,
"prompt_constructor": <Which prompt construtor is in used, the prompt constructor will construct the input feed to an LLM and extract the action from the generation, more details below>,
"action_splitter": <Inside which splitter can we extract the action, used by the prompt constructor>
}
}
```
2. Implement the prompt constructor. An example prompt constructor using Chain-of-thought/ReAct style reasoning is [here](./agent/prompts/prompt_constructor.py#L184). The prompt constructor is a class with the following methods:
* `construct`: construct the input feed to an LLM
* `_extract_action`: given the generation from an LLM, how to extract the phrase that corresponds to the action
## Citation
If you use our environment or data, please cite our paper:

View File

@ -1 +1,8 @@
from .agent import *
from .agent import (
Agent,
PromptAgent,
TeacherForcingAgent,
construct_agent,
)
__all__ = ["Agent", "TeacherForcingAgent", "PromptAgent", "construct_agent"]

View File

@ -1,10 +1,13 @@
import argparse
import json
from typing import Any
import tiktoken
from beartype import beartype
from beartype.door import is_bearable
from agent.prompts import *
from browser_env import Trajectory
from browser_env.actions import (
Action,
ActionParsingError,
@ -19,11 +22,6 @@ from llms.providers.openai_utils import (
generate_from_openai_completion,
)
from .utils import *
# from llms.providers.openai_utils import generate_from_openai_completion
# from llms.providers.openai_utils import fake_generate_from_openai_chat_completion as generate_from_openai_chat_completion
class Agent:
"""Base class for the agent"""
@ -175,3 +173,44 @@ class PromptAgent(Agent):
def reset(self, test_config_file: str) -> None:
pass
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config
def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent

View File

@ -6,8 +6,7 @@ from typing import Any, TypedDict
import tiktoken
from beartype import beartype
from agent.utils import Trajectory
from browser_env import Action, ActionParsingError
from browser_env import Action, ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo
from llms import lm_config

View File

@ -1,6 +0,0 @@
from typing import Union
from browser_env.actions import Action
from browser_env.utils import StateInfo
Trajectory = list[Union[StateInfo, Action]]

View File

@ -34,6 +34,7 @@ from .actions import (
from .async_envs import AsyncScriptBrowserEnv
from .envs import ScriptBrowserEnv
from .processors import ObservationMetadata
from .trajectory import Trajectory
from .utils import DetachedPage, StateInfo
__all__ = [
@ -71,4 +72,5 @@ __all__ = [
"create_select_option_action",
"create_stop_action",
"ActionParsingError",
"Trajectory",
]

View File

@ -0,0 +1,194 @@
import base64
import io
import json
import re
from pathlib import Path
from typing import Any
from beartype import beartype
from PIL import Image
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ObservationMetadata,
StateInfo,
action2str,
)
HTML_TEMPLATE = """
<!DOCTYPE html>
<head>
<style>
pre {{
white-space: pre-wrap;
word-wrap: break-word;
}}
</style>
</head>
<html>
<body>
{body}
</body>
</html>
"""
@beartype
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
) -> str:
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
@beartype
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
prompt_constructor: PromptConstructor | None,
) -> str:
"""Generate the text version of the predicted actions to store in action history for prompt use.
May contain hint information to recover from the failures"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
node_content = " ".join(node_content.split()[1:])
action_str = action2str(
action, action_set_tag, node_content
)
else:
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
class RenderHelper(object):
"""Helper class to render text and image observations and meta data in the trajectory"""
def __init__(
self, config_file: str, result_dir: str, action_set_tag: str
) -> None:
with open(config_file, "r") as f:
_config = json.load(f)
_config_str = ""
for k, v in _config.items():
_config_str += f"{k}: {v}\n"
_config_str = f"<pre>{_config_str}</pre>\n"
task_id = _config["task_id"]
self.action_set_tag = action_set_tag
self.render_file = open(
Path(result_dir) / f"render_{task_id}.html", "a+"
)
self.render_file.truncate(0)
# write init template
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
self.render_file.read()
self.render_file.flush()
def render(
self,
action: Action,
state_info: StateInfo,
meta_data: dict[str, Any],
render_screenshot: bool = False,
) -> None:
"""Render the trajectory"""
# text observation
observation = state_info["observation"]
text_obs = observation["text"]
info = state_info["info"]
new_content = f"<h2>New Page</h2>\n"
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
if render_screenshot:
# image observation
img_obs = observation["image"]
image = Image.fromarray(img_obs)
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
byte_io.seek(0)
image_bytes = base64.b64encode(byte_io.read())
image_str = image_bytes.decode("utf-8")
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
# meta data
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
# action
action_str = get_render_action(
action,
info["observation_metadata"],
action_set_tag=self.action_set_tag,
)
# with yellow background
action_str = f"<div class='predict_action'>{action_str}</div>"
new_content += f"{action_str}\n"
# add new content
self.render_file.seek(0)
html = self.render_file.read()
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
html_body += new_content
html = HTML_TEMPLATE.format(body=html_body)
self.render_file.seek(0)
self.render_file.truncate()
self.render_file.write(html)
self.render_file.flush()
def close(self) -> None:
self.render_file.close()

View File

@ -0,0 +1,6 @@
from typing import Union
from .actions import Action
from .utils import StateInfo
Trajectory = list[Union[StateInfo, Action]]

View File

@ -4,22 +4,22 @@ app = Flask(__name__)
@app.route("/")
def index():
def index() -> str:
return render_template("index.html")
@app.route("/scratchpad.html")
def scratchpad():
def scratchpad() -> str:
return render_template("scratchpad.html")
@app.route("/calculator.html")
def calculator():
def calculator() -> str:
return render_template("calculator.html")
@app.route("/password.html")
def password():
def password() -> str:
return render_template("password.html")

View File

@ -51,8 +51,6 @@ print("Done generating config files with the correct URLs")
subprocess.run(["bash", "prepare.sh"])
print("Done saving account cookies")
from agent.utils import Trajectory
# Init an environment
from browser_env import (
Action,
@ -60,6 +58,7 @@ from browser_env import (
ObservationMetadata,
ScriptBrowserEnv,
StateInfo,
Trajectory,
action2str,
create_id_based_action,
create_stop_action,

246
run.py
View File

@ -1,39 +1,37 @@
"""Script to run end-to-end evaluation on the benchmark"""
import argparse
import base64
import glob
import io
import json
import logging
import os
import random
import re
import subprocess
import time
from itertools import chain
from pathlib import Path
from typing import Any
import openai
import tiktoken
from beartype import beartype
from PIL import Image
from prompt_toolkit import prompt
from agent import Agent, PromptAgent, TeacherForcingAgent
from agent import (
Agent,
PromptAgent,
TeacherForcingAgent,
construct_agent,
)
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ObservationMetadata,
ScriptBrowserEnv,
StateInfo,
action2str,
Trajectory,
create_stop_action,
)
from browser_env.actions import is_equivalent
from browser_env.helper_functions import (
RenderHelper,
get_action_description,
)
from evaluation_harness import evaluator_router
from llms import lm_config
LOG_FOLDER = "log_files"
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
@ -55,24 +53,6 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
Trajectory = list[Action | StateInfo]
HTML_TEMPLATE = """
<!DOCTYPE html>
<head>
<style>
pre {{
white-space: pre-wrap;
word-wrap: break-word;
}}
</style>
</head>
<html>
<body>
{body}
</body>
</html>
"""
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
@ -164,165 +144,6 @@ def config() -> argparse.Namespace:
return args
@beartype
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
) -> str:
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
@beartype
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
prompt_constructor: PromptConstructor | None,
) -> str:
"""Generate the text version of the predicted actions to store in action history for prompt use.
May contain hint information to recover from the failures"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
node_content = " ".join(node_content.split()[1:])
action_str = action2str(
action, action_set_tag, node_content
)
else:
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
class RenderHelper(object):
"""Helper class to render text and image observations and meta data in the trajectory"""
def __init__(
self, config_file: str, result_dir: str, action_set_tag: str
) -> None:
with open(config_file, "r") as f:
_config = json.load(f)
_config_str = ""
for k, v in _config.items():
_config_str += f"{k}: {v}\n"
_config_str = f"<pre>{_config_str}</pre>\n"
task_id = _config["task_id"]
self.action_set_tag = action_set_tag
self.render_file = open(
Path(result_dir) / f"render_{task_id}.html", "a+"
)
self.render_file.truncate(0)
# write init template
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
self.render_file.read()
self.render_file.flush()
def render(
self,
action: Action,
state_info: StateInfo,
meta_data: dict[str, Any],
render_screenshot: bool = False,
) -> None:
"""Render the trajectory"""
# text observation
observation = state_info["observation"]
text_obs = observation["text"]
info = state_info["info"]
new_content = f"<h2>New Page</h2>\n"
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
if render_screenshot:
# image observation
img_obs = observation["image"]
image = Image.fromarray(img_obs)
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
byte_io.seek(0)
image_bytes = base64.b64encode(byte_io.read())
image_str = image_bytes.decode("utf-8")
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
# meta data
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
# action
action_str = get_render_action(
action,
info["observation_metadata"],
action_set_tag=self.action_set_tag,
)
# with yellow background
action_str = f"<div class='predict_action'>{action_str}</div>"
new_content += f"{action_str}\n"
# add new content
self.render_file.seek(0)
html = self.render_file.read()
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
html_body += new_content
html = HTML_TEMPLATE.format(body=html_body)
self.render_file.seek(0)
self.render_file.truncate()
self.render_file.write(html)
self.render_file.flush()
def close(self) -> None:
self.render_file.close()
@beartype
def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
@ -383,7 +204,7 @@ def early_stop(
@beartype
def test(
args: argparse.Namespace,
agent: Agent | PromptAgent,
agent: Agent | PromptAgent | TeacherForcingAgent,
config_file_list: list[str],
) -> None:
scores = []
@ -504,55 +325,12 @@ def test(
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
# logger.info(f"[Render] {render_helper.render_file.name}")
# subprocess.run(["open", render_helper.render_file.name])
render_helper.close()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config
def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json