release commit

This commit is contained in:
alexisxy 2023-07-24 00:17:27 -10:00
commit b454f2dcfd
63 changed files with 36252 additions and 0 deletions

17
.github/workflows/pre-commit.yml vendored Normal file
View File

@ -0,0 +1,17 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.10.9
- uses: pre-commit/action@v3.0.0

36
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,36 @@
name: Python Package Pytest
on: [push]
jobs:
test-all:
runs-on: ubuntu-latest
strategy:
max-parallel: 5
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.10.9
- name: Install dependencies
run: |
pip install -r requirements.txt
playwright install
pip install -e .[dev]
- name: Type-checking package with mypy
run: |
# Manually install mypy in the standard way.
pip --quiet install -U mypy
# Log this mypy version for debuggability.
mypy --version
# Run this mypy instance against our main package.
mypy --install-types --non-interactive .
mypy --strict .
- name: Enviroment prepare
run: |
bash prepare.sh
- name: Test with pytest
run: |
# ignore annotation notebook because it requires a browser
pytest

158
.gitignore vendored Normal file
View File

@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# mac OS
*.DS_Store
.vscode
*tmp*
.auth/*
# local debug
run.sh
# trajectory visualization
render_cache/*
# TMP IGNORE
agent/prompts/jsons/*
log_files/
config_files/*0.json
config_files/*1.json
config_files/*2.json
config_files/*3.json
config_files/*4.json
config_files/*5.json
config_files/*6.json
config_files/*7.json
config_files/*8.json
config_files/*9.json
config_files/test.json

23
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
exclude: '^(agent/prompts/raw)'
args: [--line-length=79]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", --line-length=72]
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
- id: nbstripout

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

35
README.md Normal file
View File

@ -0,0 +1,35 @@
[![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3109/)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/)
<a href="https://github.com/psf/black"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/)
[![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.readthedocs.io)
# WebArena: A Realistic Web Environment for Building Autonomous Agents
[[Website]](https://webarena.dev/)
[[Paper]]()
![Overview](media/overview.png)
> WebArena is a standalone, self-hostable web environment for building autonomous agents. WebArena creates websites from four popular categories with functionality and data mimicking their real-world equivalents. To emulate human problem-solving, WebArena also embeds tools and knowledge resources as independent websites. WebArena introduces a benchmark on interpreting high-level realistic natural language command to concrete web-based interactions. We provide annotated programs designed to programmatically validate the functional correctness of each task.
> **Note** This README is still under constructions. Stay tuned!
## Install
```bash
# Python 3.10+
conda create -n webarena python=3.10; conda activate webarena
pip install -r requirements.txt
playwright install
pip install -e .
# optional, dev only
pip install -e ".[dev]"
mypy --install-types --non-interactive browser_env
pip install pre-commit
pre-commit install
```
## Preperation
* Config the URLs of each website in [env_config](browser_env/env_config.py)
* `python scripts/generate_test_data.py` will generate individual config file for each test example in [config_files](config_files)
* `bash prepare.sh` to obtain the auto-login cookies for all websites
* export OPENAI_API_KEY=your_key
* `python run.py --instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json --test_start_idx 0 --test_end_idx 1 --model gpt-3.5-turbo --result_dir your_result_dir` to run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `your_result_dir/0.html`

1
agent/__init__.py Normal file
View File

@ -0,0 +1 @@
from .agent import *

177
agent/agent.py Normal file
View File

@ -0,0 +1,177 @@
import json
from typing import Any
from beartype import beartype
from beartype.door import is_bearable
from agent.prompts import *
from browser_env.actions import (
Action,
ActionParsingError,
create_id_based_action,
create_none_action,
create_playwright_action,
)
from browser_env.utils import Observation, StateInfo
from llms import lm_config
from llms.providers.openai_utils import (
generate_from_openai_chat_completion,
generate_from_openai_completion,
)
from .utils import *
# from llms.providers.openai_utils import generate_from_openai_completion
# from llms.providers.openai_utils import fake_generate_from_openai_chat_completion as generate_from_openai_chat_completion
class Agent:
"""Base class for the agent"""
def __init__(self, *args: Any) -> None:
pass
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
raise NotImplementedError
def reset(
self,
test_config_file: str,
) -> None:
raise NotImplementedError
class TeacherForcingAgent(Agent):
"""Agent that follows a pre-defined action sequence"""
def __init__(self) -> None:
super().__init__()
@beartype
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def set_actions(self, action_seq: str | list[str]) -> None:
if isinstance(action_seq, str):
action_strs = action_seq.strip().split("\n")
else:
action_strs = action_seq
action_strs = [a.strip() for a in action_strs]
actions = []
for a_str in action_strs:
try:
if self.action_set_tag == "playwright":
cur_action = create_playwright_action(a_str)
elif self.action_set_tag == "id_accessibility_tree":
cur_action = create_id_based_action(a_str)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
except ActionParsingError as e:
cur_action = create_none_action()
cur_action["raw_prediction"] = a_str
actions.append(cur_action)
self.actions: list[Action] = actions
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
return self.actions.pop(0)
@beartype
def reset(
self,
test_config_file: str,
) -> None:
with open(test_config_file) as f:
ref_actions = json.load(f)["reference_action_sequence"]
tag = ref_actions["action_set_tag"]
action_seq = ref_actions["action_sequence"]
self.set_action_set_tag(tag)
self.set_actions(action_seq)
class PromptAgent(Agent):
"""prompt-based agent that emits action given the history"""
def __init__(
self,
action_set_tag: str,
lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor,
) -> None:
super().__init__()
self.lm_config = lm_config
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
@beartype
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
) -> Action:
prompt = self.prompt_constructor.construct(
trajectory, intent, meta_data
)
lm_config = self.lm_config
if lm_config.provider == "openai":
if lm_config.mode == "chat":
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)
try:
parsed_response = self.prompt_constructor.extract_action(response)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
else:
raise ValueError(f"Unknown action type {self.action_set_tag}")
action["raw_prediction"] = response
except ActionParsingError as e:
action = create_none_action()
action["raw_prediction"] = response
return action
def reset(self, test_config_file: str) -> None:
pass

2
agent/prompts/README.md Normal file
View File

@ -0,0 +1,2 @@
## Naming of the prompt files
`description.action_space.observation_space.json`

View File

@ -0,0 +1 @@
from .prompt_constructor import *

View File

@ -0,0 +1,240 @@
import json
import re
from pathlib import Path
from typing import Any, TypedDict
import tiktoken
from beartype import beartype
from agent.utils import Trajectory
from browser_env import Action, ActionParsingError
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo
from llms import lm_config
APIInput = str | list[Any] | dict[str, Any]
class Instruction(TypedDict):
"""Instruction for constructing prompt"""
intro: str
examples: list[tuple[str, str]]
template: str
meta_data: dict[str, Any]
class PromptConstructor(object):
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
):
self.instrction_path = Path(instruction_path)
self.obs_modality = "text"
self.lm_config = lm_config
instruction = json.load(open(self.instrction_path))
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
self.instruction: Instruction = instruction
self.tokenizer = tokenizer
@beartype
def get_lm_api_input(
self, intro: str, examples: list[tuple[str, str]], current: str
) -> APIInput:
"""Return the require format for an API"""
message: list[dict[str, str]] | str
if "openai" in self.lm_config.provider:
if self.lm_config.mode == "chat":
message = [{"role": "system", "content": intro}]
for (x, y) in examples:
message.append(
{
"role": "system",
"name": "example_user",
"content": x,
}
)
message.append(
{
"role": "system",
"name": "example_assistant",
"content": y,
}
)
message.append({"role": "user", "content": current})
return message
elif self.lm_config.mode == "completion":
message = f"{intro}\n\n"
message += "Here are a few examples:\n"
for example in examples:
message += f"Observation\n:{example[0]}\n\n"
message += f"Action: {example[1]}\n\n"
message += "Now make prediction given the observation\n\n"
message += f"Observation\n:{current}\n\n"
message += "Action:"
return message
else:
raise ValueError(
f"OpenAI models do not support mode {self.lm_config.mode}"
)
else:
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
)
@beartype
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
raise NotImplementedError
@beartype
def map_url_to_real(self, url: str) -> str:
"""Map the urls to their real world counterparts"""
for i, j in URL_MAPPINGS.items():
if i in url:
url = url.replace(i, j)
return url
@beartype
def map_url_to_local(self, url: str) -> str:
"""Map the urls to their local counterparts"""
for i, j in URL_MAPPINGS.items():
if j in url:
url = url.replace(j, i)
return url
@beartype
def _extract_action(self, response: str) -> str:
raise NotImplementedError
@beartype
def extract_action(self, response: str) -> str:
response = self._extract_action(response)
response = self.map_url_to_local(response)
return response
class DirectPromptConstructor(PromptConstructor):
"""The agent will direct predict the action"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
):
super().__init__(instruction_path, lm_config, tokenizer)
@beartype
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
"""Construct prompt given the trajectory"""
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
keywords = self.instruction["meta_data"]["keywords"]
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
# input x
current = template.format(
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)
# make sure all keywords are replaced
assert all([f"{{k}}" not in current for k in keywords])
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
@beartype
def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
else:
raise ActionParsingError(
f"Cannot parse action from response {response}"
)
class CoTPromptConstructor(PromptConstructor):
"""The agent will perform step-by-step reasoning before the answer"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
@beartype
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
keywords = self.instruction["meta_data"]["keywords"]
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
current = template.format(
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)
assert all([f"{{k}}" not in current for k in keywords])
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
@beartype
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
else:
raise ActionParsingError(
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
)

View File

@ -0,0 +1,82 @@
prompt = {
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
Here's the information you'll have:
The user's objective: This is the task you're trying to complete.
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
The current web page's URL: This is the page you're currently navigating.
The open tabs: These are the tabs you have open.
The previous action: This is the action you just performed. It may be helpful to track your progress.
The actions you can perform fall into several categories:
Page Operation Actions:
`click [id]`: This action clicks on an element with a specific id on the webpage.
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
`hover [id]`: Hover over an element with id.
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
`scroll [direction=down|up]`: Scroll the page up or down.
Tab Management Actions:
`new_tab`: Open a new, empty browser tab.
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
`close_tab`: Close the currently active tab.
URL Navigation Actions:
`goto [url]`: Navigate to a specific URL.
`go_back`: Navigate to the previously viewed page.
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).
Completion Action:
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. If you believe the task is impossible to complete, provide the answer as "N/A" in the bracket.
Homepage:
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.
To be successful, it is very important to follow the following rules:
1. You should only issue an action that is valid given the current observation
2. You should only issue one action at a time.
3. You should follow the examples to reason step by step and then issue the next action.
4. Generate the action in the correct format. Start with a "In summary, the next action I will perform is" phrase, followed by action inside ``````. For example, "In summary, the next action I will perform is ```click [1234]```".
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
"examples": [
(
"""OBSERVATION:
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
[1749] StaticText '$279.49'
[1757] button 'Add to Cart'
[1760] button 'Add to Wish List'
[1761] button 'Add to Compare'
URL: http://onestopmarket.com/office-products/office-electronics.html
OBJECTIVE: What is the price of HP Inkjet Fax Machine
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```",
),
(
"""OBSERVATION:
[164] textbox 'Search' focused: True required: False
[171] button 'Go'
[174] link 'Find directions between two points'
[212] heading 'Search Results'
[216] button 'Close'
URL: http://openstreetmap.org
OBJECTIVE: Show me the restaurants near CMU
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page has a search box whose ID is [164]. According to the nominatim rule of openstreetmap, I can search for the restaurants near a location by \"restaurants near\". I can submit my typing by pressing the Enter afterwards. In summary, the next action I will perform is ```type [164] [restaurants near CMU] [1]```",
),
],
"template": """OBSERVATION:
{observation}
URL: {url}
OBJECTIVE: {objective}
PREVIOUS ACTION: {previous_action}""",
"meta_data": {
"observation": "accessibility_tree",
"action_type": "id_accessibility_tree",
"keywords": ["url", "objective", "observation", "previous_action"],
"prompt_constructor": "CoTPromptConstructor",
"answer_phrase": "In summary, the next action I will perform is",
"action_splitter": "```"
},
}

View File

@ -0,0 +1,80 @@
prompt = {
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
Here's the information you'll have:
The user's objective: This is the task you're trying to complete.
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
The current web page's URL: This is the page you're currently navigating.
The open tabs: These are the tabs you have open.
The previous action: This is the action you just performed. It may be helpful to track your progress.
The actions you can perform fall into several categories:
Page Operation Actions:
`click [id]`: This action clicks on an element with a specific id on the webpage.
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
`hover [id]`: Hover over an element with id.
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
`scroll [direction=down|up]`: Scroll the page up or down.
Tab Management Actions:
`new_tab`: Open a new, empty browser tab.
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
`close_tab`: Close the currently active tab.
URL Navigation Actions:
`goto [url]`: Navigate to a specific URL.
`go_back`: Navigate to the previously viewed page.
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).
Completion Action:
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. If you believe the task is impossible to complete, provide the answer as "N/A" in the bracket.
Homepage:
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.
To be successful, it is very important to follow the following rules:
1. You should only issue an action that is valid given the current observation
2. You should only issue one action at a time.
3. Generate the action in the correct format. Always put the action inside a pair of ```. For example, ```click [1234]```.
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
"examples": [
(
"""OBSERVATION:
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
[1749] StaticText '$279.49'
[1757] button 'Add to Cart'
[1760] button 'Add to Wish List'
[1761] button 'Add to Compare'
URL: http://onestopmarket.com/office-products/office-electronics.html
OBJECTIVE: What is the price of HP Inkjet Fax Machine
PREVIOUS ACTION: None""",
"```stop [$279.49]```",
),
(
"""OBSERVATION:
[164] textbox 'Search' focused: True required: False
[171] button 'Go'
[174] link 'Find directions between two points'
[212] heading 'Search Results'
[216] button 'Close'
URL: http://openstreetmap.org
OBJECTIVE: Show me the restaurants near CMU
PREVIOUS ACTION: None""",
"```type [164] [restaurants near CMU] [1]```",
),
],
"template": """OBSERVATION:
{observation}
URL: {url}
OBJECTIVE: {objective}
PREVIOUS ACTION: {previous_action}""",
"meta_data": {
"observation": "accessibility_tree",
"action_type": "id_accessibility_tree",
"keywords": ["url", "objective", "observation", "previous_action"],
"prompt_constructor": "DirectPromptConstructor",
"action_splitter": "```"
},
}

21
agent/prompts/to_json.py Normal file
View File

@ -0,0 +1,21 @@
import glob
import importlib
import json
import os
# use the current directory as the root
def run() -> None:
"""Convert all python files in agent/prompts to json files in agent/prompts/jsons
Python files are easiser to edit
"""
for p_file in glob.glob(f"agent/prompts/raw/*.py"):
# import the file as a module
base_name = os.path.basename(p_file).replace(".py", "")
module = importlib.import_module(f"agent.prompts.raw.{base_name}")
prompt = module.prompt
# save the prompt as a json file
with open(f"agent/prompts/jsons/{base_name}.json", "w+") as f:
json.dump(prompt, f, indent=2)
print(f"Done convert python files to json")

6
agent/utils.py Normal file
View File

@ -0,0 +1,6 @@
from typing import Union
from browser_env.actions import Action
from browser_env.utils import StateInfo
Trajectory = list[Union[StateInfo, Action]]

74
browser_env/__init__.py Normal file
View File

@ -0,0 +1,74 @@
import asyncio
from .actions import (
Action,
ActionParsingError,
ActionTypes,
action2create_function,
action2str,
create_check_action,
create_click_action,
create_focus_and_click_action,
create_focus_and_type_action,
create_go_back_action,
create_go_forward_action,
create_goto_url_action,
create_hover_action,
create_id_based_action,
create_key_press_action,
create_keyboard_type_action,
create_mouse_click_action,
create_mouse_hover_action,
create_new_tab_action,
create_none_action,
create_page_close_action,
create_page_focus_action,
create_playwright_action,
create_random_action,
create_scroll_action,
create_select_option_action,
create_stop_action,
create_type_action,
is_equivalent,
)
from .async_envs import AsyncScriptBrowserEnv
from .envs import ScriptBrowserEnv
from .processors import ObservationMetadata
from .utils import DetachedPage, StateInfo
__all__ = [
"ScriptBrowserEnv",
"AsyncScriptBrowserEnv",
"DetachedPage",
"StateInfo",
"ObservationMetadata",
"Action",
"ActionTypes",
"action2str",
"create_random_action",
"create_focus_and_click_action",
"create_focus_and_type_action",
"is_equivalent",
"create_mouse_click_action",
"create_mouse_hover_action",
"create_none_action",
"create_keyboard_type_action",
"create_page_focus_action",
"create_new_tab_action",
"create_go_back_action",
"create_go_forward_action",
"create_goto_url_action",
"create_page_close_action",
"action2create_function",
"create_playwright_action",
"create_id_based_action",
"create_scroll_action",
"create_key_press_action",
"create_check_action",
"create_click_action",
"create_type_action",
"create_hover_action",
"create_select_option_action",
"create_stop_action",
"ActionParsingError",
]

1610
browser_env/actions.py Normal file

File diff suppressed because it is too large Load Diff

160
browser_env/async_envs.py Normal file
View File

@ -0,0 +1,160 @@
import asyncio
import json
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.async_api import Page, ViewportSize, async_playwright
from .actions import Action, aexecute_action, get_action_space
from .utils import DetachedPage, png_bytes_to_numpy
class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
"""
The goal of this environment is to produce a prototype of a browser environment.
In the end, we want to support a fully configurable browser environment with wide
range of action spaces and observation spaces, both structured and unstructured.
But in this prototype, we just support action space specified by Playwright script,
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 2048,
headless: bool = True,
slow_mo: int = 0,
timeout: int = 30000,
viewport_size: ViewportSize = {"width": 1280, "height": 720},
):
self.observation_space = Box(
0,
255,
(viewport_size["height"], viewport_size["width"], 4),
np.uint8,
)
# TODO: make Space[Action] = ActionSpace
self.action_space = get_action_space() # type: ignore[assignment]
self.headless = headless
self.slow_mo = slow_mo
self.reset_finished = False
self.timeout = timeout
self.viewport_size = viewport_size
@beartype
async def setup(self, config_file: Path | None = None) -> None:
self.context_manager = async_playwright()
self.playwright = await self.context_manager.__aenter__()
self.browser = await self.playwright.chromium.launch(
headless=self.headless, slow_mo=self.slow_mo
)
if config_file:
with open(config_file, "r") as f:
instance_config = json.load(f)
else:
instance_config = {}
storage_state = instance_config.get("storage_state", None)
start_url = instance_config.get("start_url", None)
geolocation = instance_config.get("geolocation", None)
self.context = await self.browser.new_context(
viewport=self.viewport_size,
storage_state=storage_state,
geolocation=geolocation,
device_scale_factor=1,
)
self.page = await self.context.new_page()
if start_url:
await self.page.goto(start_url)
@beartype
async def areset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
"""
Reset the environment.
:param options: options for the environment. The options are:
- storage_state: the path to the storage state file
"""
super().reset(seed=seed, options=options)
if self.reset_finished:
await self.context_manager.__aexit__()
if options is not None and "config_file" in options:
config_file = Path(options["config_file"])
if config_file.exists():
await self.setup(config_file=config_file)
else:
raise ValueError(f"Config state {config_file} does not exist.")
else:
await self.setup()
self.reset_finished = True
content = await self.page.content()
screenshot = png_bytes_to_numpy(await self.page.screenshot())
return (
screenshot,
{"page": DetachedPage(self.page.url, content)},
)
@beartype
def reset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
return asyncio.run(self.areset(seed=seed, options=options))
async def aclose(self) -> None:
if self.reset_finished:
await self.context_manager.__aexit__()
def close(self) -> None:
asyncio.run(self.aclose())
@beartype
async def astep(
self, action: Action
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
if not self.reset_finished:
raise RuntimeError("Call reset first before calling step.")
success = False
fail_error = ""
try:
self.page = await aexecute_action(action, self.page, self.context)
success = True
except Exception as e:
fail_error = str(e)
try:
content = await self.page.content()
screenshot = png_bytes_to_numpy(await self.page.screenshot())
except:
await self.page.wait_for_load_state("load")
content = await self.page.content()
screenshot = png_bytes_to_numpy(await self.page.screenshot())
return (
screenshot,
float(success),
False,
False,
{
"page": DetachedPage(self.page.url, content),
"fail_error": fail_error,
},
)
@beartype
def step(
self, action: Action
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
return asyncio.run(self.astep(action), debug=True)

128
browser_env/auto_login.py Normal file
View File

@ -0,0 +1,128 @@
"""Script to automatically login each website"""
import glob
from itertools import combinations
from pathlib import Path
from beartype import beartype
from playwright.sync_api import sync_playwright
from browser_env.env_config import (
ACCOUNTS,
GITLAB,
REDDIT,
SHOPPING,
SHOPPING_ADMIN,
)
HEADLESS = True
SLOW_MO = 0
@beartype
def is_expired(
storage_state: Path, url: str, keyword: str, url_exact: bool = True
) -> bool:
"""Test whether the cookie is expired"""
if not storage_state.exists():
return True
context_manager = sync_playwright()
playwright = context_manager.__enter__()
browser = playwright.chromium.launch(headless=HEADLESS, slow_mo=SLOW_MO)
context = browser.new_context(storage_state=storage_state)
page = context.new_page()
page.goto(url)
d_url = page.url
content = page.content()
context_manager.__exit__()
if keyword:
return keyword not in content
else:
if url_exact:
return d_url != url
else:
return url not in d_url
@beartype
def renew_comb(comb: list[str]) -> None:
context_manager = sync_playwright()
playwright = context_manager.__enter__()
browser = playwright.chromium.launch(headless=HEADLESS)
context = browser.new_context()
page = context.new_page()
if "shopping" in comb:
username = ACCOUNTS["shopping"]["username"]
password = ACCOUNTS["shopping"]["password"]
page.goto(f"{SHOPPING}/customer/account/login/")
page.get_by_label("Email", exact=True).fill(username)
page.get_by_label("Password", exact=True).fill(password)
page.get_by_role("button", name="Sign In").click()
if "reddit" in comb:
username = ACCOUNTS["reddit"]["username"]
password = ACCOUNTS["reddit"]["password"]
page.goto(f"{REDDIT}/login")
page.get_by_label("Username").fill(username)
page.get_by_label("Password").fill(password)
page.get_by_role("button", name="Log in").click()
if "shopping_admin" in comb:
username = ACCOUNTS["shopping_admin"]["username"]
password = ACCOUNTS["shopping_admin"]["password"]
page.goto(f"{SHOPPING_ADMIN}")
page.get_by_placeholder("user name").fill(username)
page.get_by_placeholder("password").fill(password)
page.get_by_role("button", name="Sign in").click()
if "gitlab" in comb:
username = ACCOUNTS["gitlab"]["username"]
password = ACCOUNTS["gitlab"]["password"]
page.goto(f"{GITLAB}/users/sign_in")
page.get_by_test_id("username-field").click()
page.get_by_test_id("username-field").fill(username)
page.get_by_test_id("username-field").press("Tab")
page.get_by_test_id("password-field").fill(password)
page.get_by_test_id("sign-in-button").click()
context.storage_state(path=f"./.auth/{'.'.join(comb)}_state.json")
context_manager.__exit__()
@beartype
def main() -> None:
sites = ["gitlab", "shopping", "shopping_admin", "reddit"]
urls = [
f"{GITLAB}/-/profile",
f"{SHOPPING}/wishlist/",
f"{SHOPPING_ADMIN}/dashboard",
f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account",
]
exact_match = [True, True, True, True]
keywords = ["", "", "Dashboard", "Delete"]
pairs = list(combinations(sites, 2))
for pair in pairs:
# TODO[shuyanzh] auth don't work on these two sites
if "reddit" in pair and (
"shopping" in pair or "shopping_admin" in pair
):
continue
renew_comb(list(sorted(pair)))
for site in sites:
renew_comb([site])
for c_file in glob.glob("./.auth/*.json"):
comb = c_file.split("/")[-1].rsplit("_", 1)[0].split(".")
for cur_site in comb:
url = urls[sites.index(cur_site)]
keyword = keywords[sites.index(cur_site)]
match = exact_match[sites.index(cur_site)]
assert not is_expired(Path(c_file), url, keyword, match)
if __name__ == "__main__":
main()

295
browser_env/constants.py Normal file
View File

@ -0,0 +1,295 @@
from typing import Literal
ROLES = (
"alert",
"alertdialog",
"application",
"article",
"banner",
"blockquote",
"button",
"caption",
"cell",
"checkbox",
"code",
"columnheader",
"combobox",
"complementary",
"contentinfo",
"definition",
"deletion",
"dialog",
"directory",
"document",
"emphasis",
"feed",
"figure",
"form",
"generic",
"grid",
"gridcell",
"group",
"heading",
"img",
"insertion",
"link",
"list",
"listbox",
"listitem",
"log",
"main",
"marquee",
"math",
"meter",
"menu",
"menubar",
"menuitem",
"menuitemcheckbox",
"menuitemradio",
"navigation",
"none",
"note",
"option",
"paragraph",
"presentation",
"progressbar",
"radio",
"radiogroup",
"region",
"row",
"rowgroup",
"rowheader",
"scrollbar",
"search",
"searchbox",
"separator",
"slider",
"spinbutton",
"status",
"strong",
"subscript",
"superscript",
"switch",
"tab",
"table",
"tablist",
"tabpanel",
"term",
"textbox",
"time",
"timer",
"toolbar",
"tooltip",
"tree",
"treegrid",
"treeitem",
)
SPECIAL_LOCATORS = (
"alt_text",
"label",
"placeholder",
)
ASCII_CHARSET = "".join(chr(x) for x in range(32, 128))
FREQ_UNICODE_CHARSET = "".join(chr(x) for x in range(129, 1000))
UTTERANCE_MAX_LENGTH = 8192
ATTRIBUTE_MAX_LENGTH = 256
TEXT_MAX_LENGTH = 256
TYPING_MAX_LENGTH = 64
URL_MAX_LENGTH = 256
MAX_ELEMENT_INDEX_IN_VIEWPORT = 10
MAX_ELEMENT_ID = 1000
MAX_ANSWER_LENGTH = 512
MIN_REF = -1000000
MAX_REF = 1000000
WINDOW_WIDTH = 500
WINDOW_HEIGHT = 240
TASK_WIDTH = 160
TASK_HEIGHT = 210
FLIGHT_WINDOW_WIDTH = 600
FLIGHT_WINDOW_HEIGHT = 700
FLIGHT_TASK_WIDTH = 375
FLIGHT_TASK_HEIGHT = 667
MAX_PAGE_NUMBER = 10
SPECIAL_KEYS = (
"Enter",
"Tab",
"Control",
"Shift",
"Meta",
"Backspace",
"Delete",
"Escape",
"ArrowUp",
"ArrowDown",
"ArrowLeft",
"ArrowRight",
"PageDown",
"PageUp",
"Meta+a",
)
SPECIAL_KEY_MAPPINGS = {
"backquote": "Backquote",
"minus": "Minus",
"equal": "Equal",
"backslash": "Backslash",
"backspace": "Backspace",
"meta": "Meta",
"tab": "Tab",
"delete": "Delete",
"escape": "Escape",
"arrowdown": "ArrowDown",
"end": "End",
"enter": "Enter",
"home": "Home",
"insert": "Insert",
"pagedown": "PageDown",
"pageup": "PageUp",
"arrowright": "ArrowRight",
"arrowup": "ArrowUp",
"f1": "F1",
"f2": "F2",
"f3": "F3",
"f4": "F4",
"f5": "F5",
"f6": "F6",
"f7": "F7",
"f8": "F8",
"f9": "F9",
"f10": "F10",
"f11": "F11",
"f12": "F12",
}
RolesType = Literal[
"alert",
"alertdialog",
"application",
"article",
"banner",
"blockquote",
"button",
"caption",
"cell",
"checkbox",
"code",
"columnheader",
"combobox",
"complementary",
"contentinfo",
"definition",
"deletion",
"dialog",
"directory",
"document",
"emphasis",
"feed",
"figure",
"form",
"generic",
"grid",
"gridcell",
"group",
"heading",
"img",
"insertion",
"link",
"list",
"listbox",
"listitem",
"log",
"main",
"marquee",
"math",
"meter",
"menu",
"menubar",
"menuitem",
"menuitemcheckbox",
"menuitemradio",
"navigation",
"none",
"note",
"option",
"paragraph",
"presentation",
"progressbar",
"radio",
"radiogroup",
"region",
"row",
"rowgroup",
"rowheader",
"scrollbar",
"search",
"searchbox",
"separator",
"slider",
"spinbutton",
"status",
"strong",
"subscript",
"superscript",
"switch",
"tab",
"table",
"tablist",
"tabpanel",
"term",
"textbox",
"time",
"timer",
"toolbar",
"tooltip",
"tree",
"treegrid",
"treeitem",
"alt_text",
"label",
"placeholder",
]
MAX_VANILLA_STR_LENGTH = 1000
PLAYWRIGHT_LOCATORS = (
"get_by_role",
"get_by_text",
"get_by_label",
"get_by_placeholder",
"get_by_alt_text",
"get_by_title",
"get_by_test_id",
"filter",
"frame_locator",
"locator",
)
PLAYWRIGHT_ACTIONS = (
"fill",
"check",
"select_option",
"click",
"hover",
"dclick",
"type",
"focus",
"goto",
"press",
"scroll",
)
IGNORED_ACTREE_PROPERTIES = (
"focusable",
"editable",
"readonly",
"level",
"settable",
"multiline",
"invalid",
)

39
browser_env/env_config.py Normal file
View File

@ -0,0 +1,39 @@
# websites domain
REDDIT = ""
SHOPPING = ""
SHOPPING_ADMIN = ""
GITLAB = ""
WIKIPEDIA = ""
MAP = ""
HOMEPAGE = ""
assert (
REDDIT
and SHOPPING
and SHOPPING_ADMIN
and GITLAB
and WIKIPEDIA
and MAP
and HOMEPAGE
), "Please setup the URLs to each site"
ACCOUNTS = {
"reddit": {"username": "MarvelsGrantMan136", "password": "test1234"},
"gitlab": {"username": "byteblaze", "password": "hello1234"},
"shopping": {
"username": "emma.lopez@gmail.com",
"password": "Password.123",
},
"shopping_admin": {"username": "admin", "password": "admin1234"},
"shopping_site_admin": {"username": "admin", "password": "admin1234"},
}
URL_MAPPINGS = {
REDDIT: "http://reddit.com",
SHOPPING: "http://onestopmarket.com",
SHOPPING_ADMIN: "http://luma.com/admin",
GITLAB: "http://gitlab.com",
WIKIPEDIA: "http://wikipedia.org",
MAP: "http://openstreetmap.org",
HOMEPAGE: "http://homepage.com",
}

274
browser_env/envs.py Normal file
View File

@ -0,0 +1,274 @@
import json
import re
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.sync_api import (
CDPSession,
Page,
Playwright,
ViewportSize,
expect,
sync_playwright,
)
from .actions import Action, execute_action, get_action_space
from .processors import ObservationHandler, ObservationMetadata
from .utils import (
AccessibilityTree,
DetachedPage,
Observation,
png_bytes_to_numpy,
)
@dataclass
class PlaywrightScript:
function: str # goto, get_by_role
destination: str # https://www.google.com/, combobox
name: str | None = None # Search, Avatar 2009
operation: str | None = None # click, fill, press
value: str | None = None # avatar movie, Enter
@beartype
def parse_action(action: str) -> PlaywrightScript:
splitted = action.strip().split(" ")
assert len(splitted) >= 2
match splitted[:2]:
case ["goto", url]:
assert len(splitted) == 2
return PlaywrightScript("goto", url)
case ["get_by_role", destination]:
assert len(splitted) >= 4
match splitted[2:]:
case [name, operation]:
return PlaywrightScript(
"get_by_role", destination, name, operation
)
case [name, operation, value]:
return PlaywrightScript(
"get_by_role", destination, name, operation, value
)
case _:
raise ValueError("Invalid action")
case _:
raise ValueError(f"Invalid action {action}")
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
"""
The goal of this environment is to produce a prototype of a browser environment.
In the end, we want to support a fully configurable browser environment with wide
range of action spaces and observation spaces, both structured and unstructured.
But in this prototype, we just support action space specified by Playwright script,
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 8192,
headless: bool = True,
slow_mo: int = 0,
observation_type: str = "html",
current_viewport_only: bool = False,
viewport_size: ViewportSize = {"width": 1280, "height": 720},
save_trace_enabled: bool = False,
sleep_after_execution: float = 0.0,
):
# TODO: make Space[Action] = ActionSpace
self.action_space = get_action_space() # type: ignore[assignment]
self.headless = headless
self.slow_mo = slow_mo
self.current_viewport_only = current_viewport_only
self.reset_finished = False
self.viewport_size = viewport_size
self.save_trace_enabled = save_trace_enabled
self.sleep_after_execution = sleep_after_execution
match observation_type:
case "html" | "accessibility_tree":
self.text_observation_type = observation_type
self.image_observation_type = ""
self.main_observation_type = "text"
case "image":
self.image_observation_type = observation_type
self.text_observation_type = "" # type: ignore[assignment]
self.main_observation_type = "image"
case _:
raise ValueError(
f"Unsupported observation type: {observation_type}"
)
self.observation_handler = ObservationHandler(
self.main_observation_type,
self.text_observation_type,
self.image_observation_type,
self.current_viewport_only,
self.viewport_size,
)
self.observation_space = (
self.observation_handler.get_observation_space()
)
@beartype
def setup(self, config_file: Path | None = None) -> None:
self.context_manager = sync_playwright()
self.playwright = self.context_manager.__enter__()
self.browser = self.playwright.chromium.launch(
headless=self.headless, slow_mo=self.slow_mo
)
if config_file:
with open(config_file, "r") as f:
instance_config = json.load(f)
else:
instance_config = {}
storage_state = instance_config.get("storage_state", None)
start_url = instance_config.get("start_url", None)
geolocation = instance_config.get("geolocation", None)
self.context = self.browser.new_context(
viewport=self.viewport_size,
storage_state=storage_state,
geolocation=geolocation,
device_scale_factor=1,
)
if self.save_trace_enabled:
self.context.tracing.start(screenshots=True, snapshots=True)
if start_url:
start_urls = start_url.split(" |AND| ")
for url in start_urls:
page = self.context.new_page()
client = page.context.new_cdp_session(
page
) # talk to chrome devtools
if self.text_observation_type == "accessibility_tree":
client.send("Accessibility.enable")
page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
page.goto(url)
# set the first page as the current page
self.page = self.context.pages[0]
self.page.bring_to_front()
else:
self.page = self.context.new_page()
client = self.page.context.new_cdp_session(self.page)
if self.text_observation_type == "accessibility_tree":
client.send("Accessibility.enable")
self.page.client = client # type: ignore
@beartype
def get_page_client(self, page: Page) -> CDPSession:
return page.client # type: ignore
@beartype
def _get_obs(self) -> dict[str, Observation]:
obs = self.observation_handler.get_observation(
self.page, self.get_page_client(self.page)
)
return obs
@beartype
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
metadata = self.observation_handler.get_observation_metadata()
return metadata
@beartype
def reset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Reset the environment.
:param options: options for the environment. The current supported options are:
- "storage_state": the storage state of the browser. It is a file path to a json file.
"""
super().reset(seed=seed, options=options)
if self.reset_finished:
self.context_manager.__exit__()
if options is not None and "config_file" in options:
config_file = Path(options["config_file"])
if config_file.exists():
self.setup(config_file=config_file)
else:
raise ValueError(f"Config file {config_file} does not exist.")
else:
self.setup()
self.reset_finished = True
if self.sleep_after_execution > 0:
time.sleep(self.sleep_after_execution)
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, ""),
"fail_error": "",
"observation_metadata": observation_metadata,
}
return (observation, info)
@beartype
def save_trace(self, trace_path: str | Path) -> None:
if self.save_trace_enabled:
self.context.tracing.stop(path=trace_path)
@beartype
def close(self) -> None:
if self.reset_finished:
self.context_manager.__exit__()
def step(
self, action: Action
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
if not self.reset_finished:
raise RuntimeError("Call reset first before calling step.")
success = False
fail_error = ""
try:
self.page = execute_action(
action,
self.page,
self.context,
self.observation_handler.action_processor,
)
success = True
except Exception as e:
fail_error = str(e)
# hard sleep TODO[shuyanzh] suboptimal, may need to check network
if self.sleep_after_execution > 0:
time.sleep(self.sleep_after_execution)
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, self.page.content()),
"fail_error": fail_error,
"observation_metadata": observation_metadata,
}
msg = (
observation,
float(success), # reward
False, # terminated
False, # truncated
info,
)
return msg

670
browser_env/processors.py Normal file
View File

@ -0,0 +1,670 @@
import json
import re
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, TypedDict, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize
from browser_env.constants import (
ASCII_CHARSET,
FREQ_UNICODE_CHARSET,
IGNORED_ACTREE_PROPERTIES,
UTTERANCE_MAX_LENGTH,
)
from .utils import (
AccessibilityTree,
BrowserConfig,
BrowserInfo,
Observation,
png_bytes_to_numpy,
)
class ObservationProcessor:
def process(self, page: Page, client: CDPSession) -> Observation:
raise NotImplementedError
class ObservationMetadata(TypedDict):
obs_nodes_info: dict[str, Any]
def create_empty_metadata() -> ObservationMetadata:
return {
"obs_nodes_info": {},
}
class TextObervationProcessor(ObservationProcessor):
def __init__(
self,
observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
):
self.observation_type = observation_type
self.current_viewport_only = current_viewport_only
self.viewport_size = viewport_size
self.observation_tag = "text"
self.meta_data = (
create_empty_metadata()
) # use the store meta data of this observation type
@beartype
def fetch_browser_info(
self,
page: Page,
client: CDPSession,
) -> BrowserInfo:
# extract domtree
tree = client.send(
"DOMSnapshot.captureSnapshot",
{
"computedStyles": [],
"includeDOMRects": True,
"includePaintOrder": True,
},
)
# calibrate the bounds, in some cases, the bounds are scaled somehow
bounds = tree["documents"][0]["layout"]["bounds"]
b = bounds[0]
n = b[2] / self.viewport_size["width"]
bounds = [[x / n for x in bound] for bound in bounds]
tree["documents"][0]["layout"]["bounds"] = bounds
# add union bound placeholder
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
# extract browser info
win_upper_bound = page.evaluate("window.pageYOffset")
win_left_bound = page.evaluate("window.pageXOffset")
win_width = page.evaluate("window.screen.width")
win_height = page.evaluate("window.screen.height")
win_right_bound = win_left_bound + win_width
win_lower_bound = win_upper_bound + win_height
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
config: BrowserConfig = {
"win_upper_bound": win_upper_bound,
"win_left_bound": win_left_bound,
"win_width": win_width,
"win_height": win_height,
"win_right_bound": win_right_bound,
"win_lower_bound": win_lower_bound,
"device_pixel_ratio": device_pixel_ratio,
}
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
info: BrowserInfo = {"DOMTree": tree, "config": config}
return info
@beartype
@staticmethod
def partially_in_viewport(
bound: list[float], config: BrowserConfig
) -> bool:
[x, y, width, height] = bound
elem_left_bound = x
elem_top_bound = y
elem_right_bound = x + width
elem_lower_bound = y + height
ok = (
elem_left_bound < config["win_right_bound"]
and elem_right_bound >= config["win_left_bound"]
and elem_top_bound < config["win_lower_bound"]
and elem_lower_bound >= config["win_upper_bound"]
)
return ok
@beartype
def retrieve_viewport_info(self, info: BrowserInfo) -> None:
"""Add viewport related information to the DOMTree
1. add union bound, which is a union of all the bounds of the nodes in the subtree
This is only used when current_viewport_only is enabled since it is quite slow
TODO[robert1003]: improve
"""
tree = info["DOMTree"]
document = tree["documents"][0]
nodes = document["nodes"]
parent = nodes["parentIndex"]
node_names = nodes["nodeName"]
layout = document["layout"]
layout_node_cursor = layout["nodeIndex"]
bounds = layout["bounds"]
graph = defaultdict(lambda: [])
assert len(node_names) == len(parent)
for node_idx in range(len(node_names)):
parent_idx = parent[node_idx]
if parent_idx != -1:
graph[parent_idx].append(node_idx)
union_bounds: list[list[float] | None] = [None for _ in bounds]
def valid_bbox(bound: list[float] | None) -> bool:
if bound is None:
return False
# no width or height
if np.isclose(bound[2], 0):
return False
if np.isclose(bound[3], 0):
return False
return True
def add_union_bound(idx: int) -> list[float] | None:
if idx in layout_node_cursor:
cursor = layout_node_cursor.index(idx)
node_bound = bounds[cursor].copy()
tree_bounds: list[Any] = [node_bound]
for child_idx in graph[idx]:
child_bound = add_union_bound(child_idx)
tree_bounds.append(
child_bound.copy() if child_bound else None
)
tree_bounds = [b for b in tree_bounds if valid_bbox(b)]
# convert to absolute coordinates
for i in range(len(tree_bounds)):
tree_bounds[i][2] = tree_bounds[i][0] + tree_bounds[i][2]
tree_bounds[i][3] = tree_bounds[i][1] + tree_bounds[i][3]
if len(tree_bounds) == 0:
assert not valid_bbox(node_bound)
node_union_bound = [0.0, 0.0, 0.0, 0.0]
else:
left_bound = min([b[0] for b in tree_bounds])
top_bound = min([b[1] for b in tree_bounds])
right_bound = max([b[2] for b in tree_bounds])
bottom_bound = max([b[3] for b in tree_bounds])
node_union_bound = [
left_bound,
top_bound,
right_bound - left_bound,
bottom_bound - top_bound,
]
# update the list
union_bounds[cursor] = node_union_bound
else:
node_union_bound = None
return node_union_bound
add_union_bound(0)
info["DOMTree"]["documents"][0]["layout"]["unionBounds"] = union_bounds
@beartype
def current_viewport_html(self, info: BrowserInfo) -> str:
# adopted from [natbot](https://github.com/nat/natbot)
tree = info["DOMTree"]
strings = tree["strings"]
document = tree["documents"][0]
nodes = document["nodes"]
attributes = nodes["attributes"]
node_value = nodes["nodeValue"]
parent = nodes["parentIndex"]
node_names = nodes["nodeName"]
layout = document["layout"]
layout_node_cursor = layout["nodeIndex"]
union_bounds = layout["unionBounds"]
graph = defaultdict(lambda: [])
for node_idx in range(len(node_names)):
parent_idx = parent[node_idx]
if parent_idx != -1:
graph[parent_idx].append(node_idx)
def dfs(idx: int) -> str:
node_name = strings[node_names[idx]].lower().strip()
can_skip = "#" in node_name or "::" in node_name
inner_text = ""
node_value_idx = node_value[idx]
if node_value_idx >= 0 and node_value_idx < len(strings):
inner_text = " ".join(strings[node_value_idx].split())
node_attributes = [strings[i] for i in attributes[idx]]
node_attributes_str = ""
for i in range(0, len(node_attributes), 2):
a = node_attributes[i]
b = node_attributes[i + 1]
b = " ".join(b.split())
node_attributes_str += f'{a}="{b}" '
node_attributes_str = node_attributes_str.strip()
html = ""
if not can_skip:
html += f"<{node_name}"
if {node_attributes_str}:
html += f" {node_attributes_str}"
html += f">{inner_text}"
else:
html += f"{inner_text}"
for child_idx in graph[idx]:
if child_idx in layout_node_cursor:
cursor = layout_node_cursor.index(child_idx)
union_bound = union_bounds[cursor]
if not self.partially_in_viewport(
union_bound, info["config"]
):
continue
html += dfs(child_idx)
if not can_skip:
html += f"</{node_name}>"
return html
html = dfs(0)
return html
@beartype
def fetch_page_accessibility_tree(
self, info: BrowserInfo, client: CDPSession
) -> AccessibilityTree:
accessibility_tree: AccessibilityTree = client.send(
"Accessibility.getFullAXTree", {}
)["nodes"]
# a few nodes are repeated in the accessibility tree
seen_ids = set()
_accessibility_tree = []
for node in accessibility_tree:
if node["nodeId"] not in seen_ids:
_accessibility_tree.append(node)
seen_ids.add(node["nodeId"])
accessibility_tree = _accessibility_tree
# add the bounding box of each node
tree = info["DOMTree"]
document = tree["documents"][0]
nodes = document["nodes"]
backend_node_id = nodes["backendNodeId"]
node_names = nodes["nodeName"]
layout = document["layout"]
layout_node_cursor = layout["nodeIndex"]
bounds = layout["bounds"]
union_bounds = layout["unionBounds"]
offsetrect_bounds = layout["offsetRects"]
backend_id_to_bound = {}
# get the mapping between backend node id and bounding box
for idx in range(len(node_names)):
if idx not in layout_node_cursor:
continue
cursor = layout_node_cursor.index(idx)
node_bound = bounds[cursor]
node_union_bound = union_bounds[cursor]
node_offsetrect_bound = offsetrect_bounds[cursor]
node_backend_id = backend_node_id[idx]
backend_id_to_bound[node_backend_id] = [
node_bound,
node_union_bound,
node_offsetrect_bound,
]
parent_graph: dict[str, str] = {}
refine_node_ids: list[str] = []
for node in accessibility_tree:
if "parentId" in node:
parent_graph[node["nodeId"]] = node["parentId"]
if "backendDOMNodeId" not in node:
node["bound"] = None
node["union_bound"] = None
node["offsetrect_bound"] = None
elif node["backendDOMNodeId"] not in backend_id_to_bound:
refine_node_ids.append(node["nodeId"])
else:
node["bound"] = backend_id_to_bound[node["backendDOMNodeId"]][
0
]
node["union_bound"] = backend_id_to_bound[
node["backendDOMNodeId"]
][1]
node["offsetrect_bound"] = backend_id_to_bound[
node["backendDOMNodeId"]
][2]
# refine the bounding box for nodes which only appear in the accessibility tree
node_ids = [node["nodeId"] for node in accessibility_tree]
for refine_node_id in refine_node_ids:
child_id = refine_node_id
parent_idx: None | int = None
while child_id in parent_graph:
parent_id = parent_graph[child_id]
parent_idx = node_ids.index(parent_id)
child_id = parent_id
if accessibility_tree[parent_idx]["union_bound"] is not None:
break
refine_node_idx = node_ids.index(refine_node_id)
if parent_idx is not None:
accessibility_tree[refine_node_idx][
"bound"
] = accessibility_tree[parent_idx]["bound"]
accessibility_tree[refine_node_idx][
"union_bound"
] = accessibility_tree[parent_idx]["union_bound"]
accessibility_tree[refine_node_idx][
"offsetrect_bound"
] = accessibility_tree[parent_idx]["offsetrect_bound"]
else:
accessibility_tree[refine_node_idx]["bound"] = None
accessibility_tree[refine_node_idx]["union_bound"] = None
accessibility_tree[refine_node_idx]["offsetrect_bound"] = None
return accessibility_tree
@beartype
def current_viewport_accessibility_tree(
self,
info: BrowserInfo,
accessibility_tree: AccessibilityTree,
) -> AccessibilityTree:
config = info["config"]
subtree = []
for node in accessibility_tree:
if not node["union_bound"]:
continue
[x, y, width, height] = node["union_bound"]
elem_left_bound = x
elem_top_bound = y
elem_right_bound = x + width
elem_lower_bound = y + height
ok = (
elem_left_bound < config["win_right_bound"]
and elem_right_bound >= config["win_left_bound"]
and elem_top_bound < config["win_lower_bound"]
and elem_lower_bound >= config["win_upper_bound"]
)
if ok:
subtree.append(node)
return subtree
@beartype
@staticmethod
def parse_accessibility_tree(
accessibility_tree: AccessibilityTree,
) -> tuple[str, dict[str, Any]]:
"""Parse the accessibility tree into a string text"""
node_id_to_idx = {}
for idx, node in enumerate(accessibility_tree):
node_id_to_idx[node["nodeId"]] = idx
obs_nodes_info = {}
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
tree_str = ""
node = accessibility_tree[idx]
indent = "\t" * depth
valid_node = True
try:
role = node["role"]["value"]
name = node["name"]["value"]
node_str = f"[{obs_node_id}] {role} {repr(name)}"
properties = []
for property in node.get("properties", []):
try:
if property["name"] in IGNORED_ACTREE_PROPERTIES:
continue
properties.append(
f'{property["name"]}: {property["value"]["value"]}'
)
except KeyError:
pass
if properties:
node_str += " " + " ".join(properties)
# check valid
if not node_str.strip():
valid_node = False
# empty generic node
if not name.strip():
if not properties:
if role in [
"generic",
"img",
"list",
"strong",
"paragraph",
"banner",
"navigation",
"Section",
"LabelText",
"Legend",
"listitem",
]:
valid_node = False
elif role in ["listitem"]:
valid_node = False
if valid_node:
tree_str += f"{indent}{node_str}"
obs_nodes_info[obs_node_id] = {
"backend_id": node["backendDOMNodeId"],
"bound": node["bound"],
"union_bound": node["union_bound"],
"offsetrect_bound": node["offsetrect_bound"],
"text": node_str,
}
except Exception as e:
valid_node = False
for _, child_node_id in enumerate(node["childIds"]):
if child_node_id not in node_id_to_idx:
continue
# mark this to save some tokens
child_depth = depth + 1 if valid_node else depth
child_str = dfs(
node_id_to_idx[child_node_id], child_node_id, child_depth
)
if child_str.strip():
if tree_str.strip():
tree_str += "\n"
tree_str += child_str
return tree_str
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
return tree_str, obs_nodes_info
@beartype
@staticmethod
def clean_accesibility_tree(tree_str: str) -> str:
"""further clean accesibility tree"""
clean_lines: list[str] = []
for line in tree_str.split("\n"):
if "statictext" in line.lower():
prev_lines = clean_lines[-3:]
pattern = r"\[\d+\] StaticText '([^']+)'"
match = re.search(pattern, line)
if match:
static_text = match.group(1)
if all(
static_text not in prev_line
for prev_line in prev_lines
):
clean_lines.append(line)
else:
clean_lines.append(line)
return "\n".join(clean_lines)
@beartype
def process(self, page: Page, client: CDPSession) -> str:
# get the tab info
open_tabs = page.context.pages
tab_titles = [tab.title() for tab in open_tabs]
current_tab_idx = open_tabs.index(page)
for idx in range(len(open_tabs)):
if idx == current_tab_idx:
tab_titles[
idx
] = f"Tab {idx} (current): {open_tabs[idx].title()}"
else:
tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
tab_title_str = " | ".join(tab_titles)
try:
browser_info = self.fetch_browser_info(page, client)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page, client)
if self.current_viewport_only:
self.retrieve_viewport_info(browser_info)
if self.observation_type == "html":
if self.current_viewport_only:
html = self.current_viewport_html(browser_info)
content = html
else:
content = page.content()
elif self.observation_type == "accessibility_tree":
accessibility_tree = self.fetch_page_accessibility_tree(
browser_info, client
)
if self.current_viewport_only:
accessibility_tree = self.current_viewport_accessibility_tree(
browser_info, accessibility_tree
)
content, obs_nodes_info = self.parse_accessibility_tree(
accessibility_tree
)
content = self.clean_accesibility_tree(content)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
else:
raise ValueError(
f"Invalid observatrion type: {self.observation_type}"
)
self.browser_config = browser_info["config"]
content = f"{tab_title_str}\n\n{content}"
return content
@beartype
def get_element_center(self, element_id: str) -> tuple[float, float]:
node_info = self.obs_nodes_info[element_id]
node_bound = node_info["bound"]
x, y, width, height = node_bound
browser_config = self.browser_config
b_x, b_y = (
browser_config["win_left_bound"],
browser_config["win_upper_bound"],
)
center_x = (x - b_x) + width / 2
center_y = (y - b_y) + height / 2
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class ImageObservationProcessor(ObservationProcessor):
def __init__(self, observation_type: str):
self.observation_type = observation_type
self.observation_tag = "image"
self.meta_data = create_empty_metadata()
def process(self, page: Page, client: CDPSession) -> npt.NDArray[np.uint8]:
try:
screenshot = png_bytes_to_numpy(page.screenshot())
except:
page.wait_for_event("load")
screenshot = png_bytes_to_numpy(page.screenshot())
return screenshot
class ObservationHandler:
"""Main entry point to access all observation processor"""
def __init__(
self,
main_observation_type: str,
text_observation_type: str,
image_observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
) -> None:
self.main_observation_type = main_observation_type
self.text_processor = TextObervationProcessor(
text_observation_type, current_viewport_only, viewport_size
)
self.image_processor = ImageObservationProcessor(
image_observation_type
)
self.viewport_size = viewport_size
@beartype
def get_observation_space(self) -> spaces.Dict:
text_space = spaces.Text(
min_length=0,
max_length=UTTERANCE_MAX_LENGTH,
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
)
image_space = spaces.Box(
# Each position stores the RGB values. Note the swapped axes (height first).
np.zeros(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
),
np.ones(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
)
* 255.0,
dtype=np.uint8,
)
return spaces.Dict({"text": text_space, "image": image_space})
@beartype
def get_observation(
self, page: Page, client: CDPSession
) -> dict[str, Observation]:
text_obs = self.text_processor.process(page, client)
image_obs = self.image_processor.process(page, client)
return {"text": text_obs, "image": image_obs}
@beartype
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
return {
"text": self.text_processor.meta_data,
"image": self.image_processor.meta_data,
}
@property
def action_processor(self) -> ObservationProcessor:
"""Return the main processor that is associated with the action space"""
if self.main_observation_type == "text":
return self.text_processor
elif self.main_observation_type == "image":
return self.image_processor
else:
raise ValueError("Invalid main observation type")

0
browser_env/py.typed Normal file
View File

68
browser_env/utils.py Normal file
View File

@ -0,0 +1,68 @@
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Dict, TypedDict, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from PIL import Image
@dataclass
class DetachedPage:
url: str
content: str # html
@beartype
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
"""Convert png bytes to numpy array
Example:
>>> fig = go.Figure(go.Scatter(x=[1], y=[1]))
>>> plt.imshow(png_bytes_to_numpy(fig.to_image('png')))
"""
return np.array(Image.open(BytesIO(png)))
class AccessibilityTreeNode(TypedDict):
nodeId: str
ignored: bool
role: dict[str, Any]
chromeRole: dict[str, Any]
name: dict[str, Any]
properties: list[dict[str, Any]]
childIds: list[str]
parentId: str
backendDOMNodeId: int
frameId: str
bound: list[float] | None
union_bound: list[float] | None
offsetrect_bound: list[float] | None
class BrowserConfig(TypedDict):
win_upper_bound: float
win_left_bound: float
win_width: float
win_height: float
win_right_bound: float
win_lower_bound: float
device_pixel_ratio: float
class BrowserInfo(TypedDict):
DOMTree: dict[str, Any]
config: BrowserConfig
AccessibilityTree = list[AccessibilityTreeNode]
Observation = str | npt.NDArray[np.uint8]
class StateInfo(TypedDict):
observation: dict[str, Observation]
info: Dict[str, Any]

33
check_errors.sh Executable file
View File

@ -0,0 +1,33 @@
#!/bin/zsh
result_folder=$1
cd cache/$result_folder
# check whether there is any auto-login errors
errors=$(grep -l "Creating an account has many benefits: check out faster" *.html | sort -u | grep -o '[0-9]\+')
c=$(echo $errors | wc -l)
echo "Shopping total errors: $c"
echo $errors | tr '\n' ','
echo '\n\n'
errors=$(grep -l "Welcome, please sign in" *.html | sort -u | grep -o '[0-9]\+')
c=$(echo $errors | wc -l)
echo "Admin total errors: $c"
echo $errors | tr '\n' ','
echo '\n\n'
errors=$(grep -l "Username or email" *.html | sort -u | grep -o '[0-9]\+')
c=$(echo $errors | wc -l)
echo "Gitlab errors: $c"
echo $errors | tr '\n' ','
echo '\n\n'
errors=$(grep -l "Keep me logged in" *.html | sort -u | grep -o '[0-9]\+')
c=$(echo $errors | wc -l)
echo "Reddit errors: $c"
echo $errors | tr '\n' ','

View File

@ -0,0 +1,31 @@
{
"sites": ["reddit"],
"task_id": 1,
"require_login": true,
"storage_state": "./.auth/reddit_state.json",
"start_url": "http://metis.lti.cs.cmu.edu:9999/",
"geolocation": null,
"intent_template": "tell me all subreddits starting with character '{{character}}'",
"instantiation_dict": {"character": "a"},
"intent": "tell me all subreddits starting with character 'a'",
"require_reset": false,
"eval": {
"eval_types": ["string_match"],
"reference_answers": ["announcements Art AskReddit askscience aww"],
"reference_url": "",
"program_html": [
{
"url": "",
"required_contents": []
}
]
},
"reference_action_sequence": {
"action_set_tag": "playwright",
"action_sequence": [
"page.get_by_role(\"link\", name=\"Forums\").click()",
"page.get_by_role(\"link\", name=\"Alphabetical\").click()",
"page.stop(\"announcements Art AskReddit askscience aww\")"
]
}
}

View File

@ -0,0 +1,30 @@
{
"sites": ["misc"],
"task_id": 2,
"require_login": false,
"storage_state": null,
"start_url": "https://russmaxdesign.github.io/exercise",
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "Check out the classification section",
"require_reset": false,
"eval": {
"eval_types": ["url_match"],
"reference_answers": null,
"reference_url": "https://russmaxdesign.github.io/exercise/#link-two",
"program_html": [
{
"url": "",
"required_contents": []
}
]
},
"reference_action_sequence": {
"action_set_tag": "playwright",
"action_sequence": [
"page.get_by_role(\"navigation\").get_by_role(\"link\", name=\"Classification\").click()",
"page.stop(\"Wilson and Reade\")"
]
}
}

View File

@ -0,0 +1,31 @@
{
"sites": ["misc"],
"task_id": 3,
"require_login": false,
"storage_state": null,
"start_url": "https://russmaxdesign.github.io/exercise",
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "Tell me who provide a collection of concise, detailed information for mammal classification in 2005",
"require_reset": false,
"eval": {
"eval_types": ["string_match"],
"reference_answers": ["Wilson and Reader"],
"reference_url": "",
"program_html": [
{
"url": "",
"required_contents": []
}
]
},
"reference_action_sequence": {
"action_set_tag": "id_accessibility_tree",
"action_sequence": [
"type [13] [xyz@gmail.com] [0]",
"click [65]",
"stop [Wilson and Reader]"
]
}
}

View File

@ -0,0 +1,31 @@
{
"sites": ["reddit"],
"task_id": 4,
"require_login": true,
"storage_state": "./.auth/reddit_state.json",
"start_url": "http://metis.lti.cs.cmu.edu:9999/",
"geolocation": null,
"intent_template": "list all subreddits in alphabetical order",
"instantiation_dict": {},
"intent": "list all subreddits in alphabetical order",
"require_reset": false,
"eval": {
"eval_types": ["url_match"],
"reference_answers": null,
"reference_url": "http://metis.lti.cs.cmu.edu:9999/forums/all",
"program_html": [
{
"url": "",
"required_contents": []
}
]
},
"reference_action_sequence": {
"action_set_tag": "playwright",
"action_sequence": [
"page.get_by_role(\"link\", name=\"Forums\").click()",
"page.get_by_role(\"link\", name=\"Alphabetical\").click()",
"page.stop()"
]
}
}

28537
config_files/test.raw.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
from .evaluators import *
from .helper_functions import (
shopping_get_latest_order_url,
shopping_get_sku_latest_review_author,
shopping_get_sku_latest_review_rating,
)

View File

@ -0,0 +1,389 @@
"""base class for evaluation"""
# answer string match
import importlib
import json
import time
import urllib
from pathlib import Path
from typing import Any, Tuple, Union
import evaluate # type: ignore[import]
from beartype import beartype
from beartype.door import is_bearable
from playwright.sync_api import CDPSession, Page
from browser_env.actions import Action
from browser_env.utils import StateInfo
from evaluation_harness.helper_functions import (
gitlab_get_project_memeber_role,
llm_fuzzy_match,
reddit_get_post_url,
shopping_get_latest_order_url,
shopping_get_sku_latest_review_author,
shopping_get_sku_latest_review_rating,
)
Trajectory = list[Union[Action, StateInfo]]
@beartype
class Evaluator(object):
def __init__(self, eval_tag: str = "") -> None:
self.eval_tag = eval_tag
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession,
) -> float:
raise NotImplementedError
@staticmethod
def get_last_action(trajectory: Trajectory) -> Action:
try:
is_bearable(trajectory[-1], Action)
last_action = trajectory[-1]
except Exception:
raise ValueError(
"The last element of trajectory should be an action, add a fake stop action if needed"
)
return last_action # type: ignore[return-value]
@staticmethod
def get_last_state(trajectory: Trajectory) -> StateInfo:
try:
is_bearable(trajectory[-2], StateInfo)
last_state = trajectory[-2]
except Exception:
raise ValueError(
"The second last element of trajectory should be a state, add a fake stop action if needed"
)
return last_state # type: ignore[return-value]
@beartype
class StringExactEvaluator(Evaluator):
"""Check whether the answer is exactly the same as one of the reference answers"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | None = None,
client: CDPSession | None = None,
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
def clean_answer(answer: str) -> str:
if answer.startswith("'") and answer.endswith("'"):
answer = answer[1:-1]
elif answer.startswith('"') and answer.endswith('"'):
answer = answer[1:-1]
return answer
last_action = self.get_last_action(trajectory)
pred = clean_answer(last_action["answer"])
ref = [clean_answer(x) for x in configs["eval"]["reference_answers"]]
if pred in ref:
return 1.0
else:
return 0.0
@beartype
class StringEvaluator(Evaluator):
"""Check whether the answer is correct with:
exact match: the answer is exactly the same as the reference answer
must include: each phrase in the reference answer must be included in the answer
fuzzy match: the answer is similar to the reference answer, using LLM judge
"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | None = None,
client: CDPSession | None = None,
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
def clean_answer(answer: str) -> str:
if answer.startswith("'") and answer.endswith("'"):
answer = answer[1:-1]
elif answer.startswith('"') and answer.endswith('"'):
answer = answer[1:-1]
return answer.lower()
last_action = self.get_last_action(trajectory)
pred = clean_answer(last_action["answer"])
score = 1.0
for approach, value in configs["eval"]["reference_answers"].items():
match approach:
case "exact_match":
assert isinstance(value, str)
ref_answer = clean_answer(value)
score = score * (pred == ref_answer)
case "must_include":
assert isinstance(value, list)
for must_value in value:
must_value = clean_answer(must_value)
score = score * (must_value in pred)
case "fuzzy_match":
intent = configs["intent"]
assert isinstance(value, list)
for reference in value:
fuzzy_score = llm_fuzzy_match(pred, reference, intent)
score = score * fuzzy_score
return score
@beartype
class StringSoftEvaluator(Evaluator):
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | None = None,
client: CDPSession | None = None,
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
last_action = self.get_last_action(trajectory)
pred = last_action["answer"]
ref = configs["eval"]["reference_answers"]
# rouge
m = evaluate.load("rouge")
rouge = m.compute(predictions=[pred], references=[ref])
return float(rouge["rouge1"])
@beartype
class URLExactEvaluator(Evaluator):
"""Check whether the URL is exactly the same as of the reference URLs"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession | None = None,
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
def clean_url(url: str) -> str:
url = str(url)
if url.endswith("/"):
url = url[:-1]
return url
pred = clean_url(page.url)
ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
ref_urls = [clean_url(url) for url in ref_urls]
matching_rule = configs["eval"].get("url_note", "EXACT")
if matching_rule == "EXACT":
if pred in ref_urls:
return 1.0
else:
return 0.0
elif matching_rule == "GOLD in PRED":
if any([ref in pred for ref in ref_urls]):
return 1.0
else:
return 0.0
else:
raise ValueError(f"Unknown matching rule: {matching_rule}")
@beartype
class HTMLContentExactEvaluator(Evaluator):
"""Check whether the contents appear in the page"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession | None = None,
) -> float:
def clean(text: str) -> str:
text = str(text)
return text.strip().lower()
with open(config_file, "r") as f:
configs = json.load(f)
targets = configs["eval"]["program_html"]
score = 1.0
for target in targets:
target_url: str = target["url"] # which url to check
if target_url.startswith("func"):
func = target_url.split("func:")[1]
func = func.replace("__last_url__", page.url)
target_url = eval(func)
required_contents: str = target[
"required_contents"
] # what contents to check
locator: str = target["locator"] # js element locator
# navigate to that url
if target_url != "last":
page.goto(target_url)
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
# empty, use the full page
if not locator.strip():
selected_element = page.content()
# use JS to select the element
elif locator.startswith("document."):
try:
selected_element = page.evaluate(f"() => {locator}")
if not selected_element:
selected_element = ""
selected_element = str(selected_element)
except Exception:
# the page is wrong, return empty
selected_element = ""
# run program to call API
elif locator.startswith("func:"): # a helper function
func = locator.split("func:")[1]
func = func.replace("__page__", "page")
selected_element = eval(func)
else:
raise ValueError(f"Unknown locator: {locator}")
required_contents_or = [
clean(x) for x in required_contents.split(" |OR| ")
]
selected_element = clean(selected_element)
score *= any(
[
content in selected_element
for content in required_contents_or
]
)
return score
######
# soft matches.
# mainly for partial scores
# !!under development!!
# TODO[shuyanzh]
######
@beartype
class EvaluatorPartial(Evaluator):
def __init__(self) -> None:
raise NotImplementedError
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession,
) -> float:
raise NotImplementedError
@beartype
class URLSoftEvaluator(EvaluatorPartial):
"""Parse the URL and compare the domain and parameters"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession,
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
last_state = self.get_last_state(trajectory)
pred = last_state["info"]["page"].url
ref = configs["eval"]["reference_url"]
# parse url to get domain, parameters, etc.
parsed_pred = urllib.parse.urlparse(pred)
parsed_ref = urllib.parse.urlparse(ref)
# check domain
domain_match = int(parsed_pred.netloc == parsed_ref.netloc)
def get_param_set(query: dict[str, list[str]]) -> set[str]:
param_set = set()
for k, v in query.items():
for vv in v:
param_set.add(f"{k}={vv}")
return param_set
# calculate parameter f1
param_set_ref = get_param_set(urllib.parse.parse_qs(parsed_ref.query))
param_set_pred = get_param_set(
urllib.parse.parse_qs(parsed_pred.query)
)
r = len(param_set_ref & param_set_pred) / len(param_set_ref)
p = len(param_set_ref & param_set_pred) / len(param_set_pred)
f1 = 2 * r * p / (r + p) if r + p > 0 else 1.0
score = domain_match * f1 # domain match is a must
return score
class EvaluatorComb:
def __init__(self, evaluators: list[Evaluator]) -> None:
self.evaluators = evaluators
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page,
client: CDPSession,
) -> float:
score = 1.0
for evaluator in self.evaluators:
cur_score = evaluator(trajectory, config_file, page, client)
score *= cur_score
return score
@beartype
def evaluator_router(config_file: Path | str) -> EvaluatorComb:
"""Router to get the evaluator class"""
with open(config_file, "r") as f:
configs = json.load(f)
eval_types = configs["eval"]["eval_types"]
evaluators: list[Evaluator | EvaluatorPartial] = []
for eval_type in eval_types:
match eval_type:
case "string_match":
evaluators.append(StringEvaluator())
case "url_match":
evaluators.append(URLExactEvaluator())
case "program_html":
evaluators.append(HTMLContentExactEvaluator())
case _:
raise ValueError(f"eval_type {eval_type} is not supported")
return EvaluatorComb(evaluators)

View File

@ -0,0 +1,180 @@
"""Implements helper functions to assist evaluation cases where other evaluators are not suitable."""
import json
from typing import Any
from urllib.parse import urlparse
import requests
from beartype import beartype
from playwright.sync_api import CDPSession, Page
from browser_env.env_config import (
ACCOUNTS,
GITLAB,
MAP,
REDDIT,
SHOPPING,
SHOPPING_ADMIN,
WIKIPEDIA,
)
from llms.providers.openai_utils import (
generate_from_openai_chat_completion,
)
@beartype
def shopping_get_auth_token() -> str:
response = requests.post(
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
headers={"content-type": "application/json"},
data=json.dumps(
{
"username": ACCOUNTS["shopping_site_admin"]["username"],
"password": ACCOUNTS["shopping_site_admin"]["password"],
}
),
)
token: str = response.json()
return token
@beartype
def shopping_get_latest_order_url() -> str:
"""Get the latest order url from the shopping website."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
params = {
"searchCriteria[sortOrders][0][field]": "created_at",
"searchCriteria[sortOrders][0][direction]": "DESC",
"searchCriteria[pageSize]": "1",
}
response = requests.get(
f"{SHOPPING}/rest/V1/orders", params=params, headers=header
)
assert response.status_code == 200
response_obj = response.json()["items"][0]
order_id = int(response_obj["increment_id"])
order_url = f"{SHOPPING}/sales/order/view/order_id/{order_id}/"
return order_url
@beartype
def shopping_get_sku_latest_review_author(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
author: str = response_obj[-1]["nickname"]
return author
@beartype
def shopping_get_sku_latest_review_rating(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
assert response_obj[0]["ratings"][0]["rating_name"] == "Rating"
rating: str = str(response_obj[-1]["ratings"][0]["percent"])
return rating
@beartype
def reddit_get_post_url(url: str) -> str:
"""Get the post url"""
# Url is http://domain/f/subreddit/post_id/...
# get domain, subreddit, post_id
domain = urlparse(url).netloc
tok_url = urlparse(url).path.split("/")
# not a valid post/comment url, return the url as is
if len(tok_url) < 4:
return url
if tok_url[1] != "f":
return url
subreddit = urlparse(url).path.split("/")[2]
post_id = urlparse(url).path.split("/")[3]
scheme = urlparse(url).scheme
post_url = f"{scheme}://{domain}/f/{subreddit}/{post_id}/"
return post_url
@beartype
def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
# get the account index
try:
account_idx = page.evaluate(
f"""(() => {{
const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel");
let index = -1; // Default value if not found
for(let i = 0; i < elements.length; i++) {{
if(elements[i].outerText === '@{account_name}') {{
index = i;
break;
}}
}}
return index;
}})()"""
)
# get the role
role: str = page.evaluate(
f"""(() => {{
return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText;
}})()"""
)
except Exception:
role = ""
return role
@beartype
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
"""Check whether the prediction matches the reference with GPT-3.5"""
messages: list[dict[str, Any]] = []
messages.append(
{"role": "system", "content": "You are a helpful assistant"}
)
messages.append(
{
"role": "user",
"content": f'Given the statement "{pred}", would it be correct to infer "{reference}"? Yes or No',
}
)
response = generate_from_openai_chat_completion(
messages=messages,
model="gpt-3.5-turbo",
temperature=0,
top_p=1,
context_length=0,
max_tokens=16,
stop_token=None,
)
if "Yes" in response:
return 1.0
else:
return 0.0

1
llms/__init__.py Normal file
View File

@ -0,0 +1 @@
"""This module is adapt from https://github.com/zeno-ml/zeno-build"""

29
llms/lm_config.py Normal file
View File

@ -0,0 +1,29 @@
"""Config for language models."""
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class LMConfig:
"""A config for a language model.
Attributes:
provider: The name of the API provider.
model: The name of the model.
model_cls: The Python class corresponding to the model, mostly for
Hugging Face transformers.
tokenizer_cls: The Python class corresponding to the tokenizer, mostly
for Hugging Face transformers.
mode: The mode of the API calls, e.g., "chat" or "generation".
"""
provider: str
model: str
model_cls: type | None = None
tokenizer_cls: type | None = None
mode: str | None = None
gen_config: dict[str, Any] = dataclasses.field(default_factory=dict)

View File

@ -0,0 +1,283 @@
"""Tools to generate from OpenAI prompts.
Adopted from https://github.com/zeno-ml/zeno-build/"""
import asyncio
import logging
import os
import random
import time
from typing import Any
import aiolimiter
import openai
import openai.error
from tqdm.asyncio import tqdm_asyncio
def retry_with_exponential_backoff( # type: ignore
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 10,
errors: tuple[Any] = (openai.error.RateLimitError,),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs): # type: ignore
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Retry on specified errors
except errors as e:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
time.sleep(delay)
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
async def _throttled_openai_completion_acreate(
engine: str,
prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await openai.Completion.acreate( # type: ignore
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.error.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except openai.error.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_completion(
prompts: list[str],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Completion API.
Args:
prompts: list of prompts
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_completion_acreate(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for prompt in prompts
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["text"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_completion(
prompt: str,
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
response = openai.Completion.create( # type: ignore
prompt=prompt,
engine=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token],
)
answer: str = response["choices"][0]["text"]
return answer
async def _throttled_openai_chat_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await openai.ChatCompletion.acreate( # type: ignore
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.error.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except asyncio.exceptions.TimeoutError:
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
await asyncio.sleep(10)
except openai.error.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_chat_completion(
messages_list: list[list[dict[str, str]]],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Chat Completion API.
Args:
messages_list: list of message list
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_chat_completion_acreate(
model=engine,
messages=message,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for message in messages_list
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["message"]["content"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
response = openai.ChatCompletion.create( # type: ignore
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token] if stop_token else None,
)
answer: str = response["choices"][0]["message"]["content"]
return answer
@retry_with_exponential_backoff
# debug only
def fake_generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
return answer

14
llms/tokenizers.py Normal file
View File

@ -0,0 +1,14 @@
from typing import Any
import tiktoken
class Tokenizer(object):
def __init__(self, model_name: str) -> None:
if model_name in ["gpt-4", "gpt-turbo-3.5"]:
self.tokenizer = tiktoken.encoding_for_model(model_name)
else:
raise NotImplementedError
def __call__(self, text: str) -> list[int]:
return self.tokenizer.encode(text)

BIN
media/overview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 332 KiB

6
prepare.sh Normal file
View File

@ -0,0 +1,6 @@
#!/bin/bash
# prepare the evaluation
# re-validate login information
mkdir -p ./.auth
python browser_env/auto_login.py

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
gymnasium
playwright==1.32.1
Pillow
evaluate
openai
types-tqdm
tiktoken
aiolimiter

623
run.py Normal file
View File

@ -0,0 +1,623 @@
"""Script to run end-to-end evaluation on the benchmark"""
import argparse
import base64
import glob
import io
import json
import logging
import os
import random
import re
import subprocess
import time
from itertools import chain
from pathlib import Path
from typing import Any
import openai
import tiktoken
from beartype import beartype
from PIL import Image
from prompt_toolkit import prompt
from agent import Agent, PromptAgent, TeacherForcingAgent
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ObservationMetadata,
ScriptBrowserEnv,
StateInfo,
action2str,
create_stop_action,
)
from browser_env.actions import is_equivalent
from evaluation_harness import evaluator_router
from llms import lm_config
LOG_FOLDER = "log_files"
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
file_handler = logging.FileHandler(LOG_FILE_NAME)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
# Set the log format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
Trajectory = list[Action | StateInfo]
HTML_TEMPLATE = """
<!DOCTYPE html>
<head>
<style>
pre {{
white-space: pre-wrap;
word-wrap: break-word;
}}
</style>
</head>
<html>
<body>
{body}
</body>
</html>
"""
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
parser.add_argument(
"--render", action="store_true", help="Render the browser"
)
parser.add_argument(
"--slow_mo",
type=int,
default=0,
help="Slow down the browser by the specified amount",
)
parser.add_argument(
"--action_set_tag", default="id_accessibility_tree", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["accessibility_tree", "html", "image"],
default="accessibility_tree",
help="Observation type",
)
parser.add_argument(
"--current_viewport_only",
action="store_true",
help="Only use the current viewport for the observation",
)
parser.add_argument("--viewport_width", type=int, default=1280)
parser.add_argument("--viewport_height", type=int, default=720)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
parser.add_argument("--agent_type", type=str, default="prompt")
parser.add_argument(
"--instruction_path",
type=str,
default="agents/prompts/state_action_agent.json",
)
parser.add_argument(
"--parsing_failure_th",
help="When concesecutive parsing failure exceeds this threshold, the agent will stop",
type=int,
default=3,
)
parser.add_argument(
"--repeating_action_failure_th",
help="When concesecutive repeating action exceeds this threshold, the agent will stop",
type=int,
default=3,
)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
parser.add_argument("--mode", type=str, default="chat")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--context_length", type=int, default=0)
parser.add_argument("--max_tokens", type=int, default=384)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--max_obs_length",
type=int,
help="when not zero, will truncate the observation to this length before feeding to the model",
default=1920,
)
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=1000)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
if (
args.action_set_tag == "id_accessibility_tree"
and args.observation_type != "accessibility_tree"
):
raise ValueError(
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
)
return args
@beartype
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
) -> str:
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
@beartype
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
prompt_constructor: PromptConstructor | None,
) -> str:
"""Generate the text version of the predicted actions to store in action history for prompt use.
May contain hint information to recover from the failures"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
node_content = " ".join(node_content.split()[1:])
action_str = action2str(
action, action_set_tag, node_content
)
else:
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "playwright":
action_str = action["pw_code"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
class RenderHelper(object):
"""Helper class to render text and image observations and meta data in the trajectory"""
def __init__(
self, config_file: str, result_dir: str, action_set_tag: str
) -> None:
with open(config_file, "r") as f:
_config = json.load(f)
_config_str = ""
for k, v in _config.items():
_config_str += f"{k}: {v}\n"
_config_str = f"<pre>{_config_str}</pre>\n"
task_id = _config["task_id"]
self.action_set_tag = action_set_tag
self.render_file = open(
Path(result_dir) / f"render_{task_id}.html", "a+"
)
self.render_file.truncate(0)
# write init template
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
self.render_file.read()
self.render_file.flush()
def render(
self,
action: Action,
state_info: StateInfo,
meta_data: dict[str, Any],
render_screenshot: bool = False,
) -> None:
"""Render the trajectory"""
# text observation
observation = state_info["observation"]
text_obs = observation["text"]
info = state_info["info"]
new_content = f"<h2>New Page</h2>\n"
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
if render_screenshot:
# image observation
img_obs = observation["image"]
image = Image.fromarray(img_obs)
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
byte_io.seek(0)
image_bytes = base64.b64encode(byte_io.read())
image_str = image_bytes.decode("utf-8")
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
# meta data
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
# action
action_str = get_render_action(
action,
info["observation_metadata"],
action_set_tag=self.action_set_tag,
)
# with yellow background
action_str = f"<div class='predict_action'>{action_str}</div>"
new_content += f"{action_str}\n"
# add new content
self.render_file.seek(0)
html = self.render_file.read()
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
html_body += new_content
html = HTML_TEMPLATE.format(body=html_body)
self.render_file.seek(0)
self.render_file.truncate()
self.render_file.write(html)
self.render_file.flush()
def close(self) -> None:
self.render_file.close()
@beartype
def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
"""Check whether need to early stop"""
# reach the max step
num_steps = (len(trajectory) - 1) / 2
if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"
last_k_actions: list[Action]
action_seq: list[Action]
# Case: parsing failure for k times
k = thresholds["parsing_failure"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
if len(last_k_actions) >= k:
if all(
[
action["action_type"] == ActionTypes.NONE
for action in last_k_actions
]
):
return True, f"Failed to parse actions for {k} times"
# Case: same action for k times
k = thresholds["repeating_action"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
action_seq = trajectory[1::2] # type: ignore[assignment]
if len(action_seq) == 0:
return False, ""
last_action: Action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
@beartype
def test(
args: argparse.Namespace,
agent: Agent | PromptAgent,
config_file_list: list[str],
) -> None:
scores = []
max_steps = args.max_steps
early_stop_thresholds = {
"parsing_failure": args.parsing_failure_th,
"repeating_action": args.repeating_action_failure_th,
}
env = ScriptBrowserEnv(
headless=not args.render,
slow_mo=args.slow_mo,
observation_type=args.observation_type,
current_viewport_only=args.current_viewport_only,
viewport_size={
"width": args.viewport_width,
"height": args.viewport_height,
},
save_trace_enabled=args.save_trace_enabled,
sleep_after_execution=args.sleep_after_execution,
)
for config_file in config_file_list:
try:
render_helper = RenderHelper(
config_file, args.result_dir, args.action_set_tag
)
# get intent
with open(config_file) as f:
_c = json.load(f)
intent = _c["intent"]
task_id = _c["task_id"]
logger.info(f"[Config file]: {config_file}")
logger.info(f"[Intent]: {intent}")
agent.reset(config_file)
trajectory: Trajectory = []
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
meta_data = {"action_history": ["None"]}
while True:
early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds
)
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
try:
action = agent.next_action(
trajectory, intent, meta_data=meta_data
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
trajectory.append(action)
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
action_set_tag=args.action_set_tag,
prompt_constructor=agent.prompt_constructor
if isinstance(agent, PromptAgent)
else None,
)
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
meta_data["action_history"].append(action_str)
if action["action_type"] == ActionTypes.STOP:
break
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
if terminated:
# add a action place holder
trajectory.append(create_stop_action(""))
break
evaluator = evaluator_router(config_file)
score = evaluator(
trajectory=trajectory,
config_file=config_file,
page=env.page,
client=env.get_page_client(env.page),
)
scores.append(score)
if score == 1:
logger.info(f"[Result] (PASS) {config_file}")
else:
logger.info(f"[Result] (FAIL) {config_file}")
if args.save_trace_enabled:
env.save_trace(
Path(args.result_dir) / "traces" / f"{task_id}.zip"
)
except openai.error.OpenAIError as e:
logger.info(f"[OpenAI Error] {repr(e)}")
except Exception as e:
logger.info(f"[Unhandled Error] {repr(e)}]")
import traceback
# write to error file
with open(Path(args.result_dir) / "error.txt", "a") as f:
f.write(f"[Config file]: {config_file}\n")
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
# logger.info(f"[Render] {render_helper.render_file.name}")
# subprocess.run(["open", render_helper.render_file.name])
render_helper.close()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config
def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json
to_json.run()
# prepare result dir
result_dir = args.result_dir
if not result_dir:
result_dir = (
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
)
if not Path(result_dir).exists():
Path(result_dir).mkdir(parents=True, exist_ok=True)
args.result_dir = result_dir
logger.info(f"Create result dir: {result_dir}")
if not (Path(result_dir) / "traces").exists():
(Path(result_dir) / "traces").mkdir(parents=True)
# log the log file
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
f.write(f"{LOG_FILE_NAME}\n")
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
result_files = glob.glob(f"{result_dir}/*.html")
task_ids = [
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
]
unfinished_configs = []
for config_file in config_files:
task_id = os.path.basename(config_file).split(".")[0]
if task_id not in task_ids:
unfinished_configs.append(config_file)
return unfinished_configs
@beartype
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():
with open(config_file, "w") as f:
json.dump(vars(args), f, indent=4)
logger.info(f"Dump config to {config_file}")
if __name__ == "__main__":
args = config()
args.sleep_after_execution = 2.5
prepare(args)
test_file_list = []
st_idx = args.test_start_idx
ed_idx = args.test_end_idx
for i in range(st_idx, ed_idx):
test_file_list.append(f"config_files/{i}.json")
test_file_list = get_unfinished(test_file_list, args.result_dir)
print(f"Total {len(test_file_list)} tasks left")
args.render = True
args.render_screenshot = True
args.save_trace_enabled = True
args.current_viewport_only = True
dump_config(args)
agent = construct_agent(args)
test(args, agent, test_file_list)

55
scripts/collect_obs.py Normal file
View File

@ -0,0 +1,55 @@
"""Simple script to quickly get the observation of a page"""
import json
import re
import time
from typing import Dict, Optional, Tuple, Type, Union, cast
import pytest
from beartype import beartype
from playwright.sync_api import Page, expect
from browser_env import (
ScriptBrowserEnv,
create_id_based_action,
create_key_press_action,
create_playwright_action,
create_scroll_action,
)
from browser_env.env_config import *
HEADLESS = False
@beartype
def gen_tmp_storage_state() -> None:
with open(f"scripts/tmp_storage_state.json", "w") as f:
json.dump({"storage_state": ".auth/reddit_state.json"}, f)
@beartype
def get_observation(
observation_type: str, current_viewport_only: bool
) -> None:
env = ScriptBrowserEnv(
observation_type=observation_type,
current_viewport_only=current_viewport_only,
headless=HEADLESS,
)
env.reset(options={"config_file": f"scripts/tmp_storage_state.json"})
s = f"""page.goto("{GITLAB}")
page.scroll(down)"""
action_seq = s.split("\n")
for action in action_seq:
action = action.strip()
obs, success, _, _, info = env.step(create_playwright_action(action))
print(obs["text"])
_ = input("Press enter to continue")
if __name__ == "__main__":
gen_tmp_storage_state()
obs_type = "accessibility_tree"
current_viewport_only = True
get_observation(obs_type, current_viewport_only)

View File

@ -0,0 +1,27 @@
"""Replace the website placeholders with website domains from env_config
Generate the test data"""
import json
from browser_env.env_config import *
def main() -> None:
with open("config_files/test.raw.json", "r") as f:
raw = f.read()
raw = raw.replace("__GITLAB__", GITLAB)
raw = raw.replace("__REDDIT__", REDDIT)
raw = raw.replace("__SHOPPING__", SHOPPING)
raw = raw.replace("__SHOPPING_ADMIN__", SHOPPING_ADMIN)
raw = raw.replace("__WIKIPEDIA__", WIKIPEDIA)
raw = raw.replace("__MAP__", MAP)
with open("config_files/test.json", "w") as f:
f.write(raw)
# split to multiple files
data = json.loads(raw)
for idx, item in enumerate(data):
with open(f"config_files/{idx}.json", "w") as f:
json.dump(item, f, indent=2)
if __name__ == "__main__":
main()

26
setup.cfg Normal file
View File

@ -0,0 +1,26 @@
[metadata]
name = webarena
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
[options.extras_require]
dev =
pre-commit==3.0.1
pytest==7.1.2
mypy==0.991
beartype==0.12.0
nbmake
pytest-asyncio
types-requests
[options]
python_requires = >=3.7, <4
packages =
browser_env
agent
evaluation_harness
llms
[mypy]
strict = true

4
setup.py Normal file
View File

@ -0,0 +1,4 @@
from setuptools import setup
if __name__ == "__main__":
setup()

72
tests/conftest.py Normal file
View File

@ -0,0 +1,72 @@
from typing import AsyncGenerator, Generator
import pytest
import pytest_asyncio
from browser_env import AsyncScriptBrowserEnv, ScriptBrowserEnv
HEADLESS = True
SLOW_MO = 0
@pytest.fixture(scope="function")
def script_browser_env() -> Generator[ScriptBrowserEnv, None, None]:
"""Create a ScriptBrowserEnv instance for testing.
It is automatically closed after the test session.
This is helpful when the test failed and the browser is still open.
"""
env = ScriptBrowserEnv(
headless=HEADLESS,
slow_mo=SLOW_MO,
)
yield env
env.close()
@pytest.fixture(scope="function")
def current_viewport_script_browser_env() -> Generator[
ScriptBrowserEnv, None, None
]:
env = ScriptBrowserEnv(
headless=HEADLESS,
slow_mo=SLOW_MO,
current_viewport_only=True,
)
yield env
env.close()
@pytest.fixture(scope="function")
def accessibility_tree_script_browser_env() -> Generator[
ScriptBrowserEnv, None, None
]:
env = ScriptBrowserEnv(
headless=HEADLESS,
slow_mo=SLOW_MO,
observation_type="accessibility_tree",
)
yield env
env.close()
@pytest.fixture(scope="function")
def accessibility_tree_current_viewport_script_browser_env() -> Generator[
ScriptBrowserEnv, None, None
]:
env = ScriptBrowserEnv(
headless=HEADLESS,
slow_mo=SLOW_MO,
observation_type="accessibility_tree",
current_viewport_only=True,
)
yield env
env.close()
@pytest_asyncio.fixture(scope="function", autouse=True)
async def async_script_browser_env() -> AsyncGenerator[
AsyncScriptBrowserEnv, None
]:
env = AsyncScriptBrowserEnv(headless=HEADLESS, slow_mo=SLOW_MO)
yield env
await env.aclose()

View File

@ -0,0 +1,273 @@
import re
from typing import Dict, Optional, Tuple, Type, Union, cast
import pytest
from playwright.sync_api import Page, expect
from browser_env import (
ScriptBrowserEnv,
create_id_based_action,
create_key_press_action,
create_playwright_action,
create_scroll_action,
)
HEADLESS = True
SLOW_MO = 0
def test_frame_locator(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://www.littlewebhut.com/articles/html_iframe_example/")
page.frame_locator("iframe[name=\\"imgbox\\"]").get_by_role("img").click()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_basic(script_browser_env: ScriptBrowserEnv) -> None:
# click, fill, press, check, goto
env = script_browser_env
seq = """page.goto("https://demo.playwright.dev/todomvc/")
page.get_by_placeholder("What needs to be done?").click()
page.get_by_placeholder("What needs to be done?").fill("hello")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("world")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("yes")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("no")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_role("listitem").filter(has_text="world").get_by_role("checkbox", name="Toggle Todo").check()
page.get_by_role("button", name="Clear completed").click()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_hover(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://ianlunn.github.io/Hover/")
page.get_by_role("link", name="Download on GitHub").hover()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_select_option(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://russmaxdesign.github.io/exercise/#link-two")
page.get_by_role("combobox", name="Favourite mammal").select_option("African Wild Dog")"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_xpath(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://demo.playwright.dev/todomvc/")
page.goto("https://demo.playwright.dev/todomvc/#/")
page.get_by_placeholder("What needs to be done?").click()
page.get_by_placeholder("What needs to be done?").fill("hello")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_role("link", name="Completed").click()
page.locator("xpath=/html/body/section/div/header/input").fill("no")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.goto("https://bic-berkeley.github.io/psych-214-fall-2016/string_literals.html")
page.locator("xpath=//*[@id=\'searchbox\']/div/form/input[1]").fill("type")"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_inter_page_actions(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://demo.playwright.dev/todomvc/")
browser.new_tab()
browser.page_focus(0)
browser.page_focus(1)
page.page_close()
page.goto("https://google.com")
page.goto("https://demo.playwright.dev/todomvc/")
page.go_back()
page.go_forward()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
assert "https://demo.playwright.dev/todomvc" in info["page"].url
def test_scroll(current_viewport_script_browser_env: ScriptBrowserEnv) -> None:
env = current_viewport_script_browser_env
env.reset()
_, success, _, _, _ = env.step(create_scroll_action("down"))
assert success
_, success, _, _, _ = env.step(create_scroll_action("up"))
assert success
def test_id_click(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert "link 'McKenna/Bell'" in obs["text"]
# get the id of the link
element_id = re.search(r"\[(\d+)\] link 'McKenna/Bell'", obs["text"]).group(1) # type: ignore
obs, success, _, _, info = env.step(
create_id_based_action(f"click [{element_id}]")
)
assert success
assert (
info["page"].url
== "https://russmaxdesign.github.io/exercise/#link-four"
)
obs, success, _, _, info = env.step(create_scroll_action("down"))
assert "link 'Classification'" in obs["text"]
element_id = re.search(r"\[(\d+)\] link 'Classification'", obs["text"]).group(1) # type: ignore
obs, success, _, _, info = env.step(
create_id_based_action(f"click [{element_id}]")
)
assert success
assert (
info["page"].url
== "https://russmaxdesign.github.io/exercise/#link-two"
)
assert "radio 'Weekly'" in obs["text"]
element_id = re.search(r"\[(\d+)\] radio 'Weekly'", obs["text"]).group(1) # type: ignore
obs, success, _, _, info = env.step(
create_id_based_action(f"click [{element_id}]")
)
assert success
assert "radio 'Weekly'" in obs["text"]
def test_id_hover(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://ianlunn.github.io/Hover/")'
)
)
assert success
assert "link 'Download on GitHub'" in obs["text"]
element_id = re.search(r"\[(\d+)\] link 'Download on GitHub'", obs["text"]).group(1) # type: ignore
obs, success, _, _, info = env.step(
create_id_based_action(f"hover [{element_id}]")
)
assert success
def test_key_press(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert "textbox 'Full name'" in obs["text"]
element_id = re.search(r"\[(\d+)\] textbox 'Full name'", obs["text"]).group(1) # type: ignore
s = "My Name IS XYZ"
obs, success, _, _, info = env.step(
create_id_based_action(f"type [{element_id}] [{s}] [0]")
)
assert success
expect(env.page.get_by_label("Full name")).to_be_focused()
obs, success, _, _, info = env.step(create_key_press_action("Enter"))
assert success
expect(env.page.get_by_label("Email")).to_be_focused()
def test_id_type(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert "textbox 'Full name'" in obs["text"]
s = "My Name IS XYZ"
element_id = re.search(r"\[(\d+)\] textbox 'Full name'", obs["text"]).group(1) # type: ignore
obs, success, _, _, info = env.step(
create_id_based_action(f"type [{element_id}] [{s}]")
)
assert success
locator = env.page.get_by_label("Full name")
expect(locator).to_have_value(s)
def test_e2e_id_based_actions(
accessibility_tree_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_script_browser_env
env.reset()
obs, *_ = env.step(
create_id_based_action(
"goto [https://russmaxdesign.github.io/exercise/]"
)
)
element_id = re.search(r"\[(\d+)\] link 'What are mammals\?'", obs["text"]).group(1) # type: ignore
obs, *_ = env.step(create_id_based_action(f"click [{element_id}]"))
element_id = re.search(r"\[(\d+)\] textbox 'Email'", obs["text"]).group(1) # type: ignore
env.step(
create_id_based_action(f"type [{element_id}] [test@gmail.com] [0]")
)
env.step(create_id_based_action("scroll [down]"))
env.step(create_id_based_action("scroll [up]"))
env.step(create_id_based_action("new_tab"))
env.step(create_id_based_action("tab_focus [0]"))
env.step(create_id_based_action("tab_focus [1]"))
env.step(create_id_based_action("goto [https://example.com/]"))
env.step(create_id_based_action("go_back"))
x = env.step(create_id_based_action("go_forward"))
assert x[-1]["page"].url == "https://example.com/"
x = env.step(create_id_based_action("tab_focus [0]"))
assert (
x[-1]["page"].url
== "https://russmaxdesign.github.io/exercise/#link-one"
)

View File

@ -0,0 +1,87 @@
import numpy as np
from browser_env import *
def test_is_equivalent() -> None:
for action_type in ActionTypes.__members__.values():
action_a = create_random_action()
action_b = create_random_action()
if action_a["action_type"] != action_b["action_type"]:
assert not is_equivalent(action_a, action_b)
action_a["action_type"] = action_type
action_b["action_type"] = action_type
match action_type:
case ActionTypes.MOUSE_CLICK | ActionTypes.MOUSE_HOVER:
if not np.allclose(action_a["coords"], action_b["coords"]):
assert not is_equivalent(action_a, action_b)
action_a["coords"] = action_b["coords"]
assert is_equivalent(action_a, action_b)
case ActionTypes.KEYBOARD_TYPE:
if action_a["text"] != action_b["text"]:
assert not is_equivalent(action_a, action_b)
action_a["text"] = action_b["text"]
assert is_equivalent(action_a, action_b)
case ActionTypes.CLICK | ActionTypes.HOVER | ActionTypes.TYPE:
if action_a["element_id"] and action_b["element_id"]:
if action_a["element_id"] != action_b["element_id"]:
assert not is_equivalent(action_a, action_b)
action_a["element_id"] = action_b["element_id"]
assert is_equivalent(action_a, action_b)
elif action_a["element_id"] and action_b["element_id"]:
if action_a["element_role"] != action_b["element_role"]:
assert not is_equivalent(action_a, action_b)
action_a["element_role"] = action_b["element_role"]
if action_a["element_name"] != action_b["element_name"]:
assert not is_equivalent(action_a, action_b)
action_a["element_name"] = action_b["element_name"]
assert is_equivalent(action_a, action_b)
elif action_a["pw_code"] and action_b["pw_code"]:
if action_a["pw_code"] != action_b["pw_code"]:
assert not is_equivalent(action_a, action_b)
action_a["pw_code"] = action_b["pw_code"]
assert is_equivalent(action_a, action_b)
else:
action_a["element_id"] = action_b["element_id"]
assert is_equivalent(action_a, action_b)
case ActionTypes.GOTO_URL:
if action_a["url"] != action_b["url"]:
assert not is_equivalent(action_a, action_b)
action_a["url"] = action_b["url"]
assert is_equivalent(action_a, action_b)
case ActionTypes.PAGE_FOCUS:
if action_a["page_number"] != action_b["page_number"]:
assert not is_equivalent(action_a, action_b)
action_a["page_number"] = action_b["page_number"]
assert is_equivalent(action_a, action_b)
case ActionTypes.SCROLL:
da = "up" if "up" in action_a["direction"] else "down"
db = "up" if "up" in action_b["direction"] else "down"
if da != db:
assert not is_equivalent(action_a, action_b)
action_a["direction"] = action_b["direction"]
assert is_equivalent(action_a, action_b)
case ActionTypes.KEY_PRESS:
if action_a["key_comb"] != action_b["key_comb"]:
assert not is_equivalent(action_a, action_b)
action_a["key_comb"] = action_b["key_comb"]
assert is_equivalent(action_a, action_b)
case ActionTypes.CHECK | ActionTypes.SELECT_OPTION:
if action_a["pw_code"] != action_b["pw_code"]:
assert not is_equivalent(action_a, action_b)
action_a["pw_code"] = action_b["pw_code"]
assert is_equivalent(action_a, action_b)
case ActionTypes.STOP:
if action_a["answer"] != action_b["answer"]:
assert not is_equivalent(action_a, action_b)
action_a["answer"] = action_b["answer"]
assert is_equivalent(action_a, action_b)
case _:
assert is_equivalent(action_a, action_b)
def test_action2create_function() -> None:
for _ in range(1000):
action = create_random_action()
create_function = action2create_function(action)
assert is_equivalent(action, eval(create_function))

View File

@ -0,0 +1,67 @@
import asyncio
import json
from browser_env import *
auth_json = {
"cookies": [
{
"name": "session-username",
"value": "standard_user",
"domain": "www.saucedemo.com",
"path": "/",
"httpOnly": False,
"secure": False,
"sameSite": "Lax",
}
],
"origins": [],
}
def test_auth_cookie() -> None:
env = ScriptBrowserEnv()
env.reset()
_, reward, _, _, info = env.step(
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
)
assert reward == 1
assert "page" in info and isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.saucedemo.com/"
json.dump(auth_json, open("/tmp/auth.json", "w"))
instance_config = {"storage_state": "/tmp/auth.json"}
json.dump(instance_config, open("/tmp/config.json", "w"))
env.reset(options={"config_file": "/tmp/config.json"})
_, reward, _, _, info = env.step(
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
)
assert reward == 1
assert "page" in info and isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.saucedemo.com/inventory.html"
env.close()
def test_async_auth_cookie() -> None:
env = AsyncScriptBrowserEnv()
async def _test() -> None:
await env.areset()
_, reward, _, _, info = await env.astep(
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
)
assert reward == 1
assert "page" in info and isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.saucedemo.com/"
json.dump(auth_json, open("/tmp/auth.json", "w"))
instance_config = {"storage_state": "/tmp/auth.json"}
json.dump(instance_config, open("/tmp/config.json", "w"))
await env.areset(options={"config_file": "/tmp/config.json"})
_, reward, _, _, info = await env.astep(
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
)
assert reward == 1
assert "page" in info and isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.saucedemo.com/inventory.html"
await env.aclose()
asyncio.run(_test())

View File

@ -0,0 +1,89 @@
from typing import Dict, Generator, Optional, Tuple, Type, Union, cast
import pytest
from playwright.sync_api import Page
from browser_env import ScriptBrowserEnv, create_playwright_action
HEADLESS = True
SLOW_MO = 0
def test_frame_locator(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://www.littlewebhut.com/articles/html_iframe_example/")
page.frame_locator("iframe[name=\\"imgbox\\"]").get_by_role("img").click()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_basic(script_browser_env: ScriptBrowserEnv) -> None:
# click, fill, press, check, goto
env = script_browser_env
seq = """page.goto("https://demo.playwright.dev/todomvc/")
page.get_by_placeholder("What needs to be done?").click()
page.get_by_placeholder("What needs to be done?").fill("hello")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("world")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("yes")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_placeholder("What needs to be done?").fill("no")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_role("listitem").filter(has_text="world").get_by_role("checkbox", name="Toggle Todo").check()
page.get_by_role("button", name="Clear completed").click()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
@pytest.mark.skip(reason="not important, but the site is flaky")
def test_hover(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://www.w3schools.com/cssref/tryit.php?filename=trycss_sel_hover")
page.frame_locator("iframe[name=\\'iframeResult\\']").get_by_role("link", name="w3schools.com").hover()"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
@pytest.mark.skip(reason="not important, but the site is flaky")
def test_select_option(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://www.w3schools.com/tags/tryit.asp?filename=tryhtml_select")
page.frame_locator("iframe[name=\\'iframeResult\\']").get_by_role("combobox", name="Choose a car:").select_option("opel")"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success
def test_xpath(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
seq = """page.goto("https://demo.playwright.dev/todomvc/")
page.goto("https://demo.playwright.dev/todomvc/#/")
page.get_by_placeholder("What needs to be done?").click()
page.get_by_placeholder("What needs to be done?").fill("hello")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.get_by_role("link", name="Completed").click()
page.locator("xpath=/html/body/section/div/header/input").fill("no")
page.get_by_placeholder("What needs to be done?").press("Enter")
page.goto("https://bic-berkeley.github.io/psych-214-fall-2016/string_literals.html")
page.locator("xpath=//*[@id=\'searchbox\']/div/form/input[1]").fill("type")"""
env.reset()
for action in seq.split("\n"):
action = action.strip()
_, success, _, _, info = env.step(create_playwright_action(action))
assert success

View File

@ -0,0 +1,304 @@
import asyncio
import collections
import json
import tempfile
from typing import Callable, Dict, Optional, Tuple, Type, Union, cast
import pytest
from beartype.door import is_bearable
from gymnasium.vector import AsyncVectorEnv
from playwright.sync_api import Page
from browser_env import (
Action,
AsyncScriptBrowserEnv,
DetachedPage,
ScriptBrowserEnv,
create_focus_and_click_action,
create_goto_url_action,
create_keyboard_type_action,
create_playwright_action,
create_scroll_action,
)
from browser_env.actions import create_id_based_action
from browser_env.env_config import *
def test_script_browser_env(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
env.reset()
env.step(
create_goto_url_action("http://www.example.com"),
)
env.step(
create_focus_and_click_action(
element_role="link",
element_name="More",
),
)
_, _, _, _, info = env.step(
create_focus_and_click_action(
element_role="link",
element_name="2606",
)
)
assert isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
@pytest.mark.asyncio
async def test_async_script_browser_env(
async_script_browser_env: AsyncScriptBrowserEnv,
) -> None:
env = async_script_browser_env
await env.areset()
await env.astep(
create_goto_url_action("http://www.example.com"),
)
await env.astep(
create_focus_and_click_action(
element_role="link",
element_name="More",
),
)
_, _, _, _, info = await env.astep(
create_focus_and_click_action(
element_role="link",
element_name="2606",
)
)
assert isinstance(info["page"], DetachedPage)
assert info["page"].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
def collate_actions(actions: list[Action]) -> dict[str, list[object]]:
action_dict = collections.defaultdict(list)
for action in actions:
for key, value in action.items():
action_dict[key].append(value)
return action_dict
@pytest.mark.skip(reason="Gym doesn't support self-defined observations")
def test_parallel_script_browser_env() -> None:
vector_env = AsyncVectorEnv(
[
lambda: ScriptBrowserEnv(),
lambda: ScriptBrowserEnv(),
],
shared_memory=True,
)
vector_env.reset()
vector_env.step(
collate_actions(
[
create_goto_url_action("http://www.example.com"),
]
* 2
)
)
vector_env.step(
collate_actions(
[
create_focus_and_click_action(
element_role="link",
element_name="More",
),
]
* 2
)
)
_, _, _, _, info = vector_env.step(
collate_actions(
[
create_focus_and_click_action(
element_role="link",
element_name="2606",
),
create_focus_and_click_action(
element_role="link",
element_name="6761",
),
]
)
)
assert is_bearable(info["page"].tolist(), list[DetachedPage])
assert info["page"][0].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
assert info["page"][1].url == "https://www.rfc-editor.org/rfc/rfc6761.html"
vector_env.close() # type: ignore[no-untyped-call]
def test_is_in_viewport(script_browser_env: ScriptBrowserEnv) -> None:
env = script_browser_env
env.reset()
env.step(
create_goto_url_action("https://www.iana.org/domains/reserved"),
)
_, _, _, _, info = env.step(
create_focus_and_click_action(
element_role="link",
element_name="IDN",
nth=1,
),
)
assert (
info["page"].url
== "https://www.icann.org/resources/pages/idn-2012-02-25-en"
)
env.step(
create_goto_url_action("https://www.iana.org/domains/reserved"),
)
_, _, _, _, info = env.step(create_keyboard_type_action(keys=["PageDown"]))
_, _, _, _, info = env.step(
create_focus_and_click_action(
element_role="link",
element_name="IDN",
),
)
assert info["page"].url == "https://www.iana.org/domains/idn-tables"
def test_focus_placeholder_and_label(
script_browser_env: ScriptBrowserEnv,
) -> None:
env = script_browser_env
env.reset()
for action in [
create_goto_url_action("https://demo.applitools.com"),
create_focus_and_click_action("placeholder", "Enter your username"),
create_keyboard_type_action("abc"),
create_focus_and_click_action("placeholder", "Enter your password"),
create_keyboard_type_action("123"),
create_focus_and_click_action("label", "Remember Me"),
create_focus_and_click_action("link", "Sign in"),
]:
_, success, _, _, info = env.step(action)
assert success
assert info["page"].url == "https://demo.applitools.com/app.html"
def test_current_viewport(
current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
s1 = "detailed information about how mammals could be classified."
s2 = "Types of mammals"
env = current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert s1 in obs["text"] and s2 not in obs["text"]
obs, success, _, _, info = env.step(create_scroll_action("down"))
assert success
assert s1 not in obs["text"] and s2 in obs["text"]
def test_accessibility_tree(
accessibility_tree_script_browser_env: ScriptBrowserEnv,
) -> None:
s1 = "checkbox 'Yes'"
s2 = "button 'Submit'"
env = accessibility_tree_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert s1 in obs["text"] and s2 in obs["text"]
def test_accessibility_tree_viewport(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
s1 = "combobox 'Favourite mammal'"
s2 = "gridcell 'Canyon bat'"
s3 = "heading 'Useful links'"
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, success, _, _, info = env.step(
create_playwright_action(
'page.goto("https://russmaxdesign.github.io/exercise/")'
)
)
assert success
assert (
s1 in obs["text"] and s2 not in obs["text"] and s3 not in obs["text"]
)
obs, success, _, _, info = env.step(create_scroll_action("down"))
assert success
assert (
s1 not in obs["text"] and s2 in obs["text"] and s3 not in obs["text"]
)
obs, success, _, _, info = env.step(create_scroll_action("down"))
assert success
assert s1 not in obs["text"] and s2 in obs["text"] and s3 in obs["text"]
def test_multiple_start_url(script_browser_env: ScriptBrowserEnv) -> None:
temp_config = tempfile.NamedTemporaryFile("w", delete=False)
config = {
"require_login": False,
"start_url": f"{REDDIT} |AND| {REDDIT}/forums",
}
json.dump(config, temp_config)
temp_config.close()
env = script_browser_env
env.reset(options={"config_file": temp_config.name})
assert len(env.context.pages) == 2
assert env.context.pages[0].url == f"{REDDIT}/"
assert env.context.pages[1].url == f"{REDDIT}/forums", env.context.pages[
1
].url
def test_observation_tab_information(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, *_ = env.step(
create_id_based_action(
"goto [https://russmaxdesign.github.io/exercise/]"
)
)
obs, *_ = env.step(create_id_based_action("new_tab"))
obs, *_ = env.step(
create_id_based_action("goto [https:///www.google.com]")
)
assert obs["text"].startswith( # type: ignore[union-attr]
"Tab 0: Exercise page for keyboard and screen reader use | Tab 1 (current): Google"
)
obs, *_ = env.step(create_id_based_action("tab_focus [0]"))
assert obs["text"].startswith( # type: ignore[union-attr]
"Tab 0 (current): Exercise page for keyboard and screen reader use | Tab 1: Google"
)
def test_accessibility_tree_observation_update(
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
) -> None:
env = accessibility_tree_current_viewport_script_browser_env
env.reset()
obs, *_ = env.step(
create_playwright_action(
"page.goto('https://russmaxdesign.github.io/exercise/')"
)
)
obs, *_ = env.step(
create_playwright_action(
'page.get_by_label("Full name").fill("UNIQUE_NAME")'
)
)
assert "UNIQUE_NAME" in obs["text"]

View File

@ -0,0 +1,29 @@
{
"sites": ["shopping"],
"task_id": 0,
"require_login": true,
"storage_state": null,
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "last",
"required_contents": "80",
"locator": "func:shopping_get_sku_latest_review_rating('B09BCM56J7')"
},
{
"url": "last",
"required_contents": "cupcakecupcake",
"locator": "func:shopping_get_sku_latest_review_author('B09BCM56J7')"
}
]
}
}

View File

@ -0,0 +1,29 @@
{
"sites": ["shopping"],
"task_id": 0,
"require_login": true,
"storage_state": null,
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "last",
"required_contents": "100",
"locator": "func:shopping_get_sku_latest_review_rating('B09BCM56J7')"
},
{
"url": "last",
"required_contents": "cupcakecupcake",
"locator": "func:shopping_get_sku_latest_review_author('B09BCM56J7')"
}
]
}
}

View File

@ -0,0 +1,24 @@
{
"sites": ["shopping"],
"task_id": 0,
"require_login": true,
"storage_state": null,
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "func:reddit_get_post_url('__last_url__')",
"locator": "document.querySelector('.submission__inner').outerText",
"required_contents": "&#x200B;"
}
]
}
}

View File

@ -0,0 +1,29 @@
{
"sites": ["shopping"],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/gitlab_state.json",
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design/-/project_members",
"locator": "func:gitlab_get_project_memeber_role(__page__, 'byteblaze')",
"required_contents": "Developer"
},
{
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design/-/project_members",
"locator": "func:gitlab_get_project_memeber_role(__page__, 'primer')",
"required_contents": "Owner"
}
]
}
}

View File

@ -0,0 +1,29 @@
{
"sites": ["gitlab"],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/gitlab_state.json",
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "last",
"required_contents": "Hello World",
"locator": "document.querySelector('[id=\"form-name\"').value"
},
{
"url": "last",
"required_contents": "alexisxy@hotmail.com",
"locator": "document.querySelector('[id=\"form-email\"').value"
}
]
}
}

View File

@ -0,0 +1,34 @@
{
"sites": ["gitlab"],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/gitlab_state.json",
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html"],
"reference_answers": [],
"reference_url": "",
"program_html": [
{
"url": "last",
"required_contents": "Accessible light and dark syntax highlighting themes",
"locator": ""
},
{
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design",
"required_contents": "Add more deploy triggers",
"locator": ""
},
{
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design",
"required_contents": "Create MVP React component layout",
"locator": ""
}
]
}
}

View File

@ -0,0 +1,30 @@
{
"sites": ["gitlab"],
"task_id": 0,
"require_login": true,
"storage_state": null,
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["program_html", "url_match"],
"reference_answers": [],
"reference_url": "https://russmaxdesign.github.io/",
"url_note": "GOLD in PRED",
"program_html": [
{
"url": "last",
"required_contents": "Hello World",
"locator": "document.querySelector('[id=\"form-name\"').value"
},
{
"url": "last",
"required_contents": "alexisxy@hotmail.com",
"locator": "document.querySelector('[id=\"form-email\"').value"
}
]
}
}

View File

@ -0,0 +1,25 @@
{
"sites": ["reddit"],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/reddit_state.json",
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["string_match"],
"reference_answers": {
"must_include": ["1985/04/18"]
},
"reference_url": "",
"program_html": [
{
"url": "",
"required_contents": []
}
]
}
}

View File

@ -0,0 +1,23 @@
{
"sites": ["reddit"],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/reddit_state.json",
"start_url": null,
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "",
"require_reset": false,
"eval": {
"eval_types": ["url_match"],
"reference_answers": [],
"reference_url": "http://metis.lti.cs.cmu.edu:9999",
"program_html": [
{
"url": "",
"required_contents": []
}
]
}
}

View File

@ -0,0 +1,333 @@
import random
from glob import glob
from pathlib import Path
from typing import Any
import pytest
from beartype import beartype
from py import test
from agent import Agent, TeacherForcingAgent
from browser_env import ActionTypes, ScriptBrowserEnv
from browser_env.env_config import *
from evaluation_harness import (
HTMLContentExactEvaluator,
StringEvaluator,
URLExactEvaluator,
)
from evaluation_harness.evaluators import EvaluatorComb
HEADLESS = True
config_file_folder = "tests/test_evaluation_harness/configs"
def tf_roll_out(
agent: Agent, env: ScriptBrowserEnv, config_file: str
) -> list[Any]:
"""Roll out the agent using teacher forcing actions"""
obs, state_info = env.reset(options={"config_file": config_file})
trajectory: list[Any] = [{"observation": obs, "info": state_info}]
while True:
action = agent.next_action(
trajectory=trajectory, intent="", meta_data={}
)
trajectory.append(action)
if action["action_type"] == ActionTypes.STOP:
break
# preceed to next action
obs, reward, terminated, truncated, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
return trajectory
def test_string_match_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/string_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = """page.stop("The date is 1985/04/18")"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = StringEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
def test_string_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
config_file = f"{config_file_folder}/string_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = """page.stop("The date is 1936/04/18")"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = StringEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 0.0
def test_url_exact_match_success(script_browser_env: ScriptBrowserEnv) -> None:
config_file = f"{config_file_folder}/url_exact_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("{REDDIT}")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = URLExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
def test_url_exact_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
config_file = f"{config_file_folder}/url_exact_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("{GITLAB}")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = URLExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
print(env.page.url)
assert score == 0.0
def test_html_content_match_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/html_content_exact_match.json"
# randomly sample a string
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("{GITLAB}")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
def test_html_content_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
config_file = f"{config_file_folder}/html_content_exact_match.json"
# randomly sample a string
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = """page.goto("https://russmaxdesign.github.io/exercise")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 0.0
def test_html_content_element_match_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/html_content_element_exact_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
page.get_by_label("Full name").fill("Hello World")
page.get_by_label("Email").click()
page.get_by_label("Email").fill("alexisxy@hotmail.com")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
def test_html_content_element_match_fail(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/html_content_element_exact_match.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
page.get_by_label("Full name").fill("Hello")
page.get_by_label("Email").click()
page.get_by_label("Email").fill("alexisxy@hotmail.com")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 0.0
def test_html_content_url_comb_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/html_content_url_comb.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
page.get_by_label("Full name").fill("Hello World")
page.get_by_label("Email").click()
page.get_by_label("Email").fill("alexisxy@hotmail.com")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evaluators = EvaluatorComb(
[URLExactEvaluator(), HTMLContentExactEvaluator()]
)
score = evaluators(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
@beartype
def test_func_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/func_eval_success.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
@beartype
def test_func_fail(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/func_eval_fail.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 0.0
@beartype
def test_func_url_func_last_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/func_url_func_1.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.goto("{REDDIT}/f/wallstreetbets/50431/-/comment/676875")
page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0
@beartype
def test_func_url_func_page_success(
script_browser_env: ScriptBrowserEnv,
) -> None:
config_file = f"{config_file_folder}/func_url_func_2.json"
agent = TeacherForcingAgent()
agent.set_action_set_tag(tag="playwright")
action_seq = f"""page.stop()"""
agent.set_actions(action_seq)
env = script_browser_env
trajectory = tf_roll_out(agent, env, config_file)
evalutor = HTMLContentExactEvaluator()
score = evalutor(
trajectory, config_file, env.page, env.get_page_client(env.page)
)
assert score == 1.0

View File

@ -0,0 +1,33 @@
import json
import os
from pathlib import Path
from beartype import beartype
from browser_env import ScriptBrowserEnv
from browser_env.env_config import *
from evaluation_harness.helper_functions import (
gitlab_get_project_memeber_role,
)
HEADLESS = True
config_file_folder = "tests/test_evaluation_harness/configs"
def test_gitlab_get_project_memeber_role(
script_browser_env: ScriptBrowserEnv,
) -> None:
env = script_browser_env
config_file = f"{config_file_folder}/tmp_config.json"
with open(config_file, "w") as f:
json.dump({"storage_state": ".auth/gitlab_state.json"}, f)
env.reset(options={"config_file": config_file})
env.page.goto(f"{GITLAB}/primer/design/-/project_members")
role1 = gitlab_get_project_memeber_role(env.page, "byteblaze")
assert role1 == "Developer"
role2 = gitlab_get_project_memeber_role(env.page, "primer")
assert role2 == "Owner"
# remove tmp config file
os.remove(config_file)