mirror of
https://github.com/web-arena-x/webarena.git
synced 2026-02-06 11:16:53 +00:00
fix type errors
This commit is contained in:
parent
7a1f8d6f18
commit
9f0900f506
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
||||
mypy --version
|
||||
# Run this mypy instance against our main package.
|
||||
mypy --install-types --non-interactive .
|
||||
mypy --strict .
|
||||
mypy --strict . --exclude scripts
|
||||
- name: Enviroment prepare
|
||||
run: |
|
||||
bash prepare.sh
|
||||
|
||||
@ -9,9 +9,7 @@ import urllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import evaluate # type: ignore[import]
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
from nltk.tokenize import word_tokenize # type: ignore
|
||||
from playwright.sync_api import CDPSession, Page
|
||||
|
||||
@ -96,7 +94,7 @@ class StringEvaluator(Evaluator):
|
||||
|
||||
@staticmethod
|
||||
@beartype
|
||||
def must_include(ref: str, pred: str, tokenize=False) -> float:
|
||||
def must_include(ref: str, pred: str, tokenize: bool = False) -> float:
|
||||
clean_ref = StringEvaluator.clean_answer(ref)
|
||||
clean_pred = StringEvaluator.clean_answer(pred)
|
||||
# tokenize the answer if the ref is a single word
|
||||
@ -180,7 +178,7 @@ class URLEvaluator(Evaluator):
|
||||
|
||||
def parse_urls(
|
||||
urls: list[str],
|
||||
) -> tuple[list[str], list[str], dict[str, set[str]]]:
|
||||
) -> tuple[list[str], dict[str, set[str]]]:
|
||||
"""Parse a list of URLs."""
|
||||
base_paths = []
|
||||
queries = collections.defaultdict(set)
|
||||
@ -324,8 +322,8 @@ class EvaluatorComb:
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
config_file: Path | str,
|
||||
page: Page | PseudoPage | None = None,
|
||||
client: CDPSession | None = None,
|
||||
page: Page | PseudoPage,
|
||||
client: CDPSession,
|
||||
) -> float:
|
||||
|
||||
score = 1.0
|
||||
|
||||
@ -178,7 +178,7 @@ class PseudoPage:
|
||||
self.url = url
|
||||
self.original_page = original_page
|
||||
|
||||
def __getattr__(self, attr: str) -> any:
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
# Delegate attribute access to the original page object
|
||||
if attr not in ["url"]:
|
||||
return getattr(self.original_page, attr)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from text_generation import Client
|
||||
from text_generation import Client # type: ignore
|
||||
|
||||
|
||||
def generate_from_huggingface_completion(
|
||||
@ -10,7 +10,7 @@ def generate_from_huggingface_completion(
|
||||
stop_sequences: list[str] | None = None,
|
||||
) -> str:
|
||||
client = Client(model_endpoint, timeout=60)
|
||||
generation = client.generate(
|
||||
generation: str = client.generate(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
||||
@ -20,7 +20,7 @@ def retry_with_exponential_backoff( # type: ignore
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 3,
|
||||
errors: tuple[Any] = (openai.error.RateLimitError),
|
||||
errors: tuple[Any] = (openai.error.RateLimitError,),
|
||||
):
|
||||
"""Retry a function with exponential backoff."""
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers import LlamaTokenizer # type: ignore
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
@ -11,9 +11,9 @@ class Tokenizer(object):
|
||||
elif provider == "huggingface":
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||
# turn off adding special tokens automatically
|
||||
self.tokenizer.add_special_tokens = False
|
||||
self.tokenizer.add_bos_token = False
|
||||
self.tokenizer.add_eos_token = False
|
||||
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
|
||||
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
|
||||
self.tokenizer.add_eos_token = False # type: ignore[attr-defined]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -13,10 +13,12 @@ APIInput = str | list[Any] | dict[str, Any]
|
||||
|
||||
def call_llm(
|
||||
lm_config: lm_config.LMConfig,
|
||||
prompt: list[Any] | str,
|
||||
) -> APIInput:
|
||||
prompt: APIInput,
|
||||
) -> str:
|
||||
response: str
|
||||
if lm_config.provider == "openai":
|
||||
if lm_config.mode == "chat":
|
||||
assert isinstance(prompt, list)
|
||||
response = generate_from_openai_chat_completion(
|
||||
messages=prompt,
|
||||
model=lm_config.model,
|
||||
@ -27,6 +29,7 @@ def call_llm(
|
||||
stop_token=None,
|
||||
)
|
||||
elif lm_config.mode == "completion":
|
||||
assert isinstance(prompt, str)
|
||||
response = generate_from_openai_completion(
|
||||
prompt=prompt,
|
||||
engine=lm_config.model,
|
||||
@ -40,6 +43,7 @@ def call_llm(
|
||||
f"OpenAI models do not support mode {lm_config.mode}"
|
||||
)
|
||||
elif lm_config.provider == "huggingface":
|
||||
assert isinstance(prompt, str)
|
||||
response = generate_from_huggingface_completion(
|
||||
prompt=prompt,
|
||||
model_endpoint=lm_config.gen_config["model_endpoint"],
|
||||
|
||||
@ -9,3 +9,5 @@ aiolimiter
|
||||
beartype==0.12.0
|
||||
flask
|
||||
nltk
|
||||
text-generation
|
||||
transformers
|
||||
|
||||
@ -20,7 +20,7 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str:
|
||||
with open(file.strip(), "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
cur_log = []
|
||||
cur_log: list[str] = []
|
||||
index = None
|
||||
for line in lines:
|
||||
if "[Config file]" in line:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user