-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Describe the bug
I have a fine‑tuned wav2vec 2.0 model paired with a KenLM language model that delivers good performance.
To make better use of the GPU I switched from single‑utterance inference to batch inference. While the batch approach improves speed, the transcriptions it generates differ from those obtained with the classic (per‑utterance) inference, even though the same model and LM are used.
Is this discrepancy expected, or does it indicate a problem with my implementation?
Expected behaviour
single‑utterance and per batch inferences generate the same transcriptions.
To Reproduce
Inference on two audio files (4.47s and 6.69s)
Script:
import multiprocessing
import torch
import torchaudio
import tqdm
import pandas as pd
import numpy as np
from jiwer import wer
from tinytag import TinyTag
from pyctcdecode import build_ctcdecoder, BeamSearchDecoderCTC
from speechbrain.dataio.batch import PaddedBatch
from speechbrain.inference.ASR import EncoderASR
# Common objects
class OutputBeam():
TRANSCRIPT = 0
LM_STATE = 1
WORD_LOGIT_MATRIX = 2
LOGIT_SCORE = 3
LM_SCORE = 4
class OutputBeamMPSafe():
TRANSCRIPT = 0
WORD_LOGIT_MATRIX = 1
LOGIT_SCORE = 2
LM_SCORE = 3
# Make some general configuration
device = "cuda:"+format(0) if torch.cuda.is_available() else 'cpu'
torch.set_num_threads(56)
with torch.no_grad():
# Load the accoustic model
asr_model : EncoderASR = EncoderASR.from_hparams(
source="/data/model/v1.0.2/",
savedir="/data/model/v1.0.2/",
hparams_file="hyperparams_inference.yaml",
run_opts={"device": device})
# Load the language model
## Labels from the ASR engine
labels = [asr_model.tokenizer.id_to_piece(id) for id in range(asr_model.tokenizer.get_piece_size())]
## It's important to replace the <unk> token or the lm will be dysfunctional !
labels[0] = '<pad>'
language_model : BeamSearchDecoderCTC = build_ctcdecoder(labels, "/data//lm/lm_v1.0.2.arpa", alpha=1.0,beta=0.6)
# Load data
data_df = pd.DataFrame([{"wav":"/data/audio/1.wav","wrd":"some words for a transcription test"},
{"wav":"/data/audio/2.wav","wrd":"Is the test successful or not"}])
gnd_truth_full = []
transcripts_full = []
######################
# Loop #
######################
with tqdm.tqdm(total=len(data_df)) as pbar:
for index, row in data_df.iterrows():
# Perform inference
with open(row["wav"], 'rb') as f:
audio = f.read()
audio = np.frombuffer(audio, np.int16)
# Convert the raw audio to tensor
audio = torch.tensor(audio[:,np.newaxis], dtype=torch.float32, device=device)
wave = asr_model.audio_normalizer(audio, 16000)
batch : torch.Tensor = wave.unsqueeze(0)
# Process the transcription
t_len = torch.tensor([1.0])
encoder_out : torch.Tensor = asr_model.encode_batch(batch, t_len)
logits : np.ndarray = encoder_out[0].detach().cpu().clone().numpy()
beams = language_model.decode_beams(logits)
transcript = beams[0][OutputBeam.TRANSCRIPT]
del audio,wave,batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Store transcriot and associated inference, to compute final WER/CER
gnd_truth_full.append(row["wrd"])
transcripts_full.append(transcript)
pbar.update(1)
print(f">>> Loop WER: {wer(gnd_truth_full, transcripts_full)}")
######################
# Batch #
######################
gnd_truth_full = []
transcripts_full = []
## Sort data by audio duration
durations = []
for index, row in data_df.iterrows():
file_path = row["wav"]
durations.append(TinyTag.get(file_path).duration)
data_df["duration"] = durations
data_df = data_df.sort_values(by=["duration"])
wavs = []
res = []
batch_duration = 0.0
longest_duration = None
# Creating sub batches to optimize GPU memory utilization (~350s for 40Go)
for index, row in data_df.iterrows():
wav, sr = torchaudio.load(row['wav'])
wav_len = wav.shape[1]
duration = wav_len/sr
if longest_duration is None:
longest_duration = duration
batch_duration += duration
wavs.append({"wav" : wav.to(device).squeeze(), "duration": duration})
if batch_duration > 350.0 or index == data_df.index[-1]:
padded_batch = PaddedBatch(wavs, padded_keys=["wav"])
# Processing
encoder_out : torch.Tensor = asr_model.encode_batch(padded_batch.wav.data, padded_batch.wav.lengths)
logits_list = encoder_out.detach().cpu().clone().numpy()
del encoder_out
with multiprocessing.get_context("fork").Pool(56) as pool:
beams_list = language_model.decode_beams_batch(pool, logits_list)
for beams in beams_list:
beam = beams[0]
res.append(beam[OutputBeamMPSafe.TRANSCRIPT])
if torch.cuda.is_available():
torch.cuda.empty_cache()
wavs = []
batch_duration = 0.0
longest_duration = None
for idx, (index, row) in enumerate(data_df.iterrows()):
gnd_truth_full.append(row["wrd"])
transcripts_full.append(res[idx])
print(f">>> Batch WER: {wer(gnd_truth_full, transcripts_full)}")
hyperparams_inference.yaml
# ################################
# Model: wav2vec2 + DNN + CTC/Attention
# Augmentation: SpecAugment
# Authors: Titouan Parcollet 2021
# ################################
sample_rate: 16000
wav2vec2_hub: LeBenchmark/wav2vec2-FR-7K-large
wav2vec2_folder: /data//wav2vec2_checkpoints/
# BPE parameters
token_type: bpe
character_coverage: 1.0
# Model parameters
# activation: !name:torch.nn.LeakyReLU
dnn_layers: 2
dnn_neurons: 1024
emb_size: 128
dec_neurons: 1024
freeze_feature_extractor: false
dropout: 0.1
warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
# Outputs
output_neurons: 600
# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
blank_index: 0
bos_index: 1
eos_index: 2
enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, 1024]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
bias: true
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: 0.15
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
bias: true
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: 0.15
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
bias: true
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: true
freeze: true
freeze_feature_extractor: false
save_path: !ref <wav2vec2_folder>
ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: true
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
asr_model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]
tokenizer: !new:sentencepiece.SentencePieceProcessor
encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
decoding_function: !name:speechbrain.decoders.ctc_greedy_decode
blank_id: !ref <blank_index>
modules:
encoder: !ref <encoder>
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
loadables:
wav2vec2: !ref <wav2vec2>
asr: !ref <asr_model>
tokenizer: !ref <tokenizer>
Environment Details
Python 3.9.23
pandas==2.0.3
numpy==1.23.3
pytest==8.3.4
jiwer==3.0.3
soundfile==0.12.1
librosa==0.11.0
setuptools==75.3.2
pydub==0.25.1
jsonschema==3.2.0
unidecode==1.3.8
pyctcdecode==0.4.0
kenlm==0.3.0
transformers==4.51.3
speechbrain==1.0.2
torch==2.2.2
torchvision==0.17.2
torchaudio==2.2.2
pygrammalecte==1.3.0
tinytag>=2.1.2
Relevant Log Output
server-0:~/speechbrain$ uv run loop_vs_batch.py
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
INFO:speechbrain.utils.fetching:Fetch hyperparams_inference.yaml: Using file found at '/data/model/v1.0.2/hyperparams_inference.yaml'
WARNING:speechbrain.lobes.models.huggingface_transformers.huggingface:speechbrain.lobes.models.huggingface_transformers.huggingface - Wav2Vec2Model is frozen.
INFO:speechbrain.utils.fetching:Fetch wav2vec2.ckpt: Using file found at '/data/model/v1.0.2/wav2vec2.ckpt'
INFO:speechbrain.utils.fetching:Fetch asr.ckpt: Using file found at '/data/model/v1.0.2/asr.ckpt'
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Using file found at '/data/model/v1.0.2/tokenizer.ckpt'
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: wav2vec2, asr, tokenizer
Loading the LM will be faster if you build a binary file.
Reading /data/lm/lm_v1.0.2.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
INFO:pyctcdecode.decoder:Using arpa instead of binary LM file, decoder instantiation might be slow.
INFO:pyctcdecode.alphabet:Alphabet determined to be of BPE style.
INFO:pyctcdecode.alphabet:Found <pad> in vocabulary, substituting with .
WARNING:pyctcdecode.alphabet:UNK token ▁⁇▁ not found, is this a mistake?
WARNING:pyctcdecode.alphabet:Unigrams and labels don't seem to agree.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5.32it/s]
>>> Loop WER: 0.0
>>> Batch WER: 0.025Additional Context
No response