diff --git a/export.py b/export.py index 7fb4366..d909c9f 100644 --- a/export.py +++ b/export.py @@ -280,7 +280,12 @@ def load_checkpoint(checkpoint): def load_hf_model(model_path): - from transformers import AutoModelForCausalLM + try: + from transformers import AutoModelForCausalLM + except ImportError: + print("Error: transformers package is required to load huggingface models") + print("Please run `pip install transformers` to install it") + return None # load HF model hf_model = AutoModelForCausalLM.from_pretrained(model_path) @@ -357,5 +362,8 @@ if __name__ == "__main__": else: parser.error("Input model missing: --checkpoint or --hf is required") + if model is None: + parser.error("Can't load input model!") + # export model_export(model, args.filepath, args.version) diff --git a/requirements.txt b/requirements.txt index b4054e1..7187a73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ sentencepiece==0.1.99 torch==2.0.1 tqdm==4.64.1 wandb==0.15.5 -transformers==4.31.0