branch: master
download_models.py
1462 bytesRaw
import yaml, argparse
from pathlib import Path
from huggingface_hub import snapshot_download

def download_models(yaml_file: str, download_dir: str) -> None:
  with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
  n = len(metadata["repositories"])

  for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
    print(f"Downloading {i+1}/{n}: {model_id}...")
    allow_patterns = [file_info["file"] for file_info in model_data["files"]]
    root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
    # download configs too (the sizes are small)
    snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
    print(f"Downloaded model files to: {root_path}")
    model_data["download_path"] = str(root_path)

  # Save the updated metadata back to the YAML file
  with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
  print("Download completed according to YAML file.")

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
  parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
  args = parser.parse_args()

  models_folder = Path(__file__).parent / "models"
  models_folder.mkdir(parents=True, exist_ok=True)
  download_models(args.input, str(models_folder))