fix type errors

This commit is contained in:
alexisxy 2023-10-21 00:20:30 -04:00
parent 7a1f8d6f18
commit 9f0900f506
9 changed files with 22 additions and 18 deletions

View File

@ -34,7 +34,7 @@ jobs:
mypy --version
# Run this mypy instance against our main package.
mypy --install-types --non-interactive .
mypy --strict .
mypy --strict . --exclude scripts
- name: Enviroment prepare
run: |
bash prepare.sh

View File

@ -9,9 +9,7 @@ import urllib
from pathlib import Path
from typing import Any, Tuple, Union
import evaluate # type: ignore[import]
from beartype import beartype
from beartype.door import is_bearable
from nltk.tokenize import word_tokenize # type: ignore
from playwright.sync_api import CDPSession, Page
@ -96,7 +94,7 @@ class StringEvaluator(Evaluator):
@staticmethod
@beartype
def must_include(ref: str, pred: str, tokenize=False) -> float:
def must_include(ref: str, pred: str, tokenize: bool = False) -> float:
clean_ref = StringEvaluator.clean_answer(ref)
clean_pred = StringEvaluator.clean_answer(pred)
# tokenize the answer if the ref is a single word
@ -180,7 +178,7 @@ class URLEvaluator(Evaluator):
def parse_urls(
urls: list[str],
) -> tuple[list[str], list[str], dict[str, set[str]]]:
) -> tuple[list[str], dict[str, set[str]]]:
"""Parse a list of URLs."""
base_paths = []
queries = collections.defaultdict(set)
@ -324,8 +322,8 @@ class EvaluatorComb:
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage | None = None,
client: CDPSession | None = None,
page: Page | PseudoPage,
client: CDPSession,
) -> float:
score = 1.0

View File

@ -178,7 +178,7 @@ class PseudoPage:
self.url = url
self.original_page = original_page
def __getattr__(self, attr: str) -> any:
def __getattr__(self, attr: str) -> Any:
# Delegate attribute access to the original page object
if attr not in ["url"]:
return getattr(self.original_page, attr)

View File

@ -1,4 +1,4 @@
from text_generation import Client
from text_generation import Client # type: ignore
def generate_from_huggingface_completion(
@ -10,7 +10,7 @@ def generate_from_huggingface_completion(
stop_sequences: list[str] | None = None,
) -> str:
client = Client(model_endpoint, timeout=60)
generation = client.generate(
generation: str = client.generate(
prompt=prompt,
temperature=temperature,
top_p=top_p,

View File

@ -20,7 +20,7 @@ def retry_with_exponential_backoff( # type: ignore
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 3,
errors: tuple[Any] = (openai.error.RateLimitError),
errors: tuple[Any] = (openai.error.RateLimitError,),
):
"""Retry a function with exponential backoff."""

View File

@ -1,7 +1,7 @@
from typing import Any
import tiktoken
from transformers import LlamaTokenizer
from transformers import LlamaTokenizer # type: ignore
class Tokenizer(object):
@ -11,9 +11,9 @@ class Tokenizer(object):
elif provider == "huggingface":
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
# turn off adding special tokens automatically
self.tokenizer.add_special_tokens = False
self.tokenizer.add_bos_token = False
self.tokenizer.add_eos_token = False
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
self.tokenizer.add_eos_token = False # type: ignore[attr-defined]
else:
raise NotImplementedError

View File

@ -13,10 +13,12 @@ APIInput = str | list[Any] | dict[str, Any]
def call_llm(
lm_config: lm_config.LMConfig,
prompt: list[Any] | str,
) -> APIInput:
prompt: APIInput,
) -> str:
response: str
if lm_config.provider == "openai":
if lm_config.mode == "chat":
assert isinstance(prompt, list)
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
@ -27,6 +29,7 @@ def call_llm(
stop_token=None,
)
elif lm_config.mode == "completion":
assert isinstance(prompt, str)
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
@ -40,6 +43,7 @@ def call_llm(
f"OpenAI models do not support mode {lm_config.mode}"
)
elif lm_config.provider == "huggingface":
assert isinstance(prompt, str)
response = generate_from_huggingface_completion(
prompt=prompt,
model_endpoint=lm_config.gen_config["model_endpoint"],

View File

@ -9,3 +9,5 @@ aiolimiter
beartype==0.12.0
flask
nltk
text-generation
transformers

View File

@ -20,7 +20,7 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str:
with open(file.strip(), "r") as f:
lines = f.readlines()
cur_log = []
cur_log: list[str] = []
index = None
for line in lines:
if "[Config file]" in line: