mirror of
https://github.com/web-arena-x/webarena.git
synced 2026-02-06 11:16:53 +00:00
release commit
This commit is contained in:
commit
b454f2dcfd
17
.github/workflows/pre-commit.yml
vendored
Normal file
17
.github/workflows/pre-commit.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.9
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
36
.github/workflows/tests.yml
vendored
Normal file
36
.github/workflows/tests.yml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
name: Python Package Pytest
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
test-all:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
max-parallel: 5
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.9
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
playwright install
|
||||
pip install -e .[dev]
|
||||
- name: Type-checking package with mypy
|
||||
run: |
|
||||
# Manually install mypy in the standard way.
|
||||
pip --quiet install -U mypy
|
||||
# Log this mypy version for debuggability.
|
||||
mypy --version
|
||||
# Run this mypy instance against our main package.
|
||||
mypy --install-types --non-interactive .
|
||||
mypy --strict .
|
||||
- name: Enviroment prepare
|
||||
run: |
|
||||
bash prepare.sh
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
# ignore annotation notebook because it requires a browser
|
||||
pytest
|
||||
158
.gitignore
vendored
Normal file
158
.gitignore
vendored
Normal file
@ -0,0 +1,158 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# mac OS
|
||||
*.DS_Store
|
||||
|
||||
.vscode
|
||||
*tmp*
|
||||
|
||||
.auth/*
|
||||
|
||||
# local debug
|
||||
run.sh
|
||||
|
||||
# trajectory visualization
|
||||
render_cache/*
|
||||
|
||||
# TMP IGNORE
|
||||
agent/prompts/jsons/*
|
||||
log_files/
|
||||
config_files/*0.json
|
||||
config_files/*1.json
|
||||
config_files/*2.json
|
||||
config_files/*3.json
|
||||
config_files/*4.json
|
||||
config_files/*5.json
|
||||
config_files/*6.json
|
||||
config_files/*7.json
|
||||
config_files/*8.json
|
||||
config_files/*9.json
|
||||
config_files/test.json
|
||||
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.12.0
|
||||
hooks:
|
||||
- id: black
|
||||
exclude: '^(agent/prompts/raw)'
|
||||
args: [--line-length=79]
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile", "black", --line-length=72]
|
||||
- repo: https://github.com/kynan/nbstripout
|
||||
rev: 0.6.0
|
||||
hooks:
|
||||
- id: nbstripout
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
35
README.md
Normal file
35
README.md
Normal file
@ -0,0 +1,35 @@
|
||||
[](https://www.python.org/downloads/release/python-3109/)
|
||||
[](https://pre-commit.com/)
|
||||
<a href="https://github.com/psf/black"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
|
||||
[](https://mypy-lang.org/)
|
||||
[](https://beartype.readthedocs.io)
|
||||
|
||||
# WebArena: A Realistic Web Environment for Building Autonomous Agents
|
||||
[[Website]](https://webarena.dev/)
|
||||
[[Paper]]()
|
||||
|
||||

|
||||
> WebArena is a standalone, self-hostable web environment for building autonomous agents. WebArena creates websites from four popular categories with functionality and data mimicking their real-world equivalents. To emulate human problem-solving, WebArena also embeds tools and knowledge resources as independent websites. WebArena introduces a benchmark on interpreting high-level realistic natural language command to concrete web-based interactions. We provide annotated programs designed to programmatically validate the functional correctness of each task.
|
||||
|
||||
> **Note** This README is still under constructions. Stay tuned!
|
||||
|
||||
## Install
|
||||
```bash
|
||||
# Python 3.10+
|
||||
conda create -n webarena python=3.10; conda activate webarena
|
||||
pip install -r requirements.txt
|
||||
playwright install
|
||||
pip install -e .
|
||||
|
||||
# optional, dev only
|
||||
pip install -e ".[dev]"
|
||||
mypy --install-types --non-interactive browser_env
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
## Preperation
|
||||
* Config the URLs of each website in [env_config](browser_env/env_config.py)
|
||||
* `python scripts/generate_test_data.py` will generate individual config file for each test example in [config_files](config_files)
|
||||
* `bash prepare.sh` to obtain the auto-login cookies for all websites
|
||||
* export OPENAI_API_KEY=your_key
|
||||
* `python run.py --instruction_path agent/prompts/jsons/p_cot_id_actree_2s.json --test_start_idx 0 --test_end_idx 1 --model gpt-3.5-turbo --result_dir your_result_dir` to run the first example with GPT-3.5 reasoning agent. The trajectory will be saved in `your_result_dir/0.html`
|
||||
1
agent/__init__.py
Normal file
1
agent/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .agent import *
|
||||
177
agent/agent.py
Normal file
177
agent/agent.py
Normal file
@ -0,0 +1,177 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env.actions import (
|
||||
Action,
|
||||
ActionParsingError,
|
||||
create_id_based_action,
|
||||
create_none_action,
|
||||
create_playwright_action,
|
||||
)
|
||||
from browser_env.utils import Observation, StateInfo
|
||||
from llms import lm_config
|
||||
from llms.providers.openai_utils import (
|
||||
generate_from_openai_chat_completion,
|
||||
generate_from_openai_completion,
|
||||
)
|
||||
|
||||
from .utils import *
|
||||
|
||||
# from llms.providers.openai_utils import generate_from_openai_completion
|
||||
# from llms.providers.openai_utils import fake_generate_from_openai_chat_completion as generate_from_openai_chat_completion
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Base class for the agent"""
|
||||
|
||||
def __init__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
def next_action(
|
||||
self, trajectory: Trajectory, intent: str, meta_data: Any
|
||||
) -> Action:
|
||||
"""Predict the next action given the observation"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(
|
||||
self,
|
||||
test_config_file: str,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TeacherForcingAgent(Agent):
|
||||
"""Agent that follows a pre-defined action sequence"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@beartype
|
||||
def set_action_set_tag(self, tag: str) -> None:
|
||||
self.action_set_tag = tag
|
||||
|
||||
@beartype
|
||||
def set_actions(self, action_seq: str | list[str]) -> None:
|
||||
if isinstance(action_seq, str):
|
||||
action_strs = action_seq.strip().split("\n")
|
||||
else:
|
||||
action_strs = action_seq
|
||||
action_strs = [a.strip() for a in action_strs]
|
||||
|
||||
actions = []
|
||||
for a_str in action_strs:
|
||||
try:
|
||||
if self.action_set_tag == "playwright":
|
||||
cur_action = create_playwright_action(a_str)
|
||||
elif self.action_set_tag == "id_accessibility_tree":
|
||||
cur_action = create_id_based_action(a_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown action type {self.action_set_tag}"
|
||||
)
|
||||
except ActionParsingError as e:
|
||||
cur_action = create_none_action()
|
||||
|
||||
cur_action["raw_prediction"] = a_str
|
||||
actions.append(cur_action)
|
||||
|
||||
self.actions: list[Action] = actions
|
||||
|
||||
@beartype
|
||||
def next_action(
|
||||
self, trajectory: Trajectory, intent: str, meta_data: Any
|
||||
) -> Action:
|
||||
"""Predict the next action given the observation"""
|
||||
return self.actions.pop(0)
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
test_config_file: str,
|
||||
) -> None:
|
||||
with open(test_config_file) as f:
|
||||
ref_actions = json.load(f)["reference_action_sequence"]
|
||||
tag = ref_actions["action_set_tag"]
|
||||
action_seq = ref_actions["action_sequence"]
|
||||
self.set_action_set_tag(tag)
|
||||
self.set_actions(action_seq)
|
||||
|
||||
|
||||
class PromptAgent(Agent):
|
||||
"""prompt-based agent that emits action given the history"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_set_tag: str,
|
||||
lm_config: lm_config.LMConfig,
|
||||
prompt_constructor: PromptConstructor,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lm_config = lm_config
|
||||
self.prompt_constructor = prompt_constructor
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
@beartype
|
||||
def set_action_set_tag(self, tag: str) -> None:
|
||||
self.action_set_tag = tag
|
||||
|
||||
@beartype
|
||||
def next_action(
|
||||
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
|
||||
) -> Action:
|
||||
prompt = self.prompt_constructor.construct(
|
||||
trajectory, intent, meta_data
|
||||
)
|
||||
lm_config = self.lm_config
|
||||
if lm_config.provider == "openai":
|
||||
if lm_config.mode == "chat":
|
||||
response = generate_from_openai_chat_completion(
|
||||
messages=prompt,
|
||||
model=lm_config.model,
|
||||
temperature=lm_config.gen_config["temperature"],
|
||||
top_p=lm_config.gen_config["top_p"],
|
||||
context_length=lm_config.gen_config["context_length"],
|
||||
max_tokens=lm_config.gen_config["max_tokens"],
|
||||
stop_token=None,
|
||||
)
|
||||
elif lm_config.mode == "completion":
|
||||
response = generate_from_openai_completion(
|
||||
prompt=prompt,
|
||||
engine=lm_config.model,
|
||||
temperature=lm_config.gen_config["temperature"],
|
||||
max_tokens=lm_config.gen_config["max_tokens"],
|
||||
top_p=lm_config.gen_config["top_p"],
|
||||
stop_token=lm_config.gen_config["stop_token"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"OpenAI models do not support mode {lm_config.mode}"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {lm_config.provider} not implemented"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_response = self.prompt_constructor.extract_action(response)
|
||||
if self.action_set_tag == "id_accessibility_tree":
|
||||
action = create_id_based_action(parsed_response)
|
||||
elif self.action_set_tag == "playwright":
|
||||
action = create_playwright_action(parsed_response)
|
||||
else:
|
||||
raise ValueError(f"Unknown action type {self.action_set_tag}")
|
||||
|
||||
action["raw_prediction"] = response
|
||||
|
||||
except ActionParsingError as e:
|
||||
action = create_none_action()
|
||||
action["raw_prediction"] = response
|
||||
|
||||
return action
|
||||
|
||||
def reset(self, test_config_file: str) -> None:
|
||||
pass
|
||||
2
agent/prompts/README.md
Normal file
2
agent/prompts/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
## Naming of the prompt files
|
||||
`description.action_space.observation_space.json`
|
||||
1
agent/prompts/__init__.py
Normal file
1
agent/prompts/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .prompt_constructor import *
|
||||
240
agent/prompts/prompt_constructor.py
Normal file
240
agent/prompts/prompt_constructor.py
Normal file
@ -0,0 +1,240 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
|
||||
from agent.utils import Trajectory
|
||||
from browser_env import Action, ActionParsingError
|
||||
from browser_env.env_config import URL_MAPPINGS
|
||||
from browser_env.utils import StateInfo
|
||||
from llms import lm_config
|
||||
|
||||
APIInput = str | list[Any] | dict[str, Any]
|
||||
|
||||
|
||||
class Instruction(TypedDict):
|
||||
"""Instruction for constructing prompt"""
|
||||
|
||||
intro: str
|
||||
examples: list[tuple[str, str]]
|
||||
template: str
|
||||
meta_data: dict[str, Any]
|
||||
|
||||
|
||||
class PromptConstructor(object):
|
||||
def __init__(
|
||||
self,
|
||||
instruction_path: str | Path,
|
||||
lm_config: lm_config.LMConfig,
|
||||
tokenizer: tiktoken.core.Encoding,
|
||||
):
|
||||
self.instrction_path = Path(instruction_path)
|
||||
self.obs_modality = "text"
|
||||
self.lm_config = lm_config
|
||||
instruction = json.load(open(self.instrction_path))
|
||||
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
|
||||
self.instruction: Instruction = instruction
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@beartype
|
||||
def get_lm_api_input(
|
||||
self, intro: str, examples: list[tuple[str, str]], current: str
|
||||
) -> APIInput:
|
||||
|
||||
"""Return the require format for an API"""
|
||||
message: list[dict[str, str]] | str
|
||||
if "openai" in self.lm_config.provider:
|
||||
if self.lm_config.mode == "chat":
|
||||
message = [{"role": "system", "content": intro}]
|
||||
for (x, y) in examples:
|
||||
message.append(
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_user",
|
||||
"content": x,
|
||||
}
|
||||
)
|
||||
message.append(
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_assistant",
|
||||
"content": y,
|
||||
}
|
||||
)
|
||||
message.append({"role": "user", "content": current})
|
||||
return message
|
||||
elif self.lm_config.mode == "completion":
|
||||
message = f"{intro}\n\n"
|
||||
message += "Here are a few examples:\n"
|
||||
for example in examples:
|
||||
message += f"Observation\n:{example[0]}\n\n"
|
||||
message += f"Action: {example[1]}\n\n"
|
||||
message += "Now make prediction given the observation\n\n"
|
||||
message += f"Observation\n:{current}\n\n"
|
||||
message += "Action:"
|
||||
return message
|
||||
else:
|
||||
raise ValueError(
|
||||
f"OpenAI models do not support mode {self.lm_config.mode}"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {self.lm_config.provider} not implemented"
|
||||
)
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
intent: str,
|
||||
meta_data: dict[str, Any] = {},
|
||||
) -> APIInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@beartype
|
||||
def map_url_to_real(self, url: str) -> str:
|
||||
"""Map the urls to their real world counterparts"""
|
||||
for i, j in URL_MAPPINGS.items():
|
||||
if i in url:
|
||||
url = url.replace(i, j)
|
||||
return url
|
||||
|
||||
@beartype
|
||||
def map_url_to_local(self, url: str) -> str:
|
||||
"""Map the urls to their local counterparts"""
|
||||
for i, j in URL_MAPPINGS.items():
|
||||
if j in url:
|
||||
url = url.replace(j, i)
|
||||
return url
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@beartype
|
||||
def extract_action(self, response: str) -> str:
|
||||
response = self._extract_action(response)
|
||||
response = self.map_url_to_local(response)
|
||||
return response
|
||||
|
||||
|
||||
class DirectPromptConstructor(PromptConstructor):
|
||||
"""The agent will direct predict the action"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_path: str | Path,
|
||||
lm_config: lm_config.LMConfig,
|
||||
tokenizer: tiktoken.core.Encoding,
|
||||
):
|
||||
super().__init__(instruction_path, lm_config, tokenizer)
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
intent: str,
|
||||
meta_data: dict[str, Any] = {},
|
||||
) -> APIInput:
|
||||
"""Construct prompt given the trajectory"""
|
||||
intro = self.instruction["intro"]
|
||||
examples = self.instruction["examples"]
|
||||
template = self.instruction["template"]
|
||||
keywords = self.instruction["meta_data"]["keywords"]
|
||||
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||
|
||||
obs = state_info["observation"][self.obs_modality]
|
||||
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||
if max_obs_length:
|
||||
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||
|
||||
page = state_info["info"]["page"]
|
||||
url = page.url
|
||||
previous_action_str = meta_data["action_history"][-1]
|
||||
|
||||
# input x
|
||||
current = template.format(
|
||||
objective=intent,
|
||||
url=self.map_url_to_real(url),
|
||||
observation=obs,
|
||||
previous_action=previous_action_str,
|
||||
)
|
||||
|
||||
# make sure all keywords are replaced
|
||||
assert all([f"{{k}}" not in current for k in keywords])
|
||||
prompt = self.get_lm_api_input(intro, examples, current)
|
||||
return prompt
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||
pattern = rf"{action_splitter}(.*?){action_splitter}"
|
||||
match = re.search(pattern, response)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
raise ActionParsingError(
|
||||
f"Cannot parse action from response {response}"
|
||||
)
|
||||
|
||||
|
||||
class CoTPromptConstructor(PromptConstructor):
|
||||
"""The agent will perform step-by-step reasoning before the answer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_path: str | Path,
|
||||
lm_config: lm_config.LMConfig,
|
||||
tokenizer: tiktoken.core.Encoding,
|
||||
):
|
||||
super().__init__(instruction_path, lm_config, tokenizer)
|
||||
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
|
||||
|
||||
@beartype
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
intent: str,
|
||||
meta_data: dict[str, Any] = {},
|
||||
) -> APIInput:
|
||||
intro = self.instruction["intro"]
|
||||
examples = self.instruction["examples"]
|
||||
template = self.instruction["template"]
|
||||
keywords = self.instruction["meta_data"]["keywords"]
|
||||
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||
|
||||
obs = state_info["observation"][self.obs_modality]
|
||||
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||
if max_obs_length:
|
||||
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||
|
||||
page = state_info["info"]["page"]
|
||||
url = page.url
|
||||
previous_action_str = meta_data["action_history"][-1]
|
||||
current = template.format(
|
||||
objective=intent,
|
||||
url=self.map_url_to_real(url),
|
||||
observation=obs,
|
||||
previous_action=previous_action_str,
|
||||
)
|
||||
|
||||
assert all([f"{{k}}" not in current for k in keywords])
|
||||
|
||||
prompt = self.get_lm_api_input(intro, examples, current)
|
||||
return prompt
|
||||
|
||||
@beartype
|
||||
def _extract_action(self, response: str) -> str:
|
||||
# find the first occurence of action
|
||||
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||
pattern = rf"{action_splitter}(.*?){action_splitter}"
|
||||
match = re.search(pattern, response)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
raise ActionParsingError(
|
||||
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
|
||||
)
|
||||
82
agent/prompts/raw/p_cot_id_actree_2s.py
Normal file
82
agent/prompts/raw/p_cot_id_actree_2s.py
Normal file
@ -0,0 +1,82 @@
|
||||
prompt = {
|
||||
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
|
||||
|
||||
Here's the information you'll have:
|
||||
The user's objective: This is the task you're trying to complete.
|
||||
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
|
||||
The current web page's URL: This is the page you're currently navigating.
|
||||
The open tabs: These are the tabs you have open.
|
||||
The previous action: This is the action you just performed. It may be helpful to track your progress.
|
||||
|
||||
The actions you can perform fall into several categories:
|
||||
|
||||
Page Operation Actions:
|
||||
`click [id]`: This action clicks on an element with a specific id on the webpage.
|
||||
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
|
||||
`hover [id]`: Hover over an element with id.
|
||||
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
|
||||
`scroll [direction=down|up]`: Scroll the page up or down.
|
||||
|
||||
Tab Management Actions:
|
||||
`new_tab`: Open a new, empty browser tab.
|
||||
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
|
||||
`close_tab`: Close the currently active tab.
|
||||
|
||||
URL Navigation Actions:
|
||||
`goto [url]`: Navigate to a specific URL.
|
||||
`go_back`: Navigate to the previously viewed page.
|
||||
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).
|
||||
|
||||
Completion Action:
|
||||
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. If you believe the task is impossible to complete, provide the answer as "N/A" in the bracket.
|
||||
|
||||
Homepage:
|
||||
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
|
||||
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.
|
||||
|
||||
To be successful, it is very important to follow the following rules:
|
||||
1. You should only issue an action that is valid given the current observation
|
||||
2. You should only issue one action at a time.
|
||||
3. You should follow the examples to reason step by step and then issue the next action.
|
||||
4. Generate the action in the correct format. Start with a "In summary, the next action I will perform is" phrase, followed by action inside ``````. For example, "In summary, the next action I will perform is ```click [1234]```".
|
||||
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
|
||||
"examples": [
|
||||
(
|
||||
"""OBSERVATION:
|
||||
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
|
||||
[1749] StaticText '$279.49'
|
||||
[1757] button 'Add to Cart'
|
||||
[1760] button 'Add to Wish List'
|
||||
[1761] button 'Add to Compare'
|
||||
URL: http://onestopmarket.com/office-products/office-electronics.html
|
||||
OBJECTIVE: What is the price of HP Inkjet Fax Machine
|
||||
PREVIOUS ACTION: None""",
|
||||
"Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```",
|
||||
),
|
||||
(
|
||||
"""OBSERVATION:
|
||||
[164] textbox 'Search' focused: True required: False
|
||||
[171] button 'Go'
|
||||
[174] link 'Find directions between two points'
|
||||
[212] heading 'Search Results'
|
||||
[216] button 'Close'
|
||||
URL: http://openstreetmap.org
|
||||
OBJECTIVE: Show me the restaurants near CMU
|
||||
PREVIOUS ACTION: None""",
|
||||
"Let's think step-by-step. This page has a search box whose ID is [164]. According to the nominatim rule of openstreetmap, I can search for the restaurants near a location by \"restaurants near\". I can submit my typing by pressing the Enter afterwards. In summary, the next action I will perform is ```type [164] [restaurants near CMU] [1]```",
|
||||
),
|
||||
],
|
||||
"template": """OBSERVATION:
|
||||
{observation}
|
||||
URL: {url}
|
||||
OBJECTIVE: {objective}
|
||||
PREVIOUS ACTION: {previous_action}""",
|
||||
"meta_data": {
|
||||
"observation": "accessibility_tree",
|
||||
"action_type": "id_accessibility_tree",
|
||||
"keywords": ["url", "objective", "observation", "previous_action"],
|
||||
"prompt_constructor": "CoTPromptConstructor",
|
||||
"answer_phrase": "In summary, the next action I will perform is",
|
||||
"action_splitter": "```"
|
||||
},
|
||||
}
|
||||
80
agent/prompts/raw/p_direct_id_actree_2s.py
Normal file
80
agent/prompts/raw/p_direct_id_actree_2s.py
Normal file
@ -0,0 +1,80 @@
|
||||
prompt = {
|
||||
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
|
||||
|
||||
Here's the information you'll have:
|
||||
The user's objective: This is the task you're trying to complete.
|
||||
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
|
||||
The current web page's URL: This is the page you're currently navigating.
|
||||
The open tabs: These are the tabs you have open.
|
||||
The previous action: This is the action you just performed. It may be helpful to track your progress.
|
||||
|
||||
The actions you can perform fall into several categories:
|
||||
|
||||
Page Operation Actions:
|
||||
`click [id]`: This action clicks on an element with a specific id on the webpage.
|
||||
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
|
||||
`hover [id]`: Hover over an element with id.
|
||||
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
|
||||
`scroll [direction=down|up]`: Scroll the page up or down.
|
||||
|
||||
Tab Management Actions:
|
||||
`new_tab`: Open a new, empty browser tab.
|
||||
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
|
||||
`close_tab`: Close the currently active tab.
|
||||
|
||||
URL Navigation Actions:
|
||||
`goto [url]`: Navigate to a specific URL.
|
||||
`go_back`: Navigate to the previously viewed page.
|
||||
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).
|
||||
|
||||
Completion Action:
|
||||
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. If you believe the task is impossible to complete, provide the answer as "N/A" in the bracket.
|
||||
|
||||
Homepage:
|
||||
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
|
||||
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.
|
||||
|
||||
To be successful, it is very important to follow the following rules:
|
||||
1. You should only issue an action that is valid given the current observation
|
||||
2. You should only issue one action at a time.
|
||||
3. Generate the action in the correct format. Always put the action inside a pair of ```. For example, ```click [1234]```.
|
||||
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
|
||||
"examples": [
|
||||
(
|
||||
"""OBSERVATION:
|
||||
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
|
||||
[1749] StaticText '$279.49'
|
||||
[1757] button 'Add to Cart'
|
||||
[1760] button 'Add to Wish List'
|
||||
[1761] button 'Add to Compare'
|
||||
URL: http://onestopmarket.com/office-products/office-electronics.html
|
||||
OBJECTIVE: What is the price of HP Inkjet Fax Machine
|
||||
PREVIOUS ACTION: None""",
|
||||
"```stop [$279.49]```",
|
||||
),
|
||||
(
|
||||
"""OBSERVATION:
|
||||
[164] textbox 'Search' focused: True required: False
|
||||
[171] button 'Go'
|
||||
[174] link 'Find directions between two points'
|
||||
[212] heading 'Search Results'
|
||||
[216] button 'Close'
|
||||
URL: http://openstreetmap.org
|
||||
OBJECTIVE: Show me the restaurants near CMU
|
||||
PREVIOUS ACTION: None""",
|
||||
"```type [164] [restaurants near CMU] [1]```",
|
||||
),
|
||||
],
|
||||
"template": """OBSERVATION:
|
||||
{observation}
|
||||
URL: {url}
|
||||
OBJECTIVE: {objective}
|
||||
PREVIOUS ACTION: {previous_action}""",
|
||||
"meta_data": {
|
||||
"observation": "accessibility_tree",
|
||||
"action_type": "id_accessibility_tree",
|
||||
"keywords": ["url", "objective", "observation", "previous_action"],
|
||||
"prompt_constructor": "DirectPromptConstructor",
|
||||
"action_splitter": "```"
|
||||
},
|
||||
}
|
||||
21
agent/prompts/to_json.py
Normal file
21
agent/prompts/to_json.py
Normal file
@ -0,0 +1,21 @@
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# use the current directory as the root
|
||||
def run() -> None:
|
||||
"""Convert all python files in agent/prompts to json files in agent/prompts/jsons
|
||||
|
||||
Python files are easiser to edit
|
||||
"""
|
||||
for p_file in glob.glob(f"agent/prompts/raw/*.py"):
|
||||
# import the file as a module
|
||||
base_name = os.path.basename(p_file).replace(".py", "")
|
||||
module = importlib.import_module(f"agent.prompts.raw.{base_name}")
|
||||
prompt = module.prompt
|
||||
# save the prompt as a json file
|
||||
with open(f"agent/prompts/jsons/{base_name}.json", "w+") as f:
|
||||
json.dump(prompt, f, indent=2)
|
||||
print(f"Done convert python files to json")
|
||||
6
agent/utils.py
Normal file
6
agent/utils.py
Normal file
@ -0,0 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from browser_env.actions import Action
|
||||
from browser_env.utils import StateInfo
|
||||
|
||||
Trajectory = list[Union[StateInfo, Action]]
|
||||
74
browser_env/__init__.py
Normal file
74
browser_env/__init__.py
Normal file
@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
|
||||
from .actions import (
|
||||
Action,
|
||||
ActionParsingError,
|
||||
ActionTypes,
|
||||
action2create_function,
|
||||
action2str,
|
||||
create_check_action,
|
||||
create_click_action,
|
||||
create_focus_and_click_action,
|
||||
create_focus_and_type_action,
|
||||
create_go_back_action,
|
||||
create_go_forward_action,
|
||||
create_goto_url_action,
|
||||
create_hover_action,
|
||||
create_id_based_action,
|
||||
create_key_press_action,
|
||||
create_keyboard_type_action,
|
||||
create_mouse_click_action,
|
||||
create_mouse_hover_action,
|
||||
create_new_tab_action,
|
||||
create_none_action,
|
||||
create_page_close_action,
|
||||
create_page_focus_action,
|
||||
create_playwright_action,
|
||||
create_random_action,
|
||||
create_scroll_action,
|
||||
create_select_option_action,
|
||||
create_stop_action,
|
||||
create_type_action,
|
||||
is_equivalent,
|
||||
)
|
||||
from .async_envs import AsyncScriptBrowserEnv
|
||||
from .envs import ScriptBrowserEnv
|
||||
from .processors import ObservationMetadata
|
||||
from .utils import DetachedPage, StateInfo
|
||||
|
||||
__all__ = [
|
||||
"ScriptBrowserEnv",
|
||||
"AsyncScriptBrowserEnv",
|
||||
"DetachedPage",
|
||||
"StateInfo",
|
||||
"ObservationMetadata",
|
||||
"Action",
|
||||
"ActionTypes",
|
||||
"action2str",
|
||||
"create_random_action",
|
||||
"create_focus_and_click_action",
|
||||
"create_focus_and_type_action",
|
||||
"is_equivalent",
|
||||
"create_mouse_click_action",
|
||||
"create_mouse_hover_action",
|
||||
"create_none_action",
|
||||
"create_keyboard_type_action",
|
||||
"create_page_focus_action",
|
||||
"create_new_tab_action",
|
||||
"create_go_back_action",
|
||||
"create_go_forward_action",
|
||||
"create_goto_url_action",
|
||||
"create_page_close_action",
|
||||
"action2create_function",
|
||||
"create_playwright_action",
|
||||
"create_id_based_action",
|
||||
"create_scroll_action",
|
||||
"create_key_press_action",
|
||||
"create_check_action",
|
||||
"create_click_action",
|
||||
"create_type_action",
|
||||
"create_hover_action",
|
||||
"create_select_option_action",
|
||||
"create_stop_action",
|
||||
"ActionParsingError",
|
||||
]
|
||||
1610
browser_env/actions.py
Normal file
1610
browser_env/actions.py
Normal file
File diff suppressed because it is too large
Load Diff
160
browser_env/async_envs.py
Normal file
160
browser_env/async_envs.py
Normal file
@ -0,0 +1,160 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.async_api import Page, ViewportSize, async_playwright
|
||||
|
||||
from .actions import Action, aexecute_action, get_action_space
|
||||
from .utils import DetachedPage, png_bytes_to_numpy
|
||||
|
||||
|
||||
class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
"""
|
||||
The goal of this environment is to produce a prototype of a browser environment.
|
||||
In the end, we want to support a fully configurable browser environment with wide
|
||||
range of action spaces and observation spaces, both structured and unstructured.
|
||||
But in this prototype, we just support action space specified by Playwright script,
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 2048,
|
||||
headless: bool = True,
|
||||
slow_mo: int = 0,
|
||||
timeout: int = 30000,
|
||||
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||
):
|
||||
self.observation_space = Box(
|
||||
0,
|
||||
255,
|
||||
(viewport_size["height"], viewport_size["width"], 4),
|
||||
np.uint8,
|
||||
)
|
||||
# TODO: make Space[Action] = ActionSpace
|
||||
self.action_space = get_action_space() # type: ignore[assignment]
|
||||
self.headless = headless
|
||||
self.slow_mo = slow_mo
|
||||
self.reset_finished = False
|
||||
self.timeout = timeout
|
||||
self.viewport_size = viewport_size
|
||||
|
||||
@beartype
|
||||
async def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = async_playwright()
|
||||
self.playwright = await self.context_manager.__aenter__()
|
||||
self.browser = await self.playwright.chromium.launch(
|
||||
headless=self.headless, slow_mo=self.slow_mo
|
||||
)
|
||||
if config_file:
|
||||
with open(config_file, "r") as f:
|
||||
instance_config = json.load(f)
|
||||
else:
|
||||
instance_config = {}
|
||||
|
||||
storage_state = instance_config.get("storage_state", None)
|
||||
start_url = instance_config.get("start_url", None)
|
||||
geolocation = instance_config.get("geolocation", None)
|
||||
|
||||
self.context = await self.browser.new_context(
|
||||
viewport=self.viewport_size,
|
||||
storage_state=storage_state,
|
||||
geolocation=geolocation,
|
||||
device_scale_factor=1,
|
||||
)
|
||||
self.page = await self.context.new_page()
|
||||
if start_url:
|
||||
await self.page.goto(start_url)
|
||||
|
||||
@beartype
|
||||
async def areset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
|
||||
"""
|
||||
Reset the environment.
|
||||
:param options: options for the environment. The options are:
|
||||
- storage_state: the path to the storage state file
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
if self.reset_finished:
|
||||
await self.context_manager.__aexit__()
|
||||
if options is not None and "config_file" in options:
|
||||
config_file = Path(options["config_file"])
|
||||
if config_file.exists():
|
||||
await self.setup(config_file=config_file)
|
||||
else:
|
||||
raise ValueError(f"Config state {config_file} does not exist.")
|
||||
else:
|
||||
await self.setup()
|
||||
self.reset_finished = True
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
return (
|
||||
screenshot,
|
||||
{"page": DetachedPage(self.page.url, content)},
|
||||
)
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
|
||||
return asyncio.run(self.areset(seed=seed, options=options))
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.reset_finished:
|
||||
await self.context_manager.__aexit__()
|
||||
|
||||
def close(self) -> None:
|
||||
asyncio.run(self.aclose())
|
||||
|
||||
@beartype
|
||||
async def astep(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
if not self.reset_finished:
|
||||
raise RuntimeError("Call reset first before calling step.")
|
||||
success = False
|
||||
fail_error = ""
|
||||
try:
|
||||
self.page = await aexecute_action(action, self.page, self.context)
|
||||
success = True
|
||||
except Exception as e:
|
||||
fail_error = str(e)
|
||||
|
||||
try:
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
except:
|
||||
await self.page.wait_for_load_state("load")
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
|
||||
return (
|
||||
screenshot,
|
||||
float(success),
|
||||
False,
|
||||
False,
|
||||
{
|
||||
"page": DetachedPage(self.page.url, content),
|
||||
"fail_error": fail_error,
|
||||
},
|
||||
)
|
||||
|
||||
@beartype
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
return asyncio.run(self.astep(action), debug=True)
|
||||
128
browser_env/auto_login.py
Normal file
128
browser_env/auto_login.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Script to automatically login each website"""
|
||||
import glob
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from browser_env.env_config import (
|
||||
ACCOUNTS,
|
||||
GITLAB,
|
||||
REDDIT,
|
||||
SHOPPING,
|
||||
SHOPPING_ADMIN,
|
||||
)
|
||||
|
||||
HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
@beartype
|
||||
def is_expired(
|
||||
storage_state: Path, url: str, keyword: str, url_exact: bool = True
|
||||
) -> bool:
|
||||
"""Test whether the cookie is expired"""
|
||||
if not storage_state.exists():
|
||||
return True
|
||||
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
browser = playwright.chromium.launch(headless=HEADLESS, slow_mo=SLOW_MO)
|
||||
context = browser.new_context(storage_state=storage_state)
|
||||
page = context.new_page()
|
||||
page.goto(url)
|
||||
d_url = page.url
|
||||
content = page.content()
|
||||
context_manager.__exit__()
|
||||
if keyword:
|
||||
return keyword not in content
|
||||
else:
|
||||
if url_exact:
|
||||
return d_url != url
|
||||
else:
|
||||
return url not in d_url
|
||||
|
||||
|
||||
@beartype
|
||||
def renew_comb(comb: list[str]) -> None:
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
browser = playwright.chromium.launch(headless=HEADLESS)
|
||||
context = browser.new_context()
|
||||
page = context.new_page()
|
||||
|
||||
if "shopping" in comb:
|
||||
username = ACCOUNTS["shopping"]["username"]
|
||||
password = ACCOUNTS["shopping"]["password"]
|
||||
page.goto(f"{SHOPPING}/customer/account/login/")
|
||||
page.get_by_label("Email", exact=True).fill(username)
|
||||
page.get_by_label("Password", exact=True).fill(password)
|
||||
page.get_by_role("button", name="Sign In").click()
|
||||
|
||||
if "reddit" in comb:
|
||||
username = ACCOUNTS["reddit"]["username"]
|
||||
password = ACCOUNTS["reddit"]["password"]
|
||||
page.goto(f"{REDDIT}/login")
|
||||
page.get_by_label("Username").fill(username)
|
||||
page.get_by_label("Password").fill(password)
|
||||
page.get_by_role("button", name="Log in").click()
|
||||
|
||||
if "shopping_admin" in comb:
|
||||
username = ACCOUNTS["shopping_admin"]["username"]
|
||||
password = ACCOUNTS["shopping_admin"]["password"]
|
||||
page.goto(f"{SHOPPING_ADMIN}")
|
||||
page.get_by_placeholder("user name").fill(username)
|
||||
page.get_by_placeholder("password").fill(password)
|
||||
page.get_by_role("button", name="Sign in").click()
|
||||
|
||||
if "gitlab" in comb:
|
||||
username = ACCOUNTS["gitlab"]["username"]
|
||||
password = ACCOUNTS["gitlab"]["password"]
|
||||
page.goto(f"{GITLAB}/users/sign_in")
|
||||
page.get_by_test_id("username-field").click()
|
||||
page.get_by_test_id("username-field").fill(username)
|
||||
page.get_by_test_id("username-field").press("Tab")
|
||||
page.get_by_test_id("password-field").fill(password)
|
||||
page.get_by_test_id("sign-in-button").click()
|
||||
|
||||
context.storage_state(path=f"./.auth/{'.'.join(comb)}_state.json")
|
||||
|
||||
context_manager.__exit__()
|
||||
|
||||
|
||||
@beartype
|
||||
def main() -> None:
|
||||
sites = ["gitlab", "shopping", "shopping_admin", "reddit"]
|
||||
urls = [
|
||||
f"{GITLAB}/-/profile",
|
||||
f"{SHOPPING}/wishlist/",
|
||||
f"{SHOPPING_ADMIN}/dashboard",
|
||||
f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account",
|
||||
]
|
||||
exact_match = [True, True, True, True]
|
||||
keywords = ["", "", "Dashboard", "Delete"]
|
||||
|
||||
pairs = list(combinations(sites, 2))
|
||||
for pair in pairs:
|
||||
# TODO[shuyanzh] auth don't work on these two sites
|
||||
if "reddit" in pair and (
|
||||
"shopping" in pair or "shopping_admin" in pair
|
||||
):
|
||||
continue
|
||||
renew_comb(list(sorted(pair)))
|
||||
|
||||
for site in sites:
|
||||
renew_comb([site])
|
||||
|
||||
for c_file in glob.glob("./.auth/*.json"):
|
||||
comb = c_file.split("/")[-1].rsplit("_", 1)[0].split(".")
|
||||
for cur_site in comb:
|
||||
url = urls[sites.index(cur_site)]
|
||||
keyword = keywords[sites.index(cur_site)]
|
||||
match = exact_match[sites.index(cur_site)]
|
||||
assert not is_expired(Path(c_file), url, keyword, match)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
295
browser_env/constants.py
Normal file
295
browser_env/constants.py
Normal file
@ -0,0 +1,295 @@
|
||||
from typing import Literal
|
||||
|
||||
ROLES = (
|
||||
"alert",
|
||||
"alertdialog",
|
||||
"application",
|
||||
"article",
|
||||
"banner",
|
||||
"blockquote",
|
||||
"button",
|
||||
"caption",
|
||||
"cell",
|
||||
"checkbox",
|
||||
"code",
|
||||
"columnheader",
|
||||
"combobox",
|
||||
"complementary",
|
||||
"contentinfo",
|
||||
"definition",
|
||||
"deletion",
|
||||
"dialog",
|
||||
"directory",
|
||||
"document",
|
||||
"emphasis",
|
||||
"feed",
|
||||
"figure",
|
||||
"form",
|
||||
"generic",
|
||||
"grid",
|
||||
"gridcell",
|
||||
"group",
|
||||
"heading",
|
||||
"img",
|
||||
"insertion",
|
||||
"link",
|
||||
"list",
|
||||
"listbox",
|
||||
"listitem",
|
||||
"log",
|
||||
"main",
|
||||
"marquee",
|
||||
"math",
|
||||
"meter",
|
||||
"menu",
|
||||
"menubar",
|
||||
"menuitem",
|
||||
"menuitemcheckbox",
|
||||
"menuitemradio",
|
||||
"navigation",
|
||||
"none",
|
||||
"note",
|
||||
"option",
|
||||
"paragraph",
|
||||
"presentation",
|
||||
"progressbar",
|
||||
"radio",
|
||||
"radiogroup",
|
||||
"region",
|
||||
"row",
|
||||
"rowgroup",
|
||||
"rowheader",
|
||||
"scrollbar",
|
||||
"search",
|
||||
"searchbox",
|
||||
"separator",
|
||||
"slider",
|
||||
"spinbutton",
|
||||
"status",
|
||||
"strong",
|
||||
"subscript",
|
||||
"superscript",
|
||||
"switch",
|
||||
"tab",
|
||||
"table",
|
||||
"tablist",
|
||||
"tabpanel",
|
||||
"term",
|
||||
"textbox",
|
||||
"time",
|
||||
"timer",
|
||||
"toolbar",
|
||||
"tooltip",
|
||||
"tree",
|
||||
"treegrid",
|
||||
"treeitem",
|
||||
)
|
||||
|
||||
SPECIAL_LOCATORS = (
|
||||
"alt_text",
|
||||
"label",
|
||||
"placeholder",
|
||||
)
|
||||
|
||||
ASCII_CHARSET = "".join(chr(x) for x in range(32, 128))
|
||||
FREQ_UNICODE_CHARSET = "".join(chr(x) for x in range(129, 1000))
|
||||
UTTERANCE_MAX_LENGTH = 8192
|
||||
ATTRIBUTE_MAX_LENGTH = 256
|
||||
TEXT_MAX_LENGTH = 256
|
||||
TYPING_MAX_LENGTH = 64
|
||||
URL_MAX_LENGTH = 256
|
||||
MAX_ELEMENT_INDEX_IN_VIEWPORT = 10
|
||||
MAX_ELEMENT_ID = 1000
|
||||
MAX_ANSWER_LENGTH = 512
|
||||
|
||||
MIN_REF = -1000000
|
||||
MAX_REF = 1000000
|
||||
|
||||
WINDOW_WIDTH = 500
|
||||
WINDOW_HEIGHT = 240
|
||||
TASK_WIDTH = 160
|
||||
TASK_HEIGHT = 210
|
||||
|
||||
FLIGHT_WINDOW_WIDTH = 600
|
||||
FLIGHT_WINDOW_HEIGHT = 700
|
||||
FLIGHT_TASK_WIDTH = 375
|
||||
FLIGHT_TASK_HEIGHT = 667
|
||||
MAX_PAGE_NUMBER = 10
|
||||
|
||||
SPECIAL_KEYS = (
|
||||
"Enter",
|
||||
"Tab",
|
||||
"Control",
|
||||
"Shift",
|
||||
"Meta",
|
||||
"Backspace",
|
||||
"Delete",
|
||||
"Escape",
|
||||
"ArrowUp",
|
||||
"ArrowDown",
|
||||
"ArrowLeft",
|
||||
"ArrowRight",
|
||||
"PageDown",
|
||||
"PageUp",
|
||||
"Meta+a",
|
||||
)
|
||||
|
||||
SPECIAL_KEY_MAPPINGS = {
|
||||
"backquote": "Backquote",
|
||||
"minus": "Minus",
|
||||
"equal": "Equal",
|
||||
"backslash": "Backslash",
|
||||
"backspace": "Backspace",
|
||||
"meta": "Meta",
|
||||
"tab": "Tab",
|
||||
"delete": "Delete",
|
||||
"escape": "Escape",
|
||||
"arrowdown": "ArrowDown",
|
||||
"end": "End",
|
||||
"enter": "Enter",
|
||||
"home": "Home",
|
||||
"insert": "Insert",
|
||||
"pagedown": "PageDown",
|
||||
"pageup": "PageUp",
|
||||
"arrowright": "ArrowRight",
|
||||
"arrowup": "ArrowUp",
|
||||
"f1": "F1",
|
||||
"f2": "F2",
|
||||
"f3": "F3",
|
||||
"f4": "F4",
|
||||
"f5": "F5",
|
||||
"f6": "F6",
|
||||
"f7": "F7",
|
||||
"f8": "F8",
|
||||
"f9": "F9",
|
||||
"f10": "F10",
|
||||
"f11": "F11",
|
||||
"f12": "F12",
|
||||
}
|
||||
|
||||
RolesType = Literal[
|
||||
"alert",
|
||||
"alertdialog",
|
||||
"application",
|
||||
"article",
|
||||
"banner",
|
||||
"blockquote",
|
||||
"button",
|
||||
"caption",
|
||||
"cell",
|
||||
"checkbox",
|
||||
"code",
|
||||
"columnheader",
|
||||
"combobox",
|
||||
"complementary",
|
||||
"contentinfo",
|
||||
"definition",
|
||||
"deletion",
|
||||
"dialog",
|
||||
"directory",
|
||||
"document",
|
||||
"emphasis",
|
||||
"feed",
|
||||
"figure",
|
||||
"form",
|
||||
"generic",
|
||||
"grid",
|
||||
"gridcell",
|
||||
"group",
|
||||
"heading",
|
||||
"img",
|
||||
"insertion",
|
||||
"link",
|
||||
"list",
|
||||
"listbox",
|
||||
"listitem",
|
||||
"log",
|
||||
"main",
|
||||
"marquee",
|
||||
"math",
|
||||
"meter",
|
||||
"menu",
|
||||
"menubar",
|
||||
"menuitem",
|
||||
"menuitemcheckbox",
|
||||
"menuitemradio",
|
||||
"navigation",
|
||||
"none",
|
||||
"note",
|
||||
"option",
|
||||
"paragraph",
|
||||
"presentation",
|
||||
"progressbar",
|
||||
"radio",
|
||||
"radiogroup",
|
||||
"region",
|
||||
"row",
|
||||
"rowgroup",
|
||||
"rowheader",
|
||||
"scrollbar",
|
||||
"search",
|
||||
"searchbox",
|
||||
"separator",
|
||||
"slider",
|
||||
"spinbutton",
|
||||
"status",
|
||||
"strong",
|
||||
"subscript",
|
||||
"superscript",
|
||||
"switch",
|
||||
"tab",
|
||||
"table",
|
||||
"tablist",
|
||||
"tabpanel",
|
||||
"term",
|
||||
"textbox",
|
||||
"time",
|
||||
"timer",
|
||||
"toolbar",
|
||||
"tooltip",
|
||||
"tree",
|
||||
"treegrid",
|
||||
"treeitem",
|
||||
"alt_text",
|
||||
"label",
|
||||
"placeholder",
|
||||
]
|
||||
|
||||
MAX_VANILLA_STR_LENGTH = 1000
|
||||
|
||||
PLAYWRIGHT_LOCATORS = (
|
||||
"get_by_role",
|
||||
"get_by_text",
|
||||
"get_by_label",
|
||||
"get_by_placeholder",
|
||||
"get_by_alt_text",
|
||||
"get_by_title",
|
||||
"get_by_test_id",
|
||||
"filter",
|
||||
"frame_locator",
|
||||
"locator",
|
||||
)
|
||||
|
||||
PLAYWRIGHT_ACTIONS = (
|
||||
"fill",
|
||||
"check",
|
||||
"select_option",
|
||||
"click",
|
||||
"hover",
|
||||
"dclick",
|
||||
"type",
|
||||
"focus",
|
||||
"goto",
|
||||
"press",
|
||||
"scroll",
|
||||
)
|
||||
|
||||
IGNORED_ACTREE_PROPERTIES = (
|
||||
"focusable",
|
||||
"editable",
|
||||
"readonly",
|
||||
"level",
|
||||
"settable",
|
||||
"multiline",
|
||||
"invalid",
|
||||
)
|
||||
39
browser_env/env_config.py
Normal file
39
browser_env/env_config.py
Normal file
@ -0,0 +1,39 @@
|
||||
# websites domain
|
||||
REDDIT = ""
|
||||
SHOPPING = ""
|
||||
SHOPPING_ADMIN = ""
|
||||
GITLAB = ""
|
||||
WIKIPEDIA = ""
|
||||
MAP = ""
|
||||
HOMEPAGE = ""
|
||||
|
||||
assert (
|
||||
REDDIT
|
||||
and SHOPPING
|
||||
and SHOPPING_ADMIN
|
||||
and GITLAB
|
||||
and WIKIPEDIA
|
||||
and MAP
|
||||
and HOMEPAGE
|
||||
), "Please setup the URLs to each site"
|
||||
|
||||
ACCOUNTS = {
|
||||
"reddit": {"username": "MarvelsGrantMan136", "password": "test1234"},
|
||||
"gitlab": {"username": "byteblaze", "password": "hello1234"},
|
||||
"shopping": {
|
||||
"username": "emma.lopez@gmail.com",
|
||||
"password": "Password.123",
|
||||
},
|
||||
"shopping_admin": {"username": "admin", "password": "admin1234"},
|
||||
"shopping_site_admin": {"username": "admin", "password": "admin1234"},
|
||||
}
|
||||
|
||||
URL_MAPPINGS = {
|
||||
REDDIT: "http://reddit.com",
|
||||
SHOPPING: "http://onestopmarket.com",
|
||||
SHOPPING_ADMIN: "http://luma.com/admin",
|
||||
GITLAB: "http://gitlab.com",
|
||||
WIKIPEDIA: "http://wikipedia.org",
|
||||
MAP: "http://openstreetmap.org",
|
||||
HOMEPAGE: "http://homepage.com",
|
||||
}
|
||||
274
browser_env/envs.py
Normal file
274
browser_env/envs.py
Normal file
@ -0,0 +1,274 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.sync_api import (
|
||||
CDPSession,
|
||||
Page,
|
||||
Playwright,
|
||||
ViewportSize,
|
||||
expect,
|
||||
sync_playwright,
|
||||
)
|
||||
|
||||
from .actions import Action, execute_action, get_action_space
|
||||
from .processors import ObservationHandler, ObservationMetadata
|
||||
from .utils import (
|
||||
AccessibilityTree,
|
||||
DetachedPage,
|
||||
Observation,
|
||||
png_bytes_to_numpy,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaywrightScript:
|
||||
function: str # goto, get_by_role
|
||||
destination: str # https://www.google.com/, combobox
|
||||
name: str | None = None # Search, Avatar 2009
|
||||
operation: str | None = None # click, fill, press
|
||||
value: str | None = None # avatar movie, Enter
|
||||
|
||||
|
||||
@beartype
|
||||
def parse_action(action: str) -> PlaywrightScript:
|
||||
splitted = action.strip().split(" ")
|
||||
assert len(splitted) >= 2
|
||||
match splitted[:2]:
|
||||
case ["goto", url]:
|
||||
assert len(splitted) == 2
|
||||
return PlaywrightScript("goto", url)
|
||||
case ["get_by_role", destination]:
|
||||
assert len(splitted) >= 4
|
||||
match splitted[2:]:
|
||||
case [name, operation]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation
|
||||
)
|
||||
case [name, operation, value]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation, value
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid action")
|
||||
case _:
|
||||
raise ValueError(f"Invalid action {action}")
|
||||
|
||||
|
||||
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
"""
|
||||
The goal of this environment is to produce a prototype of a browser environment.
|
||||
In the end, we want to support a fully configurable browser environment with wide
|
||||
range of action spaces and observation spaces, both structured and unstructured.
|
||||
But in this prototype, we just support action space specified by Playwright script,
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 8192,
|
||||
headless: bool = True,
|
||||
slow_mo: int = 0,
|
||||
observation_type: str = "html",
|
||||
current_viewport_only: bool = False,
|
||||
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||
save_trace_enabled: bool = False,
|
||||
sleep_after_execution: float = 0.0,
|
||||
):
|
||||
# TODO: make Space[Action] = ActionSpace
|
||||
self.action_space = get_action_space() # type: ignore[assignment]
|
||||
self.headless = headless
|
||||
self.slow_mo = slow_mo
|
||||
self.current_viewport_only = current_viewport_only
|
||||
self.reset_finished = False
|
||||
self.viewport_size = viewport_size
|
||||
self.save_trace_enabled = save_trace_enabled
|
||||
self.sleep_after_execution = sleep_after_execution
|
||||
|
||||
match observation_type:
|
||||
case "html" | "accessibility_tree":
|
||||
self.text_observation_type = observation_type
|
||||
self.image_observation_type = ""
|
||||
self.main_observation_type = "text"
|
||||
case "image":
|
||||
self.image_observation_type = observation_type
|
||||
self.text_observation_type = "" # type: ignore[assignment]
|
||||
self.main_observation_type = "image"
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported observation type: {observation_type}"
|
||||
)
|
||||
|
||||
self.observation_handler = ObservationHandler(
|
||||
self.main_observation_type,
|
||||
self.text_observation_type,
|
||||
self.image_observation_type,
|
||||
self.current_viewport_only,
|
||||
self.viewport_size,
|
||||
)
|
||||
|
||||
self.observation_space = (
|
||||
self.observation_handler.get_observation_space()
|
||||
)
|
||||
|
||||
@beartype
|
||||
def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = sync_playwright()
|
||||
self.playwright = self.context_manager.__enter__()
|
||||
self.browser = self.playwright.chromium.launch(
|
||||
headless=self.headless, slow_mo=self.slow_mo
|
||||
)
|
||||
|
||||
if config_file:
|
||||
with open(config_file, "r") as f:
|
||||
instance_config = json.load(f)
|
||||
else:
|
||||
instance_config = {}
|
||||
|
||||
storage_state = instance_config.get("storage_state", None)
|
||||
start_url = instance_config.get("start_url", None)
|
||||
geolocation = instance_config.get("geolocation", None)
|
||||
|
||||
self.context = self.browser.new_context(
|
||||
viewport=self.viewport_size,
|
||||
storage_state=storage_state,
|
||||
geolocation=geolocation,
|
||||
device_scale_factor=1,
|
||||
)
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.start(screenshots=True, snapshots=True)
|
||||
if start_url:
|
||||
start_urls = start_url.split(" |AND| ")
|
||||
for url in start_urls:
|
||||
page = self.context.new_page()
|
||||
client = page.context.new_cdp_session(
|
||||
page
|
||||
) # talk to chrome devtools
|
||||
if self.text_observation_type == "accessibility_tree":
|
||||
client.send("Accessibility.enable")
|
||||
page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
|
||||
page.goto(url)
|
||||
# set the first page as the current page
|
||||
self.page = self.context.pages[0]
|
||||
self.page.bring_to_front()
|
||||
else:
|
||||
self.page = self.context.new_page()
|
||||
client = self.page.context.new_cdp_session(self.page)
|
||||
if self.text_observation_type == "accessibility_tree":
|
||||
client.send("Accessibility.enable")
|
||||
self.page.client = client # type: ignore
|
||||
|
||||
@beartype
|
||||
def get_page_client(self, page: Page) -> CDPSession:
|
||||
return page.client # type: ignore
|
||||
|
||||
@beartype
|
||||
def _get_obs(self) -> dict[str, Observation]:
|
||||
obs = self.observation_handler.get_observation(
|
||||
self.page, self.get_page_client(self.page)
|
||||
)
|
||||
return obs
|
||||
|
||||
@beartype
|
||||
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
metadata = self.observation_handler.get_observation_metadata()
|
||||
return metadata
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[dict[str, Observation], dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment.
|
||||
:param options: options for the environment. The current supported options are:
|
||||
- "storage_state": the storage state of the browser. It is a file path to a json file.
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
if options is not None and "config_file" in options:
|
||||
config_file = Path(options["config_file"])
|
||||
if config_file.exists():
|
||||
self.setup(config_file=config_file)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file} does not exist.")
|
||||
else:
|
||||
self.setup()
|
||||
self.reset_finished = True
|
||||
|
||||
if self.sleep_after_execution > 0:
|
||||
time.sleep(self.sleep_after_execution)
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, ""),
|
||||
"fail_error": "",
|
||||
"observation_metadata": observation_metadata,
|
||||
}
|
||||
|
||||
return (observation, info)
|
||||
|
||||
@beartype
|
||||
def save_trace(self, trace_path: str | Path) -> None:
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.stop(path=trace_path)
|
||||
|
||||
@beartype
|
||||
def close(self) -> None:
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
|
||||
if not self.reset_finished:
|
||||
raise RuntimeError("Call reset first before calling step.")
|
||||
|
||||
success = False
|
||||
fail_error = ""
|
||||
try:
|
||||
self.page = execute_action(
|
||||
action,
|
||||
self.page,
|
||||
self.context,
|
||||
self.observation_handler.action_processor,
|
||||
)
|
||||
success = True
|
||||
except Exception as e:
|
||||
fail_error = str(e)
|
||||
|
||||
# hard sleep TODO[shuyanzh] suboptimal, may need to check network
|
||||
if self.sleep_after_execution > 0:
|
||||
time.sleep(self.sleep_after_execution)
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, self.page.content()),
|
||||
"fail_error": fail_error,
|
||||
"observation_metadata": observation_metadata,
|
||||
}
|
||||
msg = (
|
||||
observation,
|
||||
float(success), # reward
|
||||
False, # terminated
|
||||
False, # truncated
|
||||
info,
|
||||
)
|
||||
return msg
|
||||
670
browser_env/processors.py
Normal file
670
browser_env/processors.py
Normal file
@ -0,0 +1,670 @@
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from gymnasium import spaces
|
||||
from playwright.sync_api import CDPSession, Page, ViewportSize
|
||||
|
||||
from browser_env.constants import (
|
||||
ASCII_CHARSET,
|
||||
FREQ_UNICODE_CHARSET,
|
||||
IGNORED_ACTREE_PROPERTIES,
|
||||
UTTERANCE_MAX_LENGTH,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
AccessibilityTree,
|
||||
BrowserConfig,
|
||||
BrowserInfo,
|
||||
Observation,
|
||||
png_bytes_to_numpy,
|
||||
)
|
||||
|
||||
|
||||
class ObservationProcessor:
|
||||
def process(self, page: Page, client: CDPSession) -> Observation:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ObservationMetadata(TypedDict):
|
||||
obs_nodes_info: dict[str, Any]
|
||||
|
||||
|
||||
def create_empty_metadata() -> ObservationMetadata:
|
||||
return {
|
||||
"obs_nodes_info": {},
|
||||
}
|
||||
|
||||
|
||||
class TextObervationProcessor(ObservationProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
observation_type: str,
|
||||
current_viewport_only: bool,
|
||||
viewport_size: ViewportSize,
|
||||
):
|
||||
self.observation_type = observation_type
|
||||
self.current_viewport_only = current_viewport_only
|
||||
self.viewport_size = viewport_size
|
||||
self.observation_tag = "text"
|
||||
self.meta_data = (
|
||||
create_empty_metadata()
|
||||
) # use the store meta data of this observation type
|
||||
|
||||
@beartype
|
||||
def fetch_browser_info(
|
||||
self,
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
) -> BrowserInfo:
|
||||
# extract domtree
|
||||
tree = client.send(
|
||||
"DOMSnapshot.captureSnapshot",
|
||||
{
|
||||
"computedStyles": [],
|
||||
"includeDOMRects": True,
|
||||
"includePaintOrder": True,
|
||||
},
|
||||
)
|
||||
|
||||
# calibrate the bounds, in some cases, the bounds are scaled somehow
|
||||
bounds = tree["documents"][0]["layout"]["bounds"]
|
||||
b = bounds[0]
|
||||
n = b[2] / self.viewport_size["width"]
|
||||
bounds = [[x / n for x in bound] for bound in bounds]
|
||||
tree["documents"][0]["layout"]["bounds"] = bounds
|
||||
# add union bound placeholder
|
||||
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
|
||||
|
||||
# extract browser info
|
||||
win_upper_bound = page.evaluate("window.pageYOffset")
|
||||
win_left_bound = page.evaluate("window.pageXOffset")
|
||||
win_width = page.evaluate("window.screen.width")
|
||||
win_height = page.evaluate("window.screen.height")
|
||||
win_right_bound = win_left_bound + win_width
|
||||
win_lower_bound = win_upper_bound + win_height
|
||||
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
|
||||
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
|
||||
|
||||
config: BrowserConfig = {
|
||||
"win_upper_bound": win_upper_bound,
|
||||
"win_left_bound": win_left_bound,
|
||||
"win_width": win_width,
|
||||
"win_height": win_height,
|
||||
"win_right_bound": win_right_bound,
|
||||
"win_lower_bound": win_lower_bound,
|
||||
"device_pixel_ratio": device_pixel_ratio,
|
||||
}
|
||||
|
||||
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
|
||||
info: BrowserInfo = {"DOMTree": tree, "config": config}
|
||||
|
||||
return info
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def partially_in_viewport(
|
||||
bound: list[float], config: BrowserConfig
|
||||
) -> bool:
|
||||
[x, y, width, height] = bound
|
||||
elem_left_bound = x
|
||||
elem_top_bound = y
|
||||
elem_right_bound = x + width
|
||||
elem_lower_bound = y + height
|
||||
|
||||
ok = (
|
||||
elem_left_bound < config["win_right_bound"]
|
||||
and elem_right_bound >= config["win_left_bound"]
|
||||
and elem_top_bound < config["win_lower_bound"]
|
||||
and elem_lower_bound >= config["win_upper_bound"]
|
||||
)
|
||||
|
||||
return ok
|
||||
|
||||
@beartype
|
||||
def retrieve_viewport_info(self, info: BrowserInfo) -> None:
|
||||
"""Add viewport related information to the DOMTree
|
||||
1. add union bound, which is a union of all the bounds of the nodes in the subtree
|
||||
This is only used when current_viewport_only is enabled since it is quite slow
|
||||
|
||||
TODO[robert1003]: improve
|
||||
"""
|
||||
tree = info["DOMTree"]
|
||||
document = tree["documents"][0]
|
||||
nodes = document["nodes"]
|
||||
parent = nodes["parentIndex"]
|
||||
node_names = nodes["nodeName"]
|
||||
|
||||
layout = document["layout"]
|
||||
layout_node_cursor = layout["nodeIndex"]
|
||||
bounds = layout["bounds"]
|
||||
|
||||
graph = defaultdict(lambda: [])
|
||||
assert len(node_names) == len(parent)
|
||||
for node_idx in range(len(node_names)):
|
||||
parent_idx = parent[node_idx]
|
||||
if parent_idx != -1:
|
||||
graph[parent_idx].append(node_idx)
|
||||
|
||||
union_bounds: list[list[float] | None] = [None for _ in bounds]
|
||||
|
||||
def valid_bbox(bound: list[float] | None) -> bool:
|
||||
if bound is None:
|
||||
return False
|
||||
# no width or height
|
||||
if np.isclose(bound[2], 0):
|
||||
return False
|
||||
if np.isclose(bound[3], 0):
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_union_bound(idx: int) -> list[float] | None:
|
||||
if idx in layout_node_cursor:
|
||||
cursor = layout_node_cursor.index(idx)
|
||||
node_bound = bounds[cursor].copy()
|
||||
tree_bounds: list[Any] = [node_bound]
|
||||
for child_idx in graph[idx]:
|
||||
child_bound = add_union_bound(child_idx)
|
||||
tree_bounds.append(
|
||||
child_bound.copy() if child_bound else None
|
||||
)
|
||||
|
||||
tree_bounds = [b for b in tree_bounds if valid_bbox(b)]
|
||||
# convert to absolute coordinates
|
||||
for i in range(len(tree_bounds)):
|
||||
tree_bounds[i][2] = tree_bounds[i][0] + tree_bounds[i][2]
|
||||
tree_bounds[i][3] = tree_bounds[i][1] + tree_bounds[i][3]
|
||||
|
||||
if len(tree_bounds) == 0:
|
||||
assert not valid_bbox(node_bound)
|
||||
node_union_bound = [0.0, 0.0, 0.0, 0.0]
|
||||
else:
|
||||
left_bound = min([b[0] for b in tree_bounds])
|
||||
top_bound = min([b[1] for b in tree_bounds])
|
||||
right_bound = max([b[2] for b in tree_bounds])
|
||||
bottom_bound = max([b[3] for b in tree_bounds])
|
||||
node_union_bound = [
|
||||
left_bound,
|
||||
top_bound,
|
||||
right_bound - left_bound,
|
||||
bottom_bound - top_bound,
|
||||
]
|
||||
|
||||
# update the list
|
||||
union_bounds[cursor] = node_union_bound
|
||||
else:
|
||||
node_union_bound = None
|
||||
|
||||
return node_union_bound
|
||||
|
||||
add_union_bound(0)
|
||||
info["DOMTree"]["documents"][0]["layout"]["unionBounds"] = union_bounds
|
||||
|
||||
@beartype
|
||||
def current_viewport_html(self, info: BrowserInfo) -> str:
|
||||
# adopted from [natbot](https://github.com/nat/natbot)
|
||||
tree = info["DOMTree"]
|
||||
strings = tree["strings"]
|
||||
document = tree["documents"][0]
|
||||
nodes = document["nodes"]
|
||||
attributes = nodes["attributes"]
|
||||
node_value = nodes["nodeValue"]
|
||||
parent = nodes["parentIndex"]
|
||||
node_names = nodes["nodeName"]
|
||||
|
||||
layout = document["layout"]
|
||||
layout_node_cursor = layout["nodeIndex"]
|
||||
union_bounds = layout["unionBounds"]
|
||||
|
||||
graph = defaultdict(lambda: [])
|
||||
for node_idx in range(len(node_names)):
|
||||
parent_idx = parent[node_idx]
|
||||
if parent_idx != -1:
|
||||
graph[parent_idx].append(node_idx)
|
||||
|
||||
def dfs(idx: int) -> str:
|
||||
node_name = strings[node_names[idx]].lower().strip()
|
||||
can_skip = "#" in node_name or "::" in node_name
|
||||
|
||||
inner_text = ""
|
||||
node_value_idx = node_value[idx]
|
||||
if node_value_idx >= 0 and node_value_idx < len(strings):
|
||||
inner_text = " ".join(strings[node_value_idx].split())
|
||||
node_attributes = [strings[i] for i in attributes[idx]]
|
||||
node_attributes_str = ""
|
||||
for i in range(0, len(node_attributes), 2):
|
||||
a = node_attributes[i]
|
||||
b = node_attributes[i + 1]
|
||||
b = " ".join(b.split())
|
||||
node_attributes_str += f'{a}="{b}" '
|
||||
node_attributes_str = node_attributes_str.strip()
|
||||
|
||||
html = ""
|
||||
if not can_skip:
|
||||
html += f"<{node_name}"
|
||||
if {node_attributes_str}:
|
||||
html += f" {node_attributes_str}"
|
||||
html += f">{inner_text}"
|
||||
else:
|
||||
html += f"{inner_text}"
|
||||
|
||||
for child_idx in graph[idx]:
|
||||
if child_idx in layout_node_cursor:
|
||||
cursor = layout_node_cursor.index(child_idx)
|
||||
union_bound = union_bounds[cursor]
|
||||
if not self.partially_in_viewport(
|
||||
union_bound, info["config"]
|
||||
):
|
||||
continue
|
||||
html += dfs(child_idx)
|
||||
|
||||
if not can_skip:
|
||||
html += f"</{node_name}>"
|
||||
|
||||
return html
|
||||
|
||||
html = dfs(0)
|
||||
return html
|
||||
|
||||
@beartype
|
||||
def fetch_page_accessibility_tree(
|
||||
self, info: BrowserInfo, client: CDPSession
|
||||
) -> AccessibilityTree:
|
||||
accessibility_tree: AccessibilityTree = client.send(
|
||||
"Accessibility.getFullAXTree", {}
|
||||
)["nodes"]
|
||||
|
||||
# a few nodes are repeated in the accessibility tree
|
||||
seen_ids = set()
|
||||
_accessibility_tree = []
|
||||
for node in accessibility_tree:
|
||||
if node["nodeId"] not in seen_ids:
|
||||
_accessibility_tree.append(node)
|
||||
seen_ids.add(node["nodeId"])
|
||||
accessibility_tree = _accessibility_tree
|
||||
|
||||
# add the bounding box of each node
|
||||
tree = info["DOMTree"]
|
||||
document = tree["documents"][0]
|
||||
nodes = document["nodes"]
|
||||
backend_node_id = nodes["backendNodeId"]
|
||||
node_names = nodes["nodeName"]
|
||||
|
||||
layout = document["layout"]
|
||||
layout_node_cursor = layout["nodeIndex"]
|
||||
bounds = layout["bounds"]
|
||||
union_bounds = layout["unionBounds"]
|
||||
offsetrect_bounds = layout["offsetRects"]
|
||||
backend_id_to_bound = {}
|
||||
|
||||
# get the mapping between backend node id and bounding box
|
||||
for idx in range(len(node_names)):
|
||||
if idx not in layout_node_cursor:
|
||||
continue
|
||||
cursor = layout_node_cursor.index(idx)
|
||||
node_bound = bounds[cursor]
|
||||
node_union_bound = union_bounds[cursor]
|
||||
node_offsetrect_bound = offsetrect_bounds[cursor]
|
||||
node_backend_id = backend_node_id[idx]
|
||||
backend_id_to_bound[node_backend_id] = [
|
||||
node_bound,
|
||||
node_union_bound,
|
||||
node_offsetrect_bound,
|
||||
]
|
||||
|
||||
parent_graph: dict[str, str] = {}
|
||||
refine_node_ids: list[str] = []
|
||||
for node in accessibility_tree:
|
||||
if "parentId" in node:
|
||||
parent_graph[node["nodeId"]] = node["parentId"]
|
||||
if "backendDOMNodeId" not in node:
|
||||
node["bound"] = None
|
||||
node["union_bound"] = None
|
||||
node["offsetrect_bound"] = None
|
||||
elif node["backendDOMNodeId"] not in backend_id_to_bound:
|
||||
refine_node_ids.append(node["nodeId"])
|
||||
else:
|
||||
node["bound"] = backend_id_to_bound[node["backendDOMNodeId"]][
|
||||
0
|
||||
]
|
||||
node["union_bound"] = backend_id_to_bound[
|
||||
node["backendDOMNodeId"]
|
||||
][1]
|
||||
node["offsetrect_bound"] = backend_id_to_bound[
|
||||
node["backendDOMNodeId"]
|
||||
][2]
|
||||
|
||||
# refine the bounding box for nodes which only appear in the accessibility tree
|
||||
node_ids = [node["nodeId"] for node in accessibility_tree]
|
||||
for refine_node_id in refine_node_ids:
|
||||
child_id = refine_node_id
|
||||
parent_idx: None | int = None
|
||||
while child_id in parent_graph:
|
||||
parent_id = parent_graph[child_id]
|
||||
parent_idx = node_ids.index(parent_id)
|
||||
child_id = parent_id
|
||||
if accessibility_tree[parent_idx]["union_bound"] is not None:
|
||||
break
|
||||
|
||||
refine_node_idx = node_ids.index(refine_node_id)
|
||||
|
||||
if parent_idx is not None:
|
||||
accessibility_tree[refine_node_idx][
|
||||
"bound"
|
||||
] = accessibility_tree[parent_idx]["bound"]
|
||||
accessibility_tree[refine_node_idx][
|
||||
"union_bound"
|
||||
] = accessibility_tree[parent_idx]["union_bound"]
|
||||
accessibility_tree[refine_node_idx][
|
||||
"offsetrect_bound"
|
||||
] = accessibility_tree[parent_idx]["offsetrect_bound"]
|
||||
else:
|
||||
accessibility_tree[refine_node_idx]["bound"] = None
|
||||
accessibility_tree[refine_node_idx]["union_bound"] = None
|
||||
accessibility_tree[refine_node_idx]["offsetrect_bound"] = None
|
||||
|
||||
return accessibility_tree
|
||||
|
||||
@beartype
|
||||
def current_viewport_accessibility_tree(
|
||||
self,
|
||||
info: BrowserInfo,
|
||||
accessibility_tree: AccessibilityTree,
|
||||
) -> AccessibilityTree:
|
||||
config = info["config"]
|
||||
subtree = []
|
||||
for node in accessibility_tree:
|
||||
if not node["union_bound"]:
|
||||
continue
|
||||
|
||||
[x, y, width, height] = node["union_bound"]
|
||||
elem_left_bound = x
|
||||
elem_top_bound = y
|
||||
elem_right_bound = x + width
|
||||
elem_lower_bound = y + height
|
||||
|
||||
ok = (
|
||||
elem_left_bound < config["win_right_bound"]
|
||||
and elem_right_bound >= config["win_left_bound"]
|
||||
and elem_top_bound < config["win_lower_bound"]
|
||||
and elem_lower_bound >= config["win_upper_bound"]
|
||||
)
|
||||
|
||||
if ok:
|
||||
subtree.append(node)
|
||||
|
||||
return subtree
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def parse_accessibility_tree(
|
||||
accessibility_tree: AccessibilityTree,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Parse the accessibility tree into a string text"""
|
||||
node_id_to_idx = {}
|
||||
for idx, node in enumerate(accessibility_tree):
|
||||
node_id_to_idx[node["nodeId"]] = idx
|
||||
|
||||
obs_nodes_info = {}
|
||||
|
||||
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
|
||||
tree_str = ""
|
||||
node = accessibility_tree[idx]
|
||||
indent = "\t" * depth
|
||||
valid_node = True
|
||||
try:
|
||||
role = node["role"]["value"]
|
||||
name = node["name"]["value"]
|
||||
node_str = f"[{obs_node_id}] {role} {repr(name)}"
|
||||
properties = []
|
||||
for property in node.get("properties", []):
|
||||
try:
|
||||
if property["name"] in IGNORED_ACTREE_PROPERTIES:
|
||||
continue
|
||||
properties.append(
|
||||
f'{property["name"]}: {property["value"]["value"]}'
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if properties:
|
||||
node_str += " " + " ".join(properties)
|
||||
|
||||
# check valid
|
||||
if not node_str.strip():
|
||||
valid_node = False
|
||||
|
||||
# empty generic node
|
||||
if not name.strip():
|
||||
if not properties:
|
||||
if role in [
|
||||
"generic",
|
||||
"img",
|
||||
"list",
|
||||
"strong",
|
||||
"paragraph",
|
||||
"banner",
|
||||
"navigation",
|
||||
"Section",
|
||||
"LabelText",
|
||||
"Legend",
|
||||
"listitem",
|
||||
]:
|
||||
valid_node = False
|
||||
elif role in ["listitem"]:
|
||||
valid_node = False
|
||||
|
||||
if valid_node:
|
||||
tree_str += f"{indent}{node_str}"
|
||||
obs_nodes_info[obs_node_id] = {
|
||||
"backend_id": node["backendDOMNodeId"],
|
||||
"bound": node["bound"],
|
||||
"union_bound": node["union_bound"],
|
||||
"offsetrect_bound": node["offsetrect_bound"],
|
||||
"text": node_str,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
valid_node = False
|
||||
|
||||
for _, child_node_id in enumerate(node["childIds"]):
|
||||
if child_node_id not in node_id_to_idx:
|
||||
continue
|
||||
# mark this to save some tokens
|
||||
child_depth = depth + 1 if valid_node else depth
|
||||
child_str = dfs(
|
||||
node_id_to_idx[child_node_id], child_node_id, child_depth
|
||||
)
|
||||
if child_str.strip():
|
||||
if tree_str.strip():
|
||||
tree_str += "\n"
|
||||
tree_str += child_str
|
||||
|
||||
return tree_str
|
||||
|
||||
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
|
||||
return tree_str, obs_nodes_info
|
||||
|
||||
@beartype
|
||||
@staticmethod
|
||||
def clean_accesibility_tree(tree_str: str) -> str:
|
||||
"""further clean accesibility tree"""
|
||||
clean_lines: list[str] = []
|
||||
for line in tree_str.split("\n"):
|
||||
if "statictext" in line.lower():
|
||||
prev_lines = clean_lines[-3:]
|
||||
pattern = r"\[\d+\] StaticText '([^']+)'"
|
||||
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
static_text = match.group(1)
|
||||
if all(
|
||||
static_text not in prev_line
|
||||
for prev_line in prev_lines
|
||||
):
|
||||
clean_lines.append(line)
|
||||
else:
|
||||
clean_lines.append(line)
|
||||
|
||||
return "\n".join(clean_lines)
|
||||
|
||||
@beartype
|
||||
def process(self, page: Page, client: CDPSession) -> str:
|
||||
# get the tab info
|
||||
open_tabs = page.context.pages
|
||||
tab_titles = [tab.title() for tab in open_tabs]
|
||||
current_tab_idx = open_tabs.index(page)
|
||||
for idx in range(len(open_tabs)):
|
||||
if idx == current_tab_idx:
|
||||
tab_titles[
|
||||
idx
|
||||
] = f"Tab {idx} (current): {open_tabs[idx].title()}"
|
||||
else:
|
||||
tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
|
||||
tab_title_str = " | ".join(tab_titles)
|
||||
|
||||
try:
|
||||
browser_info = self.fetch_browser_info(page, client)
|
||||
except Exception:
|
||||
page.wait_for_load_state("load", timeout=500)
|
||||
browser_info = self.fetch_browser_info(page, client)
|
||||
|
||||
if self.current_viewport_only:
|
||||
self.retrieve_viewport_info(browser_info)
|
||||
|
||||
if self.observation_type == "html":
|
||||
if self.current_viewport_only:
|
||||
html = self.current_viewport_html(browser_info)
|
||||
content = html
|
||||
else:
|
||||
content = page.content()
|
||||
elif self.observation_type == "accessibility_tree":
|
||||
accessibility_tree = self.fetch_page_accessibility_tree(
|
||||
browser_info, client
|
||||
)
|
||||
if self.current_viewport_only:
|
||||
accessibility_tree = self.current_viewport_accessibility_tree(
|
||||
browser_info, accessibility_tree
|
||||
)
|
||||
content, obs_nodes_info = self.parse_accessibility_tree(
|
||||
accessibility_tree
|
||||
)
|
||||
content = self.clean_accesibility_tree(content)
|
||||
self.obs_nodes_info = obs_nodes_info
|
||||
self.meta_data["obs_nodes_info"] = obs_nodes_info
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid observatrion type: {self.observation_type}"
|
||||
)
|
||||
|
||||
self.browser_config = browser_info["config"]
|
||||
content = f"{tab_title_str}\n\n{content}"
|
||||
return content
|
||||
|
||||
@beartype
|
||||
def get_element_center(self, element_id: str) -> tuple[float, float]:
|
||||
node_info = self.obs_nodes_info[element_id]
|
||||
node_bound = node_info["bound"]
|
||||
x, y, width, height = node_bound
|
||||
browser_config = self.browser_config
|
||||
b_x, b_y = (
|
||||
browser_config["win_left_bound"],
|
||||
browser_config["win_upper_bound"],
|
||||
)
|
||||
center_x = (x - b_x) + width / 2
|
||||
center_y = (y - b_y) + height / 2
|
||||
return (
|
||||
center_x / self.viewport_size["width"],
|
||||
center_y / self.viewport_size["height"],
|
||||
)
|
||||
|
||||
|
||||
class ImageObservationProcessor(ObservationProcessor):
|
||||
def __init__(self, observation_type: str):
|
||||
self.observation_type = observation_type
|
||||
self.observation_tag = "image"
|
||||
self.meta_data = create_empty_metadata()
|
||||
|
||||
def process(self, page: Page, client: CDPSession) -> npt.NDArray[np.uint8]:
|
||||
try:
|
||||
screenshot = png_bytes_to_numpy(page.screenshot())
|
||||
except:
|
||||
page.wait_for_event("load")
|
||||
screenshot = png_bytes_to_numpy(page.screenshot())
|
||||
return screenshot
|
||||
|
||||
|
||||
class ObservationHandler:
|
||||
"""Main entry point to access all observation processor"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_observation_type: str,
|
||||
text_observation_type: str,
|
||||
image_observation_type: str,
|
||||
current_viewport_only: bool,
|
||||
viewport_size: ViewportSize,
|
||||
) -> None:
|
||||
self.main_observation_type = main_observation_type
|
||||
self.text_processor = TextObervationProcessor(
|
||||
text_observation_type, current_viewport_only, viewport_size
|
||||
)
|
||||
self.image_processor = ImageObservationProcessor(
|
||||
image_observation_type
|
||||
)
|
||||
self.viewport_size = viewport_size
|
||||
|
||||
@beartype
|
||||
def get_observation_space(self) -> spaces.Dict:
|
||||
text_space = spaces.Text(
|
||||
min_length=0,
|
||||
max_length=UTTERANCE_MAX_LENGTH,
|
||||
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
|
||||
)
|
||||
|
||||
image_space = spaces.Box(
|
||||
# Each position stores the RGB values. Note the swapped axes (height first).
|
||||
np.zeros(
|
||||
(self.viewport_size["height"], self.viewport_size["width"], 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
np.ones(
|
||||
(self.viewport_size["height"], self.viewport_size["width"], 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
* 255.0,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
return spaces.Dict({"text": text_space, "image": image_space})
|
||||
|
||||
@beartype
|
||||
def get_observation(
|
||||
self, page: Page, client: CDPSession
|
||||
) -> dict[str, Observation]:
|
||||
text_obs = self.text_processor.process(page, client)
|
||||
image_obs = self.image_processor.process(page, client)
|
||||
return {"text": text_obs, "image": image_obs}
|
||||
|
||||
@beartype
|
||||
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
return {
|
||||
"text": self.text_processor.meta_data,
|
||||
"image": self.image_processor.meta_data,
|
||||
}
|
||||
|
||||
@property
|
||||
def action_processor(self) -> ObservationProcessor:
|
||||
"""Return the main processor that is associated with the action space"""
|
||||
if self.main_observation_type == "text":
|
||||
return self.text_processor
|
||||
elif self.main_observation_type == "image":
|
||||
return self.image_processor
|
||||
else:
|
||||
raise ValueError("Invalid main observation type")
|
||||
0
browser_env/py.typed
Normal file
0
browser_env/py.typed
Normal file
68
browser_env/utils.py
Normal file
68
browser_env/utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetachedPage:
|
||||
url: str
|
||||
content: str # html
|
||||
|
||||
|
||||
@beartype
|
||||
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
|
||||
"""Convert png bytes to numpy array
|
||||
|
||||
Example:
|
||||
|
||||
>>> fig = go.Figure(go.Scatter(x=[1], y=[1]))
|
||||
>>> plt.imshow(png_bytes_to_numpy(fig.to_image('png')))
|
||||
"""
|
||||
return np.array(Image.open(BytesIO(png)))
|
||||
|
||||
|
||||
class AccessibilityTreeNode(TypedDict):
|
||||
nodeId: str
|
||||
ignored: bool
|
||||
role: dict[str, Any]
|
||||
chromeRole: dict[str, Any]
|
||||
name: dict[str, Any]
|
||||
properties: list[dict[str, Any]]
|
||||
childIds: list[str]
|
||||
parentId: str
|
||||
backendDOMNodeId: int
|
||||
frameId: str
|
||||
bound: list[float] | None
|
||||
union_bound: list[float] | None
|
||||
offsetrect_bound: list[float] | None
|
||||
|
||||
|
||||
class BrowserConfig(TypedDict):
|
||||
win_upper_bound: float
|
||||
win_left_bound: float
|
||||
win_width: float
|
||||
win_height: float
|
||||
win_right_bound: float
|
||||
win_lower_bound: float
|
||||
device_pixel_ratio: float
|
||||
|
||||
|
||||
class BrowserInfo(TypedDict):
|
||||
DOMTree: dict[str, Any]
|
||||
config: BrowserConfig
|
||||
|
||||
|
||||
AccessibilityTree = list[AccessibilityTreeNode]
|
||||
|
||||
|
||||
Observation = str | npt.NDArray[np.uint8]
|
||||
|
||||
|
||||
class StateInfo(TypedDict):
|
||||
observation: dict[str, Observation]
|
||||
info: Dict[str, Any]
|
||||
33
check_errors.sh
Executable file
33
check_errors.sh
Executable file
@ -0,0 +1,33 @@
|
||||
#!/bin/zsh
|
||||
|
||||
result_folder=$1
|
||||
cd cache/$result_folder
|
||||
|
||||
|
||||
# check whether there is any auto-login errors
|
||||
errors=$(grep -l "Creating an account has many benefits: check out faster" *.html | sort -u | grep -o '[0-9]\+')
|
||||
c=$(echo $errors | wc -l)
|
||||
echo "Shopping total errors: $c"
|
||||
echo $errors | tr '\n' ','
|
||||
echo '\n\n'
|
||||
|
||||
|
||||
errors=$(grep -l "Welcome, please sign in" *.html | sort -u | grep -o '[0-9]\+')
|
||||
c=$(echo $errors | wc -l)
|
||||
echo "Admin total errors: $c"
|
||||
echo $errors | tr '\n' ','
|
||||
echo '\n\n'
|
||||
|
||||
|
||||
|
||||
errors=$(grep -l "Username or email" *.html | sort -u | grep -o '[0-9]\+')
|
||||
c=$(echo $errors | wc -l)
|
||||
echo "Gitlab errors: $c"
|
||||
echo $errors | tr '\n' ','
|
||||
echo '\n\n'
|
||||
|
||||
|
||||
errors=$(grep -l "Keep me logged in" *.html | sort -u | grep -o '[0-9]\+')
|
||||
c=$(echo $errors | wc -l)
|
||||
echo "Reddit errors: $c"
|
||||
echo $errors | tr '\n' ','
|
||||
31
config_files/examples/1.json
Normal file
31
config_files/examples/1.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"sites": ["reddit"],
|
||||
"task_id": 1,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/reddit_state.json",
|
||||
"start_url": "http://metis.lti.cs.cmu.edu:9999/",
|
||||
"geolocation": null,
|
||||
"intent_template": "tell me all subreddits starting with character '{{character}}'",
|
||||
"instantiation_dict": {"character": "a"},
|
||||
"intent": "tell me all subreddits starting with character 'a'",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["string_match"],
|
||||
"reference_answers": ["announcements Art AskReddit askscience aww"],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"reference_action_sequence": {
|
||||
"action_set_tag": "playwright",
|
||||
"action_sequence": [
|
||||
"page.get_by_role(\"link\", name=\"Forums\").click()",
|
||||
"page.get_by_role(\"link\", name=\"Alphabetical\").click()",
|
||||
"page.stop(\"announcements Art AskReddit askscience aww\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
30
config_files/examples/2.json
Normal file
30
config_files/examples/2.json
Normal file
@ -0,0 +1,30 @@
|
||||
{
|
||||
"sites": ["misc"],
|
||||
"task_id": 2,
|
||||
"require_login": false,
|
||||
"storage_state": null,
|
||||
"start_url": "https://russmaxdesign.github.io/exercise",
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "Check out the classification section",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["url_match"],
|
||||
"reference_answers": null,
|
||||
"reference_url": "https://russmaxdesign.github.io/exercise/#link-two",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"reference_action_sequence": {
|
||||
"action_set_tag": "playwright",
|
||||
"action_sequence": [
|
||||
"page.get_by_role(\"navigation\").get_by_role(\"link\", name=\"Classification\").click()",
|
||||
"page.stop(\"Wilson and Reade\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
31
config_files/examples/3.json
Normal file
31
config_files/examples/3.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"sites": ["misc"],
|
||||
"task_id": 3,
|
||||
"require_login": false,
|
||||
"storage_state": null,
|
||||
"start_url": "https://russmaxdesign.github.io/exercise",
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "Tell me who provide a collection of concise, detailed information for mammal classification in 2005",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["string_match"],
|
||||
"reference_answers": ["Wilson and Reader"],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"reference_action_sequence": {
|
||||
"action_set_tag": "id_accessibility_tree",
|
||||
"action_sequence": [
|
||||
"type [13] [xyz@gmail.com] [0]",
|
||||
"click [65]",
|
||||
"stop [Wilson and Reader]"
|
||||
]
|
||||
}
|
||||
}
|
||||
31
config_files/examples/4.json
Normal file
31
config_files/examples/4.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"sites": ["reddit"],
|
||||
"task_id": 4,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/reddit_state.json",
|
||||
"start_url": "http://metis.lti.cs.cmu.edu:9999/",
|
||||
"geolocation": null,
|
||||
"intent_template": "list all subreddits in alphabetical order",
|
||||
"instantiation_dict": {},
|
||||
"intent": "list all subreddits in alphabetical order",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["url_match"],
|
||||
"reference_answers": null,
|
||||
"reference_url": "http://metis.lti.cs.cmu.edu:9999/forums/all",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"reference_action_sequence": {
|
||||
"action_set_tag": "playwright",
|
||||
"action_sequence": [
|
||||
"page.get_by_role(\"link\", name=\"Forums\").click()",
|
||||
"page.get_by_role(\"link\", name=\"Alphabetical\").click()",
|
||||
"page.stop()"
|
||||
]
|
||||
}
|
||||
}
|
||||
28537
config_files/test.raw.json
Normal file
28537
config_files/test.raw.json
Normal file
File diff suppressed because it is too large
Load Diff
6
evaluation_harness/__init__.py
Normal file
6
evaluation_harness/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .evaluators import *
|
||||
from .helper_functions import (
|
||||
shopping_get_latest_order_url,
|
||||
shopping_get_sku_latest_review_author,
|
||||
shopping_get_sku_latest_review_rating,
|
||||
)
|
||||
389
evaluation_harness/evaluators.py
Normal file
389
evaluation_harness/evaluators.py
Normal file
@ -0,0 +1,389 @@
|
||||
"""base class for evaluation"""
|
||||
# answer string match
|
||||
import importlib
|
||||
import json
|
||||
import time
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import evaluate # type: ignore[import]
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
from playwright.sync_api import CDPSession, Page
|
||||
|
||||
from browser_env.actions import Action
|
||||
from browser_env.utils import StateInfo
|
||||
from evaluation_harness.helper_functions import (
|
||||
gitlab_get_project_memeber_role,
|
||||
llm_fuzzy_match,
|
||||
reddit_get_post_url,
|
||||
shopping_get_latest_order_url,
|
||||
shopping_get_sku_latest_review_author,
|
||||
shopping_get_sku_latest_review_rating,
|
||||
)
|
||||
|
||||
Trajectory = list[Union[Action, StateInfo]]
|
||||
|
||||
|
||||
@beartype
|
||||
class Evaluator(object):
|
||||
def __init__(self, eval_tag: str = "") -> None:
|
||||
self.eval_tag = eval_tag
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_last_action(trajectory: Trajectory) -> Action:
|
||||
try:
|
||||
is_bearable(trajectory[-1], Action)
|
||||
last_action = trajectory[-1]
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"The last element of trajectory should be an action, add a fake stop action if needed"
|
||||
)
|
||||
|
||||
return last_action # type: ignore[return-value]
|
||||
|
||||
@staticmethod
|
||||
def get_last_state(trajectory: Trajectory) -> StateInfo:
|
||||
try:
|
||||
is_bearable(trajectory[-2], StateInfo)
|
||||
last_state = trajectory[-2]
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"The second last element of trajectory should be a state, add a fake stop action if needed"
|
||||
)
|
||||
|
||||
return last_state # type: ignore[return-value]
|
||||
|
||||
|
||||
@beartype
|
||||
class StringExactEvaluator(Evaluator):
|
||||
"""Check whether the answer is exactly the same as one of the reference answers"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page | None = None,
|
||||
client: CDPSession | None = None,
|
||||
) -> float:
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
def clean_answer(answer: str) -> str:
|
||||
if answer.startswith("'") and answer.endswith("'"):
|
||||
answer = answer[1:-1]
|
||||
elif answer.startswith('"') and answer.endswith('"'):
|
||||
answer = answer[1:-1]
|
||||
return answer
|
||||
|
||||
last_action = self.get_last_action(trajectory)
|
||||
pred = clean_answer(last_action["answer"])
|
||||
ref = [clean_answer(x) for x in configs["eval"]["reference_answers"]]
|
||||
if pred in ref:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
@beartype
|
||||
class StringEvaluator(Evaluator):
|
||||
"""Check whether the answer is correct with:
|
||||
exact match: the answer is exactly the same as the reference answer
|
||||
must include: each phrase in the reference answer must be included in the answer
|
||||
fuzzy match: the answer is similar to the reference answer, using LLM judge
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page | None = None,
|
||||
client: CDPSession | None = None,
|
||||
) -> float:
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
def clean_answer(answer: str) -> str:
|
||||
if answer.startswith("'") and answer.endswith("'"):
|
||||
answer = answer[1:-1]
|
||||
elif answer.startswith('"') and answer.endswith('"'):
|
||||
answer = answer[1:-1]
|
||||
return answer.lower()
|
||||
|
||||
last_action = self.get_last_action(trajectory)
|
||||
pred = clean_answer(last_action["answer"])
|
||||
|
||||
score = 1.0
|
||||
for approach, value in configs["eval"]["reference_answers"].items():
|
||||
match approach:
|
||||
case "exact_match":
|
||||
assert isinstance(value, str)
|
||||
ref_answer = clean_answer(value)
|
||||
score = score * (pred == ref_answer)
|
||||
case "must_include":
|
||||
assert isinstance(value, list)
|
||||
for must_value in value:
|
||||
must_value = clean_answer(must_value)
|
||||
score = score * (must_value in pred)
|
||||
case "fuzzy_match":
|
||||
intent = configs["intent"]
|
||||
assert isinstance(value, list)
|
||||
for reference in value:
|
||||
fuzzy_score = llm_fuzzy_match(pred, reference, intent)
|
||||
score = score * fuzzy_score
|
||||
return score
|
||||
|
||||
|
||||
@beartype
|
||||
class StringSoftEvaluator(Evaluator):
|
||||
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page | None = None,
|
||||
client: CDPSession | None = None,
|
||||
) -> float:
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
last_action = self.get_last_action(trajectory)
|
||||
pred = last_action["answer"]
|
||||
ref = configs["eval"]["reference_answers"]
|
||||
# rouge
|
||||
m = evaluate.load("rouge")
|
||||
rouge = m.compute(predictions=[pred], references=[ref])
|
||||
return float(rouge["rouge1"])
|
||||
|
||||
|
||||
@beartype
|
||||
class URLExactEvaluator(Evaluator):
|
||||
"""Check whether the URL is exactly the same as of the reference URLs"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession | None = None,
|
||||
) -> float:
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
def clean_url(url: str) -> str:
|
||||
url = str(url)
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
return url
|
||||
|
||||
pred = clean_url(page.url)
|
||||
ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
|
||||
ref_urls = [clean_url(url) for url in ref_urls]
|
||||
matching_rule = configs["eval"].get("url_note", "EXACT")
|
||||
if matching_rule == "EXACT":
|
||||
if pred in ref_urls:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
elif matching_rule == "GOLD in PRED":
|
||||
if any([ref in pred for ref in ref_urls]):
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
else:
|
||||
raise ValueError(f"Unknown matching rule: {matching_rule}")
|
||||
|
||||
|
||||
@beartype
|
||||
class HTMLContentExactEvaluator(Evaluator):
|
||||
"""Check whether the contents appear in the page"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession | None = None,
|
||||
) -> float:
|
||||
def clean(text: str) -> str:
|
||||
text = str(text)
|
||||
return text.strip().lower()
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
targets = configs["eval"]["program_html"]
|
||||
|
||||
score = 1.0
|
||||
for target in targets:
|
||||
target_url: str = target["url"] # which url to check
|
||||
if target_url.startswith("func"):
|
||||
func = target_url.split("func:")[1]
|
||||
func = func.replace("__last_url__", page.url)
|
||||
target_url = eval(func)
|
||||
|
||||
required_contents: str = target[
|
||||
"required_contents"
|
||||
] # what contents to check
|
||||
locator: str = target["locator"] # js element locator
|
||||
|
||||
# navigate to that url
|
||||
if target_url != "last":
|
||||
page.goto(target_url)
|
||||
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
|
||||
|
||||
# empty, use the full page
|
||||
if not locator.strip():
|
||||
selected_element = page.content()
|
||||
# use JS to select the element
|
||||
elif locator.startswith("document."):
|
||||
try:
|
||||
selected_element = page.evaluate(f"() => {locator}")
|
||||
if not selected_element:
|
||||
selected_element = ""
|
||||
selected_element = str(selected_element)
|
||||
except Exception:
|
||||
# the page is wrong, return empty
|
||||
selected_element = ""
|
||||
# run program to call API
|
||||
elif locator.startswith("func:"): # a helper function
|
||||
func = locator.split("func:")[1]
|
||||
func = func.replace("__page__", "page")
|
||||
selected_element = eval(func)
|
||||
else:
|
||||
raise ValueError(f"Unknown locator: {locator}")
|
||||
|
||||
required_contents_or = [
|
||||
clean(x) for x in required_contents.split(" |OR| ")
|
||||
]
|
||||
selected_element = clean(selected_element)
|
||||
score *= any(
|
||||
[
|
||||
content in selected_element
|
||||
for content in required_contents_or
|
||||
]
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
######
|
||||
# soft matches.
|
||||
# mainly for partial scores
|
||||
# !!under development!!
|
||||
# TODO[shuyanzh]
|
||||
######
|
||||
|
||||
|
||||
@beartype
|
||||
class EvaluatorPartial(Evaluator):
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@beartype
|
||||
class URLSoftEvaluator(EvaluatorPartial):
|
||||
"""Parse the URL and compare the domain and parameters"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
) -> float:
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
last_state = self.get_last_state(trajectory)
|
||||
pred = last_state["info"]["page"].url
|
||||
ref = configs["eval"]["reference_url"]
|
||||
|
||||
# parse url to get domain, parameters, etc.
|
||||
parsed_pred = urllib.parse.urlparse(pred)
|
||||
parsed_ref = urllib.parse.urlparse(ref)
|
||||
|
||||
# check domain
|
||||
domain_match = int(parsed_pred.netloc == parsed_ref.netloc)
|
||||
|
||||
def get_param_set(query: dict[str, list[str]]) -> set[str]:
|
||||
param_set = set()
|
||||
for k, v in query.items():
|
||||
for vv in v:
|
||||
param_set.add(f"{k}={vv}")
|
||||
return param_set
|
||||
|
||||
# calculate parameter f1
|
||||
param_set_ref = get_param_set(urllib.parse.parse_qs(parsed_ref.query))
|
||||
param_set_pred = get_param_set(
|
||||
urllib.parse.parse_qs(parsed_pred.query)
|
||||
)
|
||||
r = len(param_set_ref & param_set_pred) / len(param_set_ref)
|
||||
p = len(param_set_ref & param_set_pred) / len(param_set_pred)
|
||||
f1 = 2 * r * p / (r + p) if r + p > 0 else 1.0
|
||||
|
||||
score = domain_match * f1 # domain match is a must
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class EvaluatorComb:
|
||||
def __init__(self, evaluators: list[Evaluator]) -> None:
|
||||
self.evaluators = evaluators
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
) -> float:
|
||||
|
||||
score = 1.0
|
||||
for evaluator in self.evaluators:
|
||||
cur_score = evaluator(trajectory, config_file, page, client)
|
||||
score *= cur_score
|
||||
|
||||
return score
|
||||
|
||||
|
||||
@beartype
|
||||
def evaluator_router(config_file: Path | str) -> EvaluatorComb:
|
||||
"""Router to get the evaluator class"""
|
||||
with open(config_file, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
eval_types = configs["eval"]["eval_types"]
|
||||
evaluators: list[Evaluator | EvaluatorPartial] = []
|
||||
for eval_type in eval_types:
|
||||
match eval_type:
|
||||
case "string_match":
|
||||
evaluators.append(StringEvaluator())
|
||||
case "url_match":
|
||||
evaluators.append(URLExactEvaluator())
|
||||
case "program_html":
|
||||
evaluators.append(HTMLContentExactEvaluator())
|
||||
case _:
|
||||
raise ValueError(f"eval_type {eval_type} is not supported")
|
||||
|
||||
return EvaluatorComb(evaluators)
|
||||
180
evaluation_harness/helper_functions.py
Normal file
180
evaluation_harness/helper_functions.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Implements helper functions to assist evaluation cases where other evaluators are not suitable."""
|
||||
import json
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import CDPSession, Page
|
||||
|
||||
from browser_env.env_config import (
|
||||
ACCOUNTS,
|
||||
GITLAB,
|
||||
MAP,
|
||||
REDDIT,
|
||||
SHOPPING,
|
||||
SHOPPING_ADMIN,
|
||||
WIKIPEDIA,
|
||||
)
|
||||
from llms.providers.openai_utils import (
|
||||
generate_from_openai_chat_completion,
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_auth_token() -> str:
|
||||
response = requests.post(
|
||||
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(
|
||||
{
|
||||
"username": ACCOUNTS["shopping_site_admin"]["username"],
|
||||
"password": ACCOUNTS["shopping_site_admin"]["password"],
|
||||
}
|
||||
),
|
||||
)
|
||||
token: str = response.json()
|
||||
return token
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_latest_order_url() -> str:
|
||||
"""Get the latest order url from the shopping website."""
|
||||
|
||||
header = {
|
||||
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
params = {
|
||||
"searchCriteria[sortOrders][0][field]": "created_at",
|
||||
"searchCriteria[sortOrders][0][direction]": "DESC",
|
||||
"searchCriteria[pageSize]": "1",
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
f"{SHOPPING}/rest/V1/orders", params=params, headers=header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_obj = response.json()["items"][0]
|
||||
order_id = int(response_obj["increment_id"])
|
||||
order_url = f"{SHOPPING}/sales/order/view/order_id/{order_id}/"
|
||||
return order_url
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_sku_latest_review_author(sku: str) -> str:
|
||||
"""Get the latest review for shopping admin."""
|
||||
header = {
|
||||
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.get(
|
||||
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_obj = response.json()
|
||||
if len(response_obj) == 0:
|
||||
return ""
|
||||
author: str = response_obj[-1]["nickname"]
|
||||
return author
|
||||
|
||||
|
||||
@beartype
|
||||
def shopping_get_sku_latest_review_rating(sku: str) -> str:
|
||||
"""Get the latest review for shopping admin."""
|
||||
header = {
|
||||
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.get(
|
||||
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_obj = response.json()
|
||||
if len(response_obj) == 0:
|
||||
return ""
|
||||
assert response_obj[0]["ratings"][0]["rating_name"] == "Rating"
|
||||
rating: str = str(response_obj[-1]["ratings"][0]["percent"])
|
||||
return rating
|
||||
|
||||
|
||||
@beartype
|
||||
def reddit_get_post_url(url: str) -> str:
|
||||
"""Get the post url"""
|
||||
# Url is http://domain/f/subreddit/post_id/...
|
||||
# get domain, subreddit, post_id
|
||||
domain = urlparse(url).netloc
|
||||
tok_url = urlparse(url).path.split("/")
|
||||
# not a valid post/comment url, return the url as is
|
||||
if len(tok_url) < 4:
|
||||
return url
|
||||
if tok_url[1] != "f":
|
||||
return url
|
||||
subreddit = urlparse(url).path.split("/")[2]
|
||||
post_id = urlparse(url).path.split("/")[3]
|
||||
scheme = urlparse(url).scheme
|
||||
post_url = f"{scheme}://{domain}/f/{subreddit}/{post_id}/"
|
||||
return post_url
|
||||
|
||||
|
||||
@beartype
|
||||
def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
|
||||
# get the account index
|
||||
try:
|
||||
account_idx = page.evaluate(
|
||||
f"""(() => {{
|
||||
const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel");
|
||||
let index = -1; // Default value if not found
|
||||
|
||||
for(let i = 0; i < elements.length; i++) {{
|
||||
if(elements[i].outerText === '@{account_name}') {{
|
||||
index = i;
|
||||
break;
|
||||
}}
|
||||
}}
|
||||
|
||||
return index;
|
||||
}})()"""
|
||||
)
|
||||
|
||||
# get the role
|
||||
role: str = page.evaluate(
|
||||
f"""(() => {{
|
||||
return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText;
|
||||
}})()"""
|
||||
)
|
||||
except Exception:
|
||||
role = ""
|
||||
|
||||
return role
|
||||
|
||||
|
||||
@beartype
|
||||
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
|
||||
"""Check whether the prediction matches the reference with GPT-3.5"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
messages.append(
|
||||
{"role": "system", "content": "You are a helpful assistant"}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f'Given the statement "{pred}", would it be correct to infer "{reference}"? Yes or No',
|
||||
}
|
||||
)
|
||||
|
||||
response = generate_from_openai_chat_completion(
|
||||
messages=messages,
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
context_length=0,
|
||||
max_tokens=16,
|
||||
stop_token=None,
|
||||
)
|
||||
if "Yes" in response:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
1
llms/__init__.py
Normal file
1
llms/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""This module is adapt from https://github.com/zeno-ml/zeno-build"""
|
||||
29
llms/lm_config.py
Normal file
29
llms/lm_config.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Config for language models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LMConfig:
|
||||
"""A config for a language model.
|
||||
|
||||
Attributes:
|
||||
provider: The name of the API provider.
|
||||
model: The name of the model.
|
||||
model_cls: The Python class corresponding to the model, mostly for
|
||||
Hugging Face transformers.
|
||||
tokenizer_cls: The Python class corresponding to the tokenizer, mostly
|
||||
for Hugging Face transformers.
|
||||
mode: The mode of the API calls, e.g., "chat" or "generation".
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
model_cls: type | None = None
|
||||
tokenizer_cls: type | None = None
|
||||
mode: str | None = None
|
||||
gen_config: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
283
llms/providers/openai_utils.py
Normal file
283
llms/providers/openai_utils.py
Normal file
@ -0,0 +1,283 @@
|
||||
"""Tools to generate from OpenAI prompts.
|
||||
Adopted from https://github.com/zeno-ml/zeno-build/"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiolimiter
|
||||
import openai
|
||||
import openai.error
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
||||
def retry_with_exponential_backoff( # type: ignore
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 10,
|
||||
errors: tuple[Any] = (openai.error.RateLimitError,),
|
||||
):
|
||||
"""Retry a function with exponential backoff."""
|
||||
|
||||
def wrapper(*args, **kwargs): # type: ignore
|
||||
# Initialize variables
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Retry on specified errors
|
||||
except errors as e:
|
||||
# Increment retries
|
||||
num_retries += 1
|
||||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
|
||||
# Sleep for the delay
|
||||
time.sleep(delay)
|
||||
|
||||
# Raise exceptions for any errors not specified
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def _throttled_openai_completion_acreate(
|
||||
engine: str,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
limiter: aiolimiter.AsyncLimiter,
|
||||
) -> dict[str, Any]:
|
||||
async with limiter:
|
||||
for _ in range(3):
|
||||
try:
|
||||
return await openai.Completion.acreate( # type: ignore
|
||||
engine=engine,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
except openai.error.RateLimitError:
|
||||
logging.warning(
|
||||
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
except openai.error.APIError as e:
|
||||
logging.warning(f"OpenAI API error: {e}")
|
||||
break
|
||||
return {"choices": [{"message": {"content": ""}}]}
|
||||
|
||||
|
||||
async def agenerate_from_openai_completion(
|
||||
prompts: list[str],
|
||||
engine: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
requests_per_minute: int = 300,
|
||||
) -> list[str]:
|
||||
"""Generate from OpenAI Completion API.
|
||||
|
||||
Args:
|
||||
prompts: list of prompts
|
||||
temperature: Temperature to use.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
top_p: Top p to use.
|
||||
context_length: Length of context to use.
|
||||
requests_per_minute: Number of requests per minute to allow.
|
||||
|
||||
Returns:
|
||||
List of generated responses.
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
||||
async_responses = [
|
||||
_throttled_openai_completion_acreate(
|
||||
engine=engine,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
limiter=limiter,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
responses = await tqdm_asyncio.gather(*async_responses)
|
||||
return [x["choices"][0]["text"] for x in responses]
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
def generate_from_openai_completion(
|
||||
prompt: str,
|
||||
engine: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
stop_token: str | None = None,
|
||||
) -> str:
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
response = openai.Completion.create( # type: ignore
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=[stop_token],
|
||||
)
|
||||
answer: str = response["choices"][0]["text"]
|
||||
return answer
|
||||
|
||||
|
||||
async def _throttled_openai_chat_completion_acreate(
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
limiter: aiolimiter.AsyncLimiter,
|
||||
) -> dict[str, Any]:
|
||||
async with limiter:
|
||||
for _ in range(3):
|
||||
try:
|
||||
return await openai.ChatCompletion.acreate( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
except openai.error.RateLimitError:
|
||||
logging.warning(
|
||||
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
|
||||
await asyncio.sleep(10)
|
||||
except openai.error.APIError as e:
|
||||
logging.warning(f"OpenAI API error: {e}")
|
||||
break
|
||||
return {"choices": [{"message": {"content": ""}}]}
|
||||
|
||||
|
||||
async def agenerate_from_openai_chat_completion(
|
||||
messages_list: list[list[dict[str, str]]],
|
||||
engine: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
requests_per_minute: int = 300,
|
||||
) -> list[str]:
|
||||
"""Generate from OpenAI Chat Completion API.
|
||||
|
||||
Args:
|
||||
messages_list: list of message list
|
||||
temperature: Temperature to use.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
top_p: Top p to use.
|
||||
context_length: Length of context to use.
|
||||
requests_per_minute: Number of requests per minute to allow.
|
||||
|
||||
Returns:
|
||||
List of generated responses.
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
||||
async_responses = [
|
||||
_throttled_openai_chat_completion_acreate(
|
||||
model=engine,
|
||||
messages=message,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
limiter=limiter,
|
||||
)
|
||||
for message in messages_list
|
||||
]
|
||||
responses = await tqdm_asyncio.gather(*async_responses)
|
||||
return [x["choices"][0]["message"]["content"] for x in responses]
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
def generate_from_openai_chat_completion(
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
stop_token: str | None = None,
|
||||
) -> str:
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
response = openai.ChatCompletion.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=[stop_token] if stop_token else None,
|
||||
)
|
||||
answer: str = response["choices"][0]["message"]["content"]
|
||||
return answer
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
# debug only
|
||||
def fake_generate_from_openai_chat_completion(
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
stop_token: str | None = None,
|
||||
) -> str:
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
|
||||
return answer
|
||||
14
llms/tokenizers.py
Normal file
14
llms/tokenizers.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
def __init__(self, model_name: str) -> None:
|
||||
if model_name in ["gpt-4", "gpt-turbo-3.5"]:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, text: str) -> list[int]:
|
||||
return self.tokenizer.encode(text)
|
||||
BIN
media/overview.png
Normal file
BIN
media/overview.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 332 KiB |
6
prepare.sh
Normal file
6
prepare.sh
Normal file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
# prepare the evaluation
|
||||
# re-validate login information
|
||||
mkdir -p ./.auth
|
||||
python browser_env/auto_login.py
|
||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
gymnasium
|
||||
playwright==1.32.1
|
||||
Pillow
|
||||
evaluate
|
||||
openai
|
||||
types-tqdm
|
||||
tiktoken
|
||||
aiolimiter
|
||||
623
run.py
Normal file
623
run.py
Normal file
@ -0,0 +1,623 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark"""
|
||||
import argparse
|
||||
import base64
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from beartype import beartype
|
||||
from PIL import Image
|
||||
from prompt_toolkit import prompt
|
||||
|
||||
from agent import Agent, PromptAgent, TeacherForcingAgent
|
||||
from agent.prompts import *
|
||||
from browser_env import (
|
||||
Action,
|
||||
ActionTypes,
|
||||
ObservationMetadata,
|
||||
ScriptBrowserEnv,
|
||||
StateInfo,
|
||||
action2str,
|
||||
create_stop_action,
|
||||
)
|
||||
from browser_env.actions import is_equivalent
|
||||
from evaluation_harness import evaluator_router
|
||||
from llms import lm_config
|
||||
|
||||
LOG_FOLDER = "log_files"
|
||||
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
|
||||
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
|
||||
|
||||
logger = logging.getLogger("logger")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
file_handler = logging.FileHandler(LOG_FILE_NAME)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Set the log format
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
Trajectory = list[Action | StateInfo]
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
<style>
|
||||
pre {{
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<html>
|
||||
<body>
|
||||
{body}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--render", action="store_true", help="Render the browser"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slow_mo",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Slow down the browser by the specified amount",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["accessibility_tree", "html", "image"],
|
||||
default="accessibility_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--current_viewport_only",
|
||||
action="store_true",
|
||||
help="Only use the current viewport for the observation",
|
||||
)
|
||||
parser.add_argument("--viewport_width", type=int, default=1280)
|
||||
parser.add_argument("--viewport_height", type=int, default=720)
|
||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--max_steps", type=int, default=30)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--agent_type", type=str, default="prompt")
|
||||
parser.add_argument(
|
||||
"--instruction_path",
|
||||
type=str,
|
||||
default="agents/prompts/state_action_agent.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parsing_failure_th",
|
||||
help="When concesecutive parsing failure exceeds this threshold, the agent will stop",
|
||||
type=int,
|
||||
default=3,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeating_action_failure_th",
|
||||
help="When concesecutive repeating action exceeds this threshold, the agent will stop",
|
||||
type=int,
|
||||
default=3,
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
||||
parser.add_argument("--mode", type=str, default="chat")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--context_length", type=int, default=0)
|
||||
parser.add_argument("--max_tokens", type=int, default=384)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--max_obs_length",
|
||||
type=int,
|
||||
help="when not zero, will truncate the observation to this length before feeding to the model",
|
||||
default=1920,
|
||||
)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
||||
parser.add_argument("--test_end_idx", type=int, default=1000)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check the whether the action space is compatible with the observation space
|
||||
if (
|
||||
args.action_set_tag == "id_accessibility_tree"
|
||||
and args.observation_type != "accessibility_tree"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@beartype
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
) -> str:
|
||||
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
return action_str
|
||||
|
||||
|
||||
@beartype
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
prompt_constructor: PromptConstructor | None,
|
||||
) -> str:
|
||||
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||
May contain hint information to recover from the failures"""
|
||||
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
|
||||
return action_str
|
||||
|
||||
|
||||
class RenderHelper(object):
|
||||
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||
|
||||
def __init__(
|
||||
self, config_file: str, result_dir: str, action_set_tag: str
|
||||
) -> None:
|
||||
with open(config_file, "r") as f:
|
||||
_config = json.load(f)
|
||||
_config_str = ""
|
||||
for k, v in _config.items():
|
||||
_config_str += f"{k}: {v}\n"
|
||||
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||
task_id = _config["task_id"]
|
||||
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
self.render_file = open(
|
||||
Path(result_dir) / f"render_{task_id}.html", "a+"
|
||||
)
|
||||
self.render_file.truncate(0)
|
||||
# write init template
|
||||
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||
self.render_file.read()
|
||||
self.render_file.flush()
|
||||
|
||||
def render(
|
||||
self,
|
||||
action: Action,
|
||||
state_info: StateInfo,
|
||||
meta_data: dict[str, Any],
|
||||
render_screenshot: bool = False,
|
||||
) -> None:
|
||||
"""Render the trajectory"""
|
||||
# text observation
|
||||
observation = state_info["observation"]
|
||||
text_obs = observation["text"]
|
||||
info = state_info["info"]
|
||||
new_content = f"<h2>New Page</h2>\n"
|
||||
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||
|
||||
if render_screenshot:
|
||||
# image observation
|
||||
img_obs = observation["image"]
|
||||
image = Image.fromarray(img_obs)
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
byte_io.seek(0)
|
||||
image_bytes = base64.b64encode(byte_io.read())
|
||||
image_str = image_bytes.decode("utf-8")
|
||||
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||
|
||||
# meta data
|
||||
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||
|
||||
# action
|
||||
action_str = get_render_action(
|
||||
action,
|
||||
info["observation_metadata"],
|
||||
action_set_tag=self.action_set_tag,
|
||||
)
|
||||
# with yellow background
|
||||
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||
new_content += f"{action_str}\n"
|
||||
|
||||
# add new content
|
||||
self.render_file.seek(0)
|
||||
html = self.render_file.read()
|
||||
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||
html_body += new_content
|
||||
|
||||
html = HTML_TEMPLATE.format(body=html_body)
|
||||
self.render_file.seek(0)
|
||||
self.render_file.truncate()
|
||||
self.render_file.write(html)
|
||||
self.render_file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.render_file.close()
|
||||
|
||||
|
||||
@beartype
|
||||
def early_stop(
|
||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether need to early stop"""
|
||||
|
||||
# reach the max step
|
||||
num_steps = (len(trajectory) - 1) / 2
|
||||
if num_steps >= max_steps:
|
||||
return True, f"Reach max steps {max_steps}"
|
||||
|
||||
last_k_actions: list[Action]
|
||||
action_seq: list[Action]
|
||||
|
||||
# Case: parsing failure for k times
|
||||
k = thresholds["parsing_failure"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Failed to parse actions for {k} times"
|
||||
|
||||
# Case: same action for k times
|
||||
k = thresholds["repeating_action"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
action_seq = trajectory[1::2] # type: ignore[assignment]
|
||||
|
||||
if len(action_seq) == 0:
|
||||
return False, ""
|
||||
|
||||
last_action: Action = action_seq[-1]
|
||||
|
||||
if last_action["action_type"] != ActionTypes.TYPE:
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
is_equivalent(action, last_action)
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Same action for {k} times"
|
||||
|
||||
else:
|
||||
# check the action sequence
|
||||
if (
|
||||
sum([is_equivalent(action, last_action) for action in action_seq])
|
||||
>= k
|
||||
):
|
||||
return True, f"Same typing action for {k} times"
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
@beartype
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
agent: Agent | PromptAgent,
|
||||
config_file_list: list[str],
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
early_stop_thresholds = {
|
||||
"parsing_failure": args.parsing_failure_th,
|
||||
"repeating_action": args.repeating_action_failure_th,
|
||||
}
|
||||
|
||||
env = ScriptBrowserEnv(
|
||||
headless=not args.render,
|
||||
slow_mo=args.slow_mo,
|
||||
observation_type=args.observation_type,
|
||||
current_viewport_only=args.current_viewport_only,
|
||||
viewport_size={
|
||||
"width": args.viewport_width,
|
||||
"height": args.viewport_height,
|
||||
},
|
||||
save_trace_enabled=args.save_trace_enabled,
|
||||
sleep_after_execution=args.sleep_after_execution,
|
||||
)
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
render_helper = RenderHelper(
|
||||
config_file, args.result_dir, args.action_set_tag
|
||||
)
|
||||
|
||||
# get intent
|
||||
with open(config_file) as f:
|
||||
_c = json.load(f)
|
||||
intent = _c["intent"]
|
||||
task_id = _c["task_id"]
|
||||
|
||||
logger.info(f"[Config file]: {config_file}")
|
||||
logger.info(f"[Intent]: {intent}")
|
||||
|
||||
agent.reset(config_file)
|
||||
trajectory: Trajectory = []
|
||||
obs, info = env.reset(options={"config_file": config_file})
|
||||
state_info: StateInfo = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
meta_data = {"action_history": ["None"]}
|
||||
while True:
|
||||
early_stop_flag, stop_info = early_stop(
|
||||
trajectory, max_steps, early_stop_thresholds
|
||||
)
|
||||
|
||||
if early_stop_flag:
|
||||
action = create_stop_action(f"Early stop: {stop_info}")
|
||||
else:
|
||||
try:
|
||||
action = agent.next_action(
|
||||
trajectory, intent, meta_data=meta_data
|
||||
)
|
||||
except ValueError as e:
|
||||
# get the error message
|
||||
action = create_stop_action(f"ERROR: {str(e)}")
|
||||
|
||||
trajectory.append(action)
|
||||
|
||||
action_str = get_action_description(
|
||||
action,
|
||||
state_info["info"]["observation_metadata"],
|
||||
action_set_tag=args.action_set_tag,
|
||||
prompt_constructor=agent.prompt_constructor
|
||||
if isinstance(agent, PromptAgent)
|
||||
else None,
|
||||
)
|
||||
render_helper.render(
|
||||
action, state_info, meta_data, args.render_screenshot
|
||||
)
|
||||
meta_data["action_history"].append(action_str)
|
||||
|
||||
if action["action_type"] == ActionTypes.STOP:
|
||||
break
|
||||
|
||||
obs, _, terminated, _, info = env.step(action)
|
||||
state_info = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
if terminated:
|
||||
# add a action place holder
|
||||
trajectory.append(create_stop_action(""))
|
||||
break
|
||||
|
||||
evaluator = evaluator_router(config_file)
|
||||
score = evaluator(
|
||||
trajectory=trajectory,
|
||||
config_file=config_file,
|
||||
page=env.page,
|
||||
client=env.get_page_client(env.page),
|
||||
)
|
||||
|
||||
scores.append(score)
|
||||
|
||||
if score == 1:
|
||||
logger.info(f"[Result] (PASS) {config_file}")
|
||||
else:
|
||||
logger.info(f"[Result] (FAIL) {config_file}")
|
||||
|
||||
if args.save_trace_enabled:
|
||||
env.save_trace(
|
||||
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
||||
)
|
||||
|
||||
except openai.error.OpenAIError as e:
|
||||
logger.info(f"[OpenAI Error] {repr(e)}")
|
||||
except Exception as e:
|
||||
logger.info(f"[Unhandled Error] {repr(e)}]")
|
||||
import traceback
|
||||
|
||||
# write to error file
|
||||
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
||||
f.write(f"[Config file]: {config_file}\n")
|
||||
f.write(f"[Unhandled Error] {repr(e)}\n")
|
||||
f.write(traceback.format_exc()) # write stack trace to file
|
||||
|
||||
# logger.info(f"[Render] {render_helper.render_file.name}")
|
||||
# subprocess.run(["open", render_helper.render_file.name])
|
||||
render_helper.close()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
|
||||
llm_config = lm_config.LMConfig(
|
||||
provider=args.provider, model=args.model, mode=args.mode
|
||||
)
|
||||
if args.provider == "openai":
|
||||
llm_config.gen_config["temperature"] = args.temperature
|
||||
llm_config.gen_config["top_p"] = args.top_p
|
||||
llm_config.gen_config["context_length"] = args.context_length
|
||||
llm_config.gen_config["max_tokens"] = args.max_tokens
|
||||
llm_config.gen_config["stop_token"] = args.stop_token
|
||||
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
||||
else:
|
||||
raise NotImplementedError(f"provider {args.provider} not implemented")
|
||||
return llm_config
|
||||
|
||||
|
||||
def construct_agent(args: argparse.Namespace) -> Agent:
|
||||
llm_config = construct_llm_config(args)
|
||||
|
||||
agent: Agent
|
||||
if args.agent_type == "teacher_forcing":
|
||||
agent = TeacherForcingAgent()
|
||||
elif args.agent_type == "prompt":
|
||||
with open(args.instruction_path) as f:
|
||||
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
||||
tokenizer = tiktoken.encoding_for_model(llm_config.model)
|
||||
prompt_constructor = eval(constructor_type)(
|
||||
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
||||
)
|
||||
agent = PromptAgent(
|
||||
action_set_tag=args.action_set_tag,
|
||||
lm_config=llm_config,
|
||||
prompt_constructor=prompt_constructor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"agent type {args.agent_type} not implemented"
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def prepare(args: argparse.Namespace) -> None:
|
||||
# convert prompt python files to json
|
||||
from agent.prompts import to_json
|
||||
|
||||
to_json.run()
|
||||
|
||||
# prepare result dir
|
||||
result_dir = args.result_dir
|
||||
if not result_dir:
|
||||
result_dir = (
|
||||
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
||||
)
|
||||
if not Path(result_dir).exists():
|
||||
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
||||
args.result_dir = result_dir
|
||||
logger.info(f"Create result dir: {result_dir}")
|
||||
|
||||
if not (Path(result_dir) / "traces").exists():
|
||||
(Path(result_dir) / "traces").mkdir(parents=True)
|
||||
|
||||
# log the log file
|
||||
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
||||
f.write(f"{LOG_FILE_NAME}\n")
|
||||
|
||||
|
||||
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
||||
result_files = glob.glob(f"{result_dir}/*.html")
|
||||
task_ids = [
|
||||
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
||||
]
|
||||
unfinished_configs = []
|
||||
for config_file in config_files:
|
||||
task_id = os.path.basename(config_file).split(".")[0]
|
||||
if task_id not in task_ids:
|
||||
unfinished_configs.append(config_file)
|
||||
return unfinished_configs
|
||||
|
||||
|
||||
@beartype
|
||||
def dump_config(args: argparse.Namespace) -> None:
|
||||
config_file = Path(args.result_dir) / "config.json"
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
logger.info(f"Dump config to {config_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config()
|
||||
args.sleep_after_execution = 2.5
|
||||
prepare(args)
|
||||
|
||||
test_file_list = []
|
||||
st_idx = args.test_start_idx
|
||||
ed_idx = args.test_end_idx
|
||||
for i in range(st_idx, ed_idx):
|
||||
test_file_list.append(f"config_files/{i}.json")
|
||||
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
||||
print(f"Total {len(test_file_list)} tasks left")
|
||||
args.render = True
|
||||
args.render_screenshot = True
|
||||
args.save_trace_enabled = True
|
||||
|
||||
args.current_viewport_only = True
|
||||
dump_config(args)
|
||||
|
||||
agent = construct_agent(args)
|
||||
test(args, agent, test_file_list)
|
||||
55
scripts/collect_obs.py
Normal file
55
scripts/collect_obs.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Simple script to quickly get the observation of a page"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from beartype import beartype
|
||||
from playwright.sync_api import Page, expect
|
||||
|
||||
from browser_env import (
|
||||
ScriptBrowserEnv,
|
||||
create_id_based_action,
|
||||
create_key_press_action,
|
||||
create_playwright_action,
|
||||
create_scroll_action,
|
||||
)
|
||||
from browser_env.env_config import *
|
||||
|
||||
HEADLESS = False
|
||||
|
||||
|
||||
@beartype
|
||||
def gen_tmp_storage_state() -> None:
|
||||
with open(f"scripts/tmp_storage_state.json", "w") as f:
|
||||
json.dump({"storage_state": ".auth/reddit_state.json"}, f)
|
||||
|
||||
|
||||
@beartype
|
||||
def get_observation(
|
||||
observation_type: str, current_viewport_only: bool
|
||||
) -> None:
|
||||
env = ScriptBrowserEnv(
|
||||
observation_type=observation_type,
|
||||
current_viewport_only=current_viewport_only,
|
||||
headless=HEADLESS,
|
||||
)
|
||||
env.reset(options={"config_file": f"scripts/tmp_storage_state.json"})
|
||||
s = f"""page.goto("{GITLAB}")
|
||||
page.scroll(down)"""
|
||||
action_seq = s.split("\n")
|
||||
|
||||
for action in action_seq:
|
||||
action = action.strip()
|
||||
obs, success, _, _, info = env.step(create_playwright_action(action))
|
||||
print(obs["text"])
|
||||
_ = input("Press enter to continue")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gen_tmp_storage_state()
|
||||
obs_type = "accessibility_tree"
|
||||
current_viewport_only = True
|
||||
get_observation(obs_type, current_viewport_only)
|
||||
27
scripts/generate_test_data.py
Normal file
27
scripts/generate_test_data.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Replace the website placeholders with website domains from env_config
|
||||
Generate the test data"""
|
||||
import json
|
||||
|
||||
from browser_env.env_config import *
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with open("config_files/test.raw.json", "r") as f:
|
||||
raw = f.read()
|
||||
raw = raw.replace("__GITLAB__", GITLAB)
|
||||
raw = raw.replace("__REDDIT__", REDDIT)
|
||||
raw = raw.replace("__SHOPPING__", SHOPPING)
|
||||
raw = raw.replace("__SHOPPING_ADMIN__", SHOPPING_ADMIN)
|
||||
raw = raw.replace("__WIKIPEDIA__", WIKIPEDIA)
|
||||
raw = raw.replace("__MAP__", MAP)
|
||||
with open("config_files/test.json", "w") as f:
|
||||
f.write(raw)
|
||||
# split to multiple files
|
||||
data = json.loads(raw)
|
||||
for idx, item in enumerate(data):
|
||||
with open(f"config_files/{idx}.json", "w") as f:
|
||||
json.dump(item, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
setup.cfg
Normal file
26
setup.cfg
Normal file
@ -0,0 +1,26 @@
|
||||
[metadata]
|
||||
name = webarena
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
|
||||
[options.extras_require]
|
||||
dev =
|
||||
pre-commit==3.0.1
|
||||
pytest==7.1.2
|
||||
mypy==0.991
|
||||
beartype==0.12.0
|
||||
nbmake
|
||||
pytest-asyncio
|
||||
types-requests
|
||||
|
||||
[options]
|
||||
python_requires = >=3.7, <4
|
||||
packages =
|
||||
browser_env
|
||||
agent
|
||||
evaluation_harness
|
||||
llms
|
||||
[mypy]
|
||||
strict = true
|
||||
4
setup.py
Normal file
4
setup.py
Normal file
@ -0,0 +1,4 @@
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
72
tests/conftest.py
Normal file
72
tests/conftest.py
Normal file
@ -0,0 +1,72 @@
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from browser_env import AsyncScriptBrowserEnv, ScriptBrowserEnv
|
||||
|
||||
HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def script_browser_env() -> Generator[ScriptBrowserEnv, None, None]:
|
||||
"""Create a ScriptBrowserEnv instance for testing.
|
||||
It is automatically closed after the test session.
|
||||
This is helpful when the test failed and the browser is still open.
|
||||
"""
|
||||
env = ScriptBrowserEnv(
|
||||
headless=HEADLESS,
|
||||
slow_mo=SLOW_MO,
|
||||
)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def current_viewport_script_browser_env() -> Generator[
|
||||
ScriptBrowserEnv, None, None
|
||||
]:
|
||||
env = ScriptBrowserEnv(
|
||||
headless=HEADLESS,
|
||||
slow_mo=SLOW_MO,
|
||||
current_viewport_only=True,
|
||||
)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def accessibility_tree_script_browser_env() -> Generator[
|
||||
ScriptBrowserEnv, None, None
|
||||
]:
|
||||
env = ScriptBrowserEnv(
|
||||
headless=HEADLESS,
|
||||
slow_mo=SLOW_MO,
|
||||
observation_type="accessibility_tree",
|
||||
)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def accessibility_tree_current_viewport_script_browser_env() -> Generator[
|
||||
ScriptBrowserEnv, None, None
|
||||
]:
|
||||
env = ScriptBrowserEnv(
|
||||
headless=HEADLESS,
|
||||
slow_mo=SLOW_MO,
|
||||
observation_type="accessibility_tree",
|
||||
current_viewport_only=True,
|
||||
)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
async def async_script_browser_env() -> AsyncGenerator[
|
||||
AsyncScriptBrowserEnv, None
|
||||
]:
|
||||
env = AsyncScriptBrowserEnv(headless=HEADLESS, slow_mo=SLOW_MO)
|
||||
yield env
|
||||
await env.aclose()
|
||||
273
tests/test_browser_env/test_action_functionalities.py
Normal file
273
tests/test_browser_env/test_action_functionalities.py
Normal file
@ -0,0 +1,273 @@
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from playwright.sync_api import Page, expect
|
||||
|
||||
from browser_env import (
|
||||
ScriptBrowserEnv,
|
||||
create_id_based_action,
|
||||
create_key_press_action,
|
||||
create_playwright_action,
|
||||
create_scroll_action,
|
||||
)
|
||||
|
||||
HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
def test_frame_locator(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://www.littlewebhut.com/articles/html_iframe_example/")
|
||||
page.frame_locator("iframe[name=\\"imgbox\\"]").get_by_role("img").click()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_basic(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
# click, fill, press, check, goto
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://demo.playwright.dev/todomvc/")
|
||||
page.get_by_placeholder("What needs to be done?").click()
|
||||
page.get_by_placeholder("What needs to be done?").fill("hello")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("world")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("yes")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("no")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_role("listitem").filter(has_text="world").get_by_role("checkbox", name="Toggle Todo").check()
|
||||
page.get_by_role("button", name="Clear completed").click()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_hover(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://ianlunn.github.io/Hover/")
|
||||
page.get_by_role("link", name="Download on GitHub").hover()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_select_option(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://russmaxdesign.github.io/exercise/#link-two")
|
||||
page.get_by_role("combobox", name="Favourite mammal").select_option("African Wild Dog")"""
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_xpath(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
|
||||
seq = """page.goto("https://demo.playwright.dev/todomvc/")
|
||||
page.goto("https://demo.playwright.dev/todomvc/#/")
|
||||
page.get_by_placeholder("What needs to be done?").click()
|
||||
page.get_by_placeholder("What needs to be done?").fill("hello")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_role("link", name="Completed").click()
|
||||
page.locator("xpath=/html/body/section/div/header/input").fill("no")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.goto("https://bic-berkeley.github.io/psych-214-fall-2016/string_literals.html")
|
||||
page.locator("xpath=//*[@id=\'searchbox\']/div/form/input[1]").fill("type")"""
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_inter_page_actions(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://demo.playwright.dev/todomvc/")
|
||||
browser.new_tab()
|
||||
browser.page_focus(0)
|
||||
browser.page_focus(1)
|
||||
page.page_close()
|
||||
page.goto("https://google.com")
|
||||
page.goto("https://demo.playwright.dev/todomvc/")
|
||||
page.go_back()
|
||||
page.go_forward()"""
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
assert "https://demo.playwright.dev/todomvc" in info["page"].url
|
||||
|
||||
|
||||
def test_scroll(current_viewport_script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = current_viewport_script_browser_env
|
||||
env.reset()
|
||||
_, success, _, _, _ = env.step(create_scroll_action("down"))
|
||||
assert success
|
||||
_, success, _, _, _ = env.step(create_scroll_action("up"))
|
||||
assert success
|
||||
|
||||
|
||||
def test_id_click(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert "link 'McKenna/Bell'" in obs["text"]
|
||||
# get the id of the link
|
||||
element_id = re.search(r"\[(\d+)\] link 'McKenna/Bell'", obs["text"]).group(1) # type: ignore
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"click [{element_id}]")
|
||||
)
|
||||
assert success
|
||||
assert (
|
||||
info["page"].url
|
||||
== "https://russmaxdesign.github.io/exercise/#link-four"
|
||||
)
|
||||
|
||||
obs, success, _, _, info = env.step(create_scroll_action("down"))
|
||||
assert "link 'Classification'" in obs["text"]
|
||||
element_id = re.search(r"\[(\d+)\] link 'Classification'", obs["text"]).group(1) # type: ignore
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"click [{element_id}]")
|
||||
)
|
||||
assert success
|
||||
assert (
|
||||
info["page"].url
|
||||
== "https://russmaxdesign.github.io/exercise/#link-two"
|
||||
)
|
||||
assert "radio 'Weekly'" in obs["text"]
|
||||
element_id = re.search(r"\[(\d+)\] radio 'Weekly'", obs["text"]).group(1) # type: ignore
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"click [{element_id}]")
|
||||
)
|
||||
assert success
|
||||
assert "radio 'Weekly'" in obs["text"]
|
||||
|
||||
|
||||
def test_id_hover(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://ianlunn.github.io/Hover/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert "link 'Download on GitHub'" in obs["text"]
|
||||
element_id = re.search(r"\[(\d+)\] link 'Download on GitHub'", obs["text"]).group(1) # type: ignore
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"hover [{element_id}]")
|
||||
)
|
||||
assert success
|
||||
|
||||
|
||||
def test_key_press(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert "textbox 'Full name'" in obs["text"]
|
||||
element_id = re.search(r"\[(\d+)\] textbox 'Full name'", obs["text"]).group(1) # type: ignore
|
||||
s = "My Name IS XYZ"
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"type [{element_id}] [{s}] [0]")
|
||||
)
|
||||
|
||||
assert success
|
||||
expect(env.page.get_by_label("Full name")).to_be_focused()
|
||||
|
||||
obs, success, _, _, info = env.step(create_key_press_action("Enter"))
|
||||
assert success
|
||||
expect(env.page.get_by_label("Email")).to_be_focused()
|
||||
|
||||
|
||||
def test_id_type(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert "textbox 'Full name'" in obs["text"]
|
||||
s = "My Name IS XYZ"
|
||||
element_id = re.search(r"\[(\d+)\] textbox 'Full name'", obs["text"]).group(1) # type: ignore
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_id_based_action(f"type [{element_id}] [{s}]")
|
||||
)
|
||||
assert success
|
||||
locator = env.page.get_by_label("Full name")
|
||||
expect(locator).to_have_value(s)
|
||||
|
||||
|
||||
def test_e2e_id_based_actions(
|
||||
accessibility_tree_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_script_browser_env
|
||||
env.reset()
|
||||
obs, *_ = env.step(
|
||||
create_id_based_action(
|
||||
"goto [https://russmaxdesign.github.io/exercise/]"
|
||||
)
|
||||
)
|
||||
element_id = re.search(r"\[(\d+)\] link 'What are mammals\?'", obs["text"]).group(1) # type: ignore
|
||||
obs, *_ = env.step(create_id_based_action(f"click [{element_id}]"))
|
||||
element_id = re.search(r"\[(\d+)\] textbox 'Email'", obs["text"]).group(1) # type: ignore
|
||||
env.step(
|
||||
create_id_based_action(f"type [{element_id}] [test@gmail.com] [0]")
|
||||
)
|
||||
env.step(create_id_based_action("scroll [down]"))
|
||||
env.step(create_id_based_action("scroll [up]"))
|
||||
env.step(create_id_based_action("new_tab"))
|
||||
env.step(create_id_based_action("tab_focus [0]"))
|
||||
env.step(create_id_based_action("tab_focus [1]"))
|
||||
env.step(create_id_based_action("goto [https://example.com/]"))
|
||||
env.step(create_id_based_action("go_back"))
|
||||
x = env.step(create_id_based_action("go_forward"))
|
||||
assert x[-1]["page"].url == "https://example.com/"
|
||||
x = env.step(create_id_based_action("tab_focus [0]"))
|
||||
assert (
|
||||
x[-1]["page"].url
|
||||
== "https://russmaxdesign.github.io/exercise/#link-one"
|
||||
)
|
||||
87
tests/test_browser_env/test_actions.py
Normal file
87
tests/test_browser_env/test_actions.py
Normal file
@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
|
||||
from browser_env import *
|
||||
|
||||
|
||||
def test_is_equivalent() -> None:
|
||||
for action_type in ActionTypes.__members__.values():
|
||||
action_a = create_random_action()
|
||||
action_b = create_random_action()
|
||||
if action_a["action_type"] != action_b["action_type"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["action_type"] = action_type
|
||||
action_b["action_type"] = action_type
|
||||
match action_type:
|
||||
case ActionTypes.MOUSE_CLICK | ActionTypes.MOUSE_HOVER:
|
||||
if not np.allclose(action_a["coords"], action_b["coords"]):
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["coords"] = action_b["coords"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.KEYBOARD_TYPE:
|
||||
if action_a["text"] != action_b["text"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["text"] = action_b["text"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.CLICK | ActionTypes.HOVER | ActionTypes.TYPE:
|
||||
if action_a["element_id"] and action_b["element_id"]:
|
||||
if action_a["element_id"] != action_b["element_id"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["element_id"] = action_b["element_id"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
elif action_a["element_id"] and action_b["element_id"]:
|
||||
if action_a["element_role"] != action_b["element_role"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["element_role"] = action_b["element_role"]
|
||||
if action_a["element_name"] != action_b["element_name"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["element_name"] = action_b["element_name"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
elif action_a["pw_code"] and action_b["pw_code"]:
|
||||
if action_a["pw_code"] != action_b["pw_code"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["pw_code"] = action_b["pw_code"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
else:
|
||||
action_a["element_id"] = action_b["element_id"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.GOTO_URL:
|
||||
if action_a["url"] != action_b["url"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["url"] = action_b["url"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.PAGE_FOCUS:
|
||||
if action_a["page_number"] != action_b["page_number"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["page_number"] = action_b["page_number"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.SCROLL:
|
||||
da = "up" if "up" in action_a["direction"] else "down"
|
||||
db = "up" if "up" in action_b["direction"] else "down"
|
||||
if da != db:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["direction"] = action_b["direction"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.KEY_PRESS:
|
||||
if action_a["key_comb"] != action_b["key_comb"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["key_comb"] = action_b["key_comb"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.CHECK | ActionTypes.SELECT_OPTION:
|
||||
if action_a["pw_code"] != action_b["pw_code"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["pw_code"] = action_b["pw_code"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case ActionTypes.STOP:
|
||||
if action_a["answer"] != action_b["answer"]:
|
||||
assert not is_equivalent(action_a, action_b)
|
||||
action_a["answer"] = action_b["answer"]
|
||||
assert is_equivalent(action_a, action_b)
|
||||
case _:
|
||||
assert is_equivalent(action_a, action_b)
|
||||
|
||||
|
||||
def test_action2create_function() -> None:
|
||||
for _ in range(1000):
|
||||
action = create_random_action()
|
||||
create_function = action2create_function(action)
|
||||
assert is_equivalent(action, eval(create_function))
|
||||
67
tests/test_browser_env/test_auth_cookie.py
Normal file
67
tests/test_browser_env/test_auth_cookie.py
Normal file
@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from browser_env import *
|
||||
|
||||
auth_json = {
|
||||
"cookies": [
|
||||
{
|
||||
"name": "session-username",
|
||||
"value": "standard_user",
|
||||
"domain": "www.saucedemo.com",
|
||||
"path": "/",
|
||||
"httpOnly": False,
|
||||
"secure": False,
|
||||
"sameSite": "Lax",
|
||||
}
|
||||
],
|
||||
"origins": [],
|
||||
}
|
||||
|
||||
|
||||
def test_auth_cookie() -> None:
|
||||
env = ScriptBrowserEnv()
|
||||
env.reset()
|
||||
_, reward, _, _, info = env.step(
|
||||
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
|
||||
)
|
||||
assert reward == 1
|
||||
assert "page" in info and isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.saucedemo.com/"
|
||||
json.dump(auth_json, open("/tmp/auth.json", "w"))
|
||||
instance_config = {"storage_state": "/tmp/auth.json"}
|
||||
json.dump(instance_config, open("/tmp/config.json", "w"))
|
||||
env.reset(options={"config_file": "/tmp/config.json"})
|
||||
_, reward, _, _, info = env.step(
|
||||
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
|
||||
)
|
||||
assert reward == 1
|
||||
assert "page" in info and isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.saucedemo.com/inventory.html"
|
||||
env.close()
|
||||
|
||||
|
||||
def test_async_auth_cookie() -> None:
|
||||
env = AsyncScriptBrowserEnv()
|
||||
|
||||
async def _test() -> None:
|
||||
await env.areset()
|
||||
_, reward, _, _, info = await env.astep(
|
||||
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
|
||||
)
|
||||
assert reward == 1
|
||||
assert "page" in info and isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.saucedemo.com/"
|
||||
json.dump(auth_json, open("/tmp/auth.json", "w"))
|
||||
instance_config = {"storage_state": "/tmp/auth.json"}
|
||||
json.dump(instance_config, open("/tmp/config.json", "w"))
|
||||
await env.areset(options={"config_file": "/tmp/config.json"})
|
||||
_, reward, _, _, info = await env.astep(
|
||||
create_goto_url_action("https://www.saucedemo.com/inventory.html"),
|
||||
)
|
||||
assert reward == 1
|
||||
assert "page" in info and isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.saucedemo.com/inventory.html"
|
||||
await env.aclose()
|
||||
|
||||
asyncio.run(_test())
|
||||
89
tests/test_browser_env/test_playwright_actions.py
Normal file
89
tests/test_browser_env/test_playwright_actions.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import Dict, Generator, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from playwright.sync_api import Page
|
||||
|
||||
from browser_env import ScriptBrowserEnv, create_playwright_action
|
||||
|
||||
HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
def test_frame_locator(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://www.littlewebhut.com/articles/html_iframe_example/")
|
||||
page.frame_locator("iframe[name=\\"imgbox\\"]").get_by_role("img").click()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_basic(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
# click, fill, press, check, goto
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://demo.playwright.dev/todomvc/")
|
||||
page.get_by_placeholder("What needs to be done?").click()
|
||||
page.get_by_placeholder("What needs to be done?").fill("hello")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("world")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("yes")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_placeholder("What needs to be done?").fill("no")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_role("listitem").filter(has_text="world").get_by_role("checkbox", name="Toggle Todo").check()
|
||||
page.get_by_role("button", name="Clear completed").click()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not important, but the site is flaky")
|
||||
def test_hover(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://www.w3schools.com/cssref/tryit.php?filename=trycss_sel_hover")
|
||||
page.frame_locator("iframe[name=\\'iframeResult\\']").get_by_role("link", name="w3schools.com").hover()"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not important, but the site is flaky")
|
||||
def test_select_option(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://www.w3schools.com/tags/tryit.asp?filename=tryhtml_select")
|
||||
page.frame_locator("iframe[name=\\'iframeResult\\']").get_by_role("combobox", name="Choose a car:").select_option("opel")"""
|
||||
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
|
||||
|
||||
def test_xpath(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
seq = """page.goto("https://demo.playwright.dev/todomvc/")
|
||||
page.goto("https://demo.playwright.dev/todomvc/#/")
|
||||
page.get_by_placeholder("What needs to be done?").click()
|
||||
page.get_by_placeholder("What needs to be done?").fill("hello")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.get_by_role("link", name="Completed").click()
|
||||
page.locator("xpath=/html/body/section/div/header/input").fill("no")
|
||||
page.get_by_placeholder("What needs to be done?").press("Enter")
|
||||
page.goto("https://bic-berkeley.github.io/psych-214-fall-2016/string_literals.html")
|
||||
page.locator("xpath=//*[@id=\'searchbox\']/div/form/input[1]").fill("type")"""
|
||||
env.reset()
|
||||
for action in seq.split("\n"):
|
||||
action = action.strip()
|
||||
_, success, _, _, info = env.step(create_playwright_action(action))
|
||||
assert success
|
||||
304
tests/test_browser_env/test_script_browser_env.py
Normal file
304
tests/test_browser_env/test_script_browser_env.py
Normal file
@ -0,0 +1,304 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import tempfile
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import pytest
|
||||
from beartype.door import is_bearable
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
from playwright.sync_api import Page
|
||||
|
||||
from browser_env import (
|
||||
Action,
|
||||
AsyncScriptBrowserEnv,
|
||||
DetachedPage,
|
||||
ScriptBrowserEnv,
|
||||
create_focus_and_click_action,
|
||||
create_goto_url_action,
|
||||
create_keyboard_type_action,
|
||||
create_playwright_action,
|
||||
create_scroll_action,
|
||||
)
|
||||
from browser_env.actions import create_id_based_action
|
||||
from browser_env.env_config import *
|
||||
|
||||
|
||||
def test_script_browser_env(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
env.reset()
|
||||
env.step(
|
||||
create_goto_url_action("http://www.example.com"),
|
||||
)
|
||||
env.step(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="More",
|
||||
),
|
||||
)
|
||||
_, _, _, _, info = env.step(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="2606",
|
||||
)
|
||||
)
|
||||
assert isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_script_browser_env(
|
||||
async_script_browser_env: AsyncScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = async_script_browser_env
|
||||
await env.areset()
|
||||
await env.astep(
|
||||
create_goto_url_action("http://www.example.com"),
|
||||
)
|
||||
await env.astep(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="More",
|
||||
),
|
||||
)
|
||||
_, _, _, _, info = await env.astep(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="2606",
|
||||
)
|
||||
)
|
||||
assert isinstance(info["page"], DetachedPage)
|
||||
assert info["page"].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
|
||||
|
||||
|
||||
def collate_actions(actions: list[Action]) -> dict[str, list[object]]:
|
||||
action_dict = collections.defaultdict(list)
|
||||
for action in actions:
|
||||
for key, value in action.items():
|
||||
action_dict[key].append(value)
|
||||
return action_dict
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Gym doesn't support self-defined observations")
|
||||
def test_parallel_script_browser_env() -> None:
|
||||
vector_env = AsyncVectorEnv(
|
||||
[
|
||||
lambda: ScriptBrowserEnv(),
|
||||
lambda: ScriptBrowserEnv(),
|
||||
],
|
||||
shared_memory=True,
|
||||
)
|
||||
vector_env.reset()
|
||||
vector_env.step(
|
||||
collate_actions(
|
||||
[
|
||||
create_goto_url_action("http://www.example.com"),
|
||||
]
|
||||
* 2
|
||||
)
|
||||
)
|
||||
vector_env.step(
|
||||
collate_actions(
|
||||
[
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="More",
|
||||
),
|
||||
]
|
||||
* 2
|
||||
)
|
||||
)
|
||||
_, _, _, _, info = vector_env.step(
|
||||
collate_actions(
|
||||
[
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="2606",
|
||||
),
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="6761",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
assert is_bearable(info["page"].tolist(), list[DetachedPage])
|
||||
assert info["page"][0].url == "https://www.rfc-editor.org/rfc/rfc2606.html"
|
||||
assert info["page"][1].url == "https://www.rfc-editor.org/rfc/rfc6761.html"
|
||||
vector_env.close() # type: ignore[no-untyped-call]
|
||||
|
||||
|
||||
def test_is_in_viewport(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
env = script_browser_env
|
||||
env.reset()
|
||||
env.step(
|
||||
create_goto_url_action("https://www.iana.org/domains/reserved"),
|
||||
)
|
||||
_, _, _, _, info = env.step(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="IDN",
|
||||
nth=1,
|
||||
),
|
||||
)
|
||||
assert (
|
||||
info["page"].url
|
||||
== "https://www.icann.org/resources/pages/idn-2012-02-25-en"
|
||||
)
|
||||
env.step(
|
||||
create_goto_url_action("https://www.iana.org/domains/reserved"),
|
||||
)
|
||||
_, _, _, _, info = env.step(create_keyboard_type_action(keys=["PageDown"]))
|
||||
_, _, _, _, info = env.step(
|
||||
create_focus_and_click_action(
|
||||
element_role="link",
|
||||
element_name="IDN",
|
||||
),
|
||||
)
|
||||
assert info["page"].url == "https://www.iana.org/domains/idn-tables"
|
||||
|
||||
|
||||
def test_focus_placeholder_and_label(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = script_browser_env
|
||||
env.reset()
|
||||
for action in [
|
||||
create_goto_url_action("https://demo.applitools.com"),
|
||||
create_focus_and_click_action("placeholder", "Enter your username"),
|
||||
create_keyboard_type_action("abc"),
|
||||
create_focus_and_click_action("placeholder", "Enter your password"),
|
||||
create_keyboard_type_action("123"),
|
||||
create_focus_and_click_action("label", "Remember Me"),
|
||||
create_focus_and_click_action("link", "Sign in"),
|
||||
]:
|
||||
_, success, _, _, info = env.step(action)
|
||||
assert success
|
||||
assert info["page"].url == "https://demo.applitools.com/app.html"
|
||||
|
||||
|
||||
def test_current_viewport(
|
||||
current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
s1 = "detailed information about how mammals could be classified."
|
||||
s2 = "Types of mammals"
|
||||
env = current_viewport_script_browser_env
|
||||
env.reset()
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert s1 in obs["text"] and s2 not in obs["text"]
|
||||
obs, success, _, _, info = env.step(create_scroll_action("down"))
|
||||
assert success
|
||||
assert s1 not in obs["text"] and s2 in obs["text"]
|
||||
|
||||
|
||||
def test_accessibility_tree(
|
||||
accessibility_tree_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
s1 = "checkbox 'Yes'"
|
||||
s2 = "button 'Submit'"
|
||||
env = accessibility_tree_script_browser_env
|
||||
env.reset()
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert s1 in obs["text"] and s2 in obs["text"]
|
||||
|
||||
|
||||
def test_accessibility_tree_viewport(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
s1 = "combobox 'Favourite mammal'"
|
||||
s2 = "gridcell 'Canyon bat'"
|
||||
s3 = "heading 'Useful links'"
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
|
||||
obs, success, _, _, info = env.step(
|
||||
create_playwright_action(
|
||||
'page.goto("https://russmaxdesign.github.io/exercise/")'
|
||||
)
|
||||
)
|
||||
assert success
|
||||
assert (
|
||||
s1 in obs["text"] and s2 not in obs["text"] and s3 not in obs["text"]
|
||||
)
|
||||
|
||||
obs, success, _, _, info = env.step(create_scroll_action("down"))
|
||||
assert success
|
||||
assert (
|
||||
s1 not in obs["text"] and s2 in obs["text"] and s3 not in obs["text"]
|
||||
)
|
||||
|
||||
obs, success, _, _, info = env.step(create_scroll_action("down"))
|
||||
assert success
|
||||
assert s1 not in obs["text"] and s2 in obs["text"] and s3 in obs["text"]
|
||||
|
||||
|
||||
def test_multiple_start_url(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
temp_config = tempfile.NamedTemporaryFile("w", delete=False)
|
||||
config = {
|
||||
"require_login": False,
|
||||
"start_url": f"{REDDIT} |AND| {REDDIT}/forums",
|
||||
}
|
||||
json.dump(config, temp_config)
|
||||
temp_config.close()
|
||||
|
||||
env = script_browser_env
|
||||
env.reset(options={"config_file": temp_config.name})
|
||||
assert len(env.context.pages) == 2
|
||||
assert env.context.pages[0].url == f"{REDDIT}/"
|
||||
assert env.context.pages[1].url == f"{REDDIT}/forums", env.context.pages[
|
||||
1
|
||||
].url
|
||||
|
||||
|
||||
def test_observation_tab_information(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
obs, *_ = env.step(
|
||||
create_id_based_action(
|
||||
"goto [https://russmaxdesign.github.io/exercise/]"
|
||||
)
|
||||
)
|
||||
obs, *_ = env.step(create_id_based_action("new_tab"))
|
||||
|
||||
obs, *_ = env.step(
|
||||
create_id_based_action("goto [https:///www.google.com]")
|
||||
)
|
||||
assert obs["text"].startswith( # type: ignore[union-attr]
|
||||
"Tab 0: Exercise page for keyboard and screen reader use | Tab 1 (current): Google"
|
||||
)
|
||||
|
||||
obs, *_ = env.step(create_id_based_action("tab_focus [0]"))
|
||||
|
||||
assert obs["text"].startswith( # type: ignore[union-attr]
|
||||
"Tab 0 (current): Exercise page for keyboard and screen reader use | Tab 1: Google"
|
||||
)
|
||||
|
||||
|
||||
def test_accessibility_tree_observation_update(
|
||||
accessibility_tree_current_viewport_script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = accessibility_tree_current_viewport_script_browser_env
|
||||
env.reset()
|
||||
obs, *_ = env.step(
|
||||
create_playwright_action(
|
||||
"page.goto('https://russmaxdesign.github.io/exercise/')"
|
||||
)
|
||||
)
|
||||
obs, *_ = env.step(
|
||||
create_playwright_action(
|
||||
'page.get_by_label("Full name").fill("UNIQUE_NAME")'
|
||||
)
|
||||
)
|
||||
assert "UNIQUE_NAME" in obs["text"]
|
||||
29
tests/test_evaluation_harness/configs/func_eval_fail.json
Normal file
29
tests/test_evaluation_harness/configs/func_eval_fail.json
Normal file
@ -0,0 +1,29 @@
|
||||
{
|
||||
"sites": ["shopping"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": null,
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "80",
|
||||
"locator": "func:shopping_get_sku_latest_review_rating('B09BCM56J7')"
|
||||
},
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "cupcakecupcake",
|
||||
"locator": "func:shopping_get_sku_latest_review_author('B09BCM56J7')"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
29
tests/test_evaluation_harness/configs/func_eval_success.json
Normal file
29
tests/test_evaluation_harness/configs/func_eval_success.json
Normal file
@ -0,0 +1,29 @@
|
||||
{
|
||||
"sites": ["shopping"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": null,
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "100",
|
||||
"locator": "func:shopping_get_sku_latest_review_rating('B09BCM56J7')"
|
||||
},
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "cupcakecupcake",
|
||||
"locator": "func:shopping_get_sku_latest_review_author('B09BCM56J7')"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
24
tests/test_evaluation_harness/configs/func_url_func_1.json
Normal file
24
tests/test_evaluation_harness/configs/func_url_func_1.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"sites": ["shopping"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": null,
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "func:reddit_get_post_url('__last_url__')",
|
||||
"locator": "document.querySelector('.submission__inner').outerText",
|
||||
"required_contents": "​"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
29
tests/test_evaluation_harness/configs/func_url_func_2.json
Normal file
29
tests/test_evaluation_harness/configs/func_url_func_2.json
Normal file
@ -0,0 +1,29 @@
|
||||
{
|
||||
"sites": ["shopping"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/gitlab_state.json",
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design/-/project_members",
|
||||
"locator": "func:gitlab_get_project_memeber_role(__page__, 'byteblaze')",
|
||||
"required_contents": "Developer"
|
||||
},
|
||||
{
|
||||
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design/-/project_members",
|
||||
"locator": "func:gitlab_get_project_memeber_role(__page__, 'primer')",
|
||||
"required_contents": "Owner"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,29 @@
|
||||
{
|
||||
"sites": ["gitlab"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/gitlab_state.json",
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "Hello World",
|
||||
"locator": "document.querySelector('[id=\"form-name\"').value"
|
||||
},
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "alexisxy@hotmail.com",
|
||||
"locator": "document.querySelector('[id=\"form-email\"').value"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,34 @@
|
||||
{
|
||||
"sites": ["gitlab"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/gitlab_state.json",
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "Accessible light and dark syntax highlighting themes",
|
||||
"locator": ""
|
||||
},
|
||||
{
|
||||
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design",
|
||||
"required_contents": "Add more deploy triggers",
|
||||
"locator": ""
|
||||
},
|
||||
{
|
||||
"url": "http://metis.lti.cs.cmu.edu:8023/primer/design",
|
||||
"required_contents": "Create MVP React component layout",
|
||||
"locator": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,30 @@
|
||||
{
|
||||
"sites": ["gitlab"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": null,
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["program_html", "url_match"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "https://russmaxdesign.github.io/",
|
||||
"url_note": "GOLD in PRED",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "Hello World",
|
||||
"locator": "document.querySelector('[id=\"form-name\"').value"
|
||||
},
|
||||
{
|
||||
"url": "last",
|
||||
"required_contents": "alexisxy@hotmail.com",
|
||||
"locator": "document.querySelector('[id=\"form-email\"').value"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
tests/test_evaluation_harness/configs/string_match.json
Normal file
25
tests/test_evaluation_harness/configs/string_match.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"sites": ["reddit"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/reddit_state.json",
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["string_match"],
|
||||
"reference_answers": {
|
||||
"must_include": ["1985/04/18"]
|
||||
},
|
||||
"reference_url": "",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
23
tests/test_evaluation_harness/configs/url_exact_match.json
Normal file
23
tests/test_evaluation_harness/configs/url_exact_match.json
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"sites": ["reddit"],
|
||||
"task_id": 0,
|
||||
"require_login": true,
|
||||
"storage_state": "./.auth/reddit_state.json",
|
||||
"start_url": null,
|
||||
"geolocation": null,
|
||||
"intent_template": "",
|
||||
"instantiation_dict": {},
|
||||
"intent": "",
|
||||
"require_reset": false,
|
||||
"eval": {
|
||||
"eval_types": ["url_match"],
|
||||
"reference_answers": [],
|
||||
"reference_url": "http://metis.lti.cs.cmu.edu:9999",
|
||||
"program_html": [
|
||||
{
|
||||
"url": "",
|
||||
"required_contents": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
333
tests/test_evaluation_harness/test_exact_evaluators.py
Normal file
333
tests/test_evaluation_harness/test_exact_evaluators.py
Normal file
@ -0,0 +1,333 @@
|
||||
import random
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from beartype import beartype
|
||||
from py import test
|
||||
|
||||
from agent import Agent, TeacherForcingAgent
|
||||
from browser_env import ActionTypes, ScriptBrowserEnv
|
||||
from browser_env.env_config import *
|
||||
from evaluation_harness import (
|
||||
HTMLContentExactEvaluator,
|
||||
StringEvaluator,
|
||||
URLExactEvaluator,
|
||||
)
|
||||
from evaluation_harness.evaluators import EvaluatorComb
|
||||
|
||||
HEADLESS = True
|
||||
config_file_folder = "tests/test_evaluation_harness/configs"
|
||||
|
||||
|
||||
def tf_roll_out(
|
||||
agent: Agent, env: ScriptBrowserEnv, config_file: str
|
||||
) -> list[Any]:
|
||||
"""Roll out the agent using teacher forcing actions"""
|
||||
obs, state_info = env.reset(options={"config_file": config_file})
|
||||
|
||||
trajectory: list[Any] = [{"observation": obs, "info": state_info}]
|
||||
while True:
|
||||
action = agent.next_action(
|
||||
trajectory=trajectory, intent="", meta_data={}
|
||||
)
|
||||
trajectory.append(action)
|
||||
if action["action_type"] == ActionTypes.STOP:
|
||||
break
|
||||
|
||||
# preceed to next action
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
state_info = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
return trajectory
|
||||
|
||||
|
||||
def test_string_match_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/string_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = """page.stop("The date is 1985/04/18")"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = StringEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_string_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
config_file = f"{config_file_folder}/string_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = """page.stop("The date is 1936/04/18")"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = StringEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
def test_url_exact_match_success(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
config_file = f"{config_file_folder}/url_exact_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("{REDDIT}")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = URLExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_url_exact_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
config_file = f"{config_file_folder}/url_exact_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("{GITLAB}")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = URLExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
print(env.page.url)
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
def test_html_content_match_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/html_content_exact_match.json"
|
||||
|
||||
# randomly sample a string
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("{GITLAB}")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_html_content_match_fail(script_browser_env: ScriptBrowserEnv) -> None:
|
||||
config_file = f"{config_file_folder}/html_content_exact_match.json"
|
||||
|
||||
# randomly sample a string
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = """page.goto("https://russmaxdesign.github.io/exercise")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
def test_html_content_element_match_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/html_content_element_exact_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
|
||||
page.get_by_label("Full name").fill("Hello World")
|
||||
page.get_by_label("Email").click()
|
||||
page.get_by_label("Email").fill("alexisxy@hotmail.com")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_html_content_element_match_fail(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/html_content_element_exact_match.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
|
||||
page.get_by_label("Full name").fill("Hello")
|
||||
page.get_by_label("Email").click()
|
||||
page.get_by_label("Email").fill("alexisxy@hotmail.com")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
def test_html_content_url_comb_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/html_content_url_comb.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
|
||||
page.get_by_label("Full name").fill("Hello World")
|
||||
page.get_by_label("Email").click()
|
||||
page.get_by_label("Email").fill("alexisxy@hotmail.com")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evaluators = EvaluatorComb(
|
||||
[URLExactEvaluator(), HTMLContentExactEvaluator()]
|
||||
)
|
||||
score = evaluators(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/func_eval_success.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_fail(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/func_eval_fail.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("https://russmaxdesign.github.io/exercise/")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_url_func_last_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/func_url_func_1.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.goto("{REDDIT}/f/wallstreetbets/50431/-/comment/676875")
|
||||
page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
@beartype
|
||||
def test_func_url_func_page_success(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
config_file = f"{config_file_folder}/func_url_func_2.json"
|
||||
|
||||
agent = TeacherForcingAgent()
|
||||
agent.set_action_set_tag(tag="playwright")
|
||||
action_seq = f"""page.stop()"""
|
||||
agent.set_actions(action_seq)
|
||||
|
||||
env = script_browser_env
|
||||
trajectory = tf_roll_out(agent, env, config_file)
|
||||
|
||||
evalutor = HTMLContentExactEvaluator()
|
||||
score = evalutor(
|
||||
trajectory, config_file, env.page, env.get_page_client(env.page)
|
||||
)
|
||||
assert score == 1.0
|
||||
33
tests/test_evaluation_harness/test_helper_functions.py
Normal file
33
tests/test_evaluation_harness/test_helper_functions.py
Normal file
@ -0,0 +1,33 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from beartype import beartype
|
||||
|
||||
from browser_env import ScriptBrowserEnv
|
||||
from browser_env.env_config import *
|
||||
from evaluation_harness.helper_functions import (
|
||||
gitlab_get_project_memeber_role,
|
||||
)
|
||||
|
||||
HEADLESS = True
|
||||
config_file_folder = "tests/test_evaluation_harness/configs"
|
||||
|
||||
|
||||
def test_gitlab_get_project_memeber_role(
|
||||
script_browser_env: ScriptBrowserEnv,
|
||||
) -> None:
|
||||
env = script_browser_env
|
||||
config_file = f"{config_file_folder}/tmp_config.json"
|
||||
|
||||
with open(config_file, "w") as f:
|
||||
json.dump({"storage_state": ".auth/gitlab_state.json"}, f)
|
||||
env.reset(options={"config_file": config_file})
|
||||
env.page.goto(f"{GITLAB}/primer/design/-/project_members")
|
||||
role1 = gitlab_get_project_memeber_role(env.page, "byteblaze")
|
||||
assert role1 == "Developer"
|
||||
role2 = gitlab_get_project_memeber_role(env.page, "primer")
|
||||
assert role2 == "Owner"
|
||||
|
||||
# remove tmp config file
|
||||
os.remove(config_file)
|
||||
Loading…
Reference in New Issue
Block a user