branch: master
external_test_whisper_librispeech.py
2726 bytesRaw
import unittest
import torch
import tqdm
import torchaudio
import pathlib
import jiwer
import os
import numpy as np
from whisper.normalizers import EnglishTextNormalizer
from examples.whisper import init_whisper, transcribe_waveform

class TestWhisperLibriSpeech(unittest.TestCase):
  # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
  # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
  # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
  def test_en_tiny(self):
    run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)

  def test_tiny(self):
    run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)

  def test_en_base(self):
    run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)

  def test_en_small(self):
    run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)

def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
  dataset = LibriSpeech()
  batch_size=16
  loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

  model, enc = init_whisper(model_name, batch_size=batch_size)

  hypotheses = []
  references = []

  for audio, texts in tqdm.tqdm(loader):
    transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
    hypotheses.extend(transcriptions)
    references.extend(texts)

  normalizer = EnglishTextNormalizer()
  normalized_hypotheses = [normalizer(text) for text in hypotheses]
  normalized_references = [normalizer(text) for text in references]
  wer = jiwer.wer(normalized_hypotheses, normalized_references)

  np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
  print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
  del model, enc

class LibriSpeech(torch.utils.data.Dataset):
  def __init__(self):
    folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
    if not os.path.exists(folder):
      os.makedirs(folder)

    self.dataset = torchaudio.datasets.LIBRISPEECH(
      root=folder,
      url="test-clean",
      download=True,
    )

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, item):
    audio, sample_rate, text, _, _, _ = self.dataset[item]
    assert sample_rate == 16000
    return pad_or_trim_tensor(audio[0]), text

def pad_or_trim_tensor(tensor, target_len=480000):
  curr_len = len(tensor)
  if curr_len == target_len:
    return tensor
  elif curr_len < target_len:
    return torch.cat((tensor, torch.zeros(target_len - curr_len)))
  else:
    return tensor[:target_len]


if __name__ == '__main__':
  unittest.main()