X Tutup
Skip to content

Running inference per batch alters the generated transcript #2986

@Craya

Description

@Craya

Describe the bug

I have a fine‑tuned wav2vec 2.0 model paired with a KenLM language model that delivers good performance.

To make better use of the GPU I switched from single‑utterance inference to batch inference. While the batch approach improves speed, the transcriptions it generates differ from those obtained with the classic (per‑utterance) inference, even though the same model and LM are used. 

Is this discrepancy expected, or does it indicate a problem with my implementation?

Expected behaviour

single‑utterance and per batch inferences generate the same transcriptions.

To Reproduce

Inference on two audio files (4.47s and 6.69s)

Script:

import multiprocessing
import torch
import torchaudio
import tqdm
import pandas as pd
import numpy as np
from jiwer import wer
from tinytag import TinyTag
from pyctcdecode import build_ctcdecoder, BeamSearchDecoderCTC
from speechbrain.dataio.batch import PaddedBatch
from speechbrain.inference.ASR import EncoderASR

# Common objects
class OutputBeam():
  TRANSCRIPT = 0
  LM_STATE = 1
  WORD_LOGIT_MATRIX = 2
  LOGIT_SCORE = 3
  LM_SCORE = 4

class OutputBeamMPSafe():
  TRANSCRIPT = 0
  WORD_LOGIT_MATRIX = 1
  LOGIT_SCORE = 2
  LM_SCORE = 3

# Make some general configuration
device = "cuda:"+format(0) if torch.cuda.is_available() else 'cpu'
torch.set_num_threads(56)

with torch.no_grad():
    # Load the accoustic model
    asr_model : EncoderASR = EncoderASR.from_hparams(
        source="/data/model/v1.0.2/", 
        savedir="/data/model/v1.0.2/",
        hparams_file="hyperparams_inference.yaml",
        run_opts={"device": device})

    # Load the language model

    ## Labels from the ASR engine 
    labels = [asr_model.tokenizer.id_to_piece(id) for id in range(asr_model.tokenizer.get_piece_size())]
    ## It's important to replace the <unk> token or the lm will be dysfunctional !
    labels[0] = '<pad>'

    language_model : BeamSearchDecoderCTC = build_ctcdecoder(labels, "/data//lm/lm_v1.0.2.arpa", alpha=1.0,beta=0.6)

    # Load data
    data_df = pd.DataFrame([{"wav":"/data/audio/1.wav","wrd":"some words for a transcription test"},
                {"wav":"/data/audio/2.wav","wrd":"Is the test successful or not"}])

    gnd_truth_full = []
    transcripts_full = []

    ######################
    #        Loop        #
    ######################
    with tqdm.tqdm(total=len(data_df)) as pbar:
        for index, row in data_df.iterrows(): 
            # Perform inference 
            with open(row["wav"], 'rb') as f:
                audio = f.read()

            audio = np.frombuffer(audio, np.int16)
            # Convert the raw audio to tensor
            audio = torch.tensor(audio[:,np.newaxis], dtype=torch.float32, device=device)
            wave = asr_model.audio_normalizer(audio, 16000)
            batch : torch.Tensor = wave.unsqueeze(0)

            # Process the transcription
            t_len = torch.tensor([1.0])
            encoder_out : torch.Tensor = asr_model.encode_batch(batch, t_len)
            logits : np.ndarray = encoder_out[0].detach().cpu().clone().numpy()
            beams = language_model.decode_beams(logits)

            transcript = beams[0][OutputBeam.TRANSCRIPT]

            del audio,wave,batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Store transcriot and associated inference, to compute final WER/CER 
            gnd_truth_full.append(row["wrd"])
            transcripts_full.append(transcript)
            
            pbar.update(1)

    print(f">>> Loop WER: {wer(gnd_truth_full, transcripts_full)}")

    ######################
    #        Batch       #
    ######################

    gnd_truth_full = []
    transcripts_full = []

    ## Sort data by audio duration
    durations = []
    for index, row in data_df.iterrows():
        file_path = row["wav"]
        durations.append(TinyTag.get(file_path).duration)
    data_df["duration"] = durations
    data_df = data_df.sort_values(by=["duration"])

    wavs = []
    res = []
    batch_duration = 0.0
    longest_duration = None

    # Creating sub batches to optimize GPU memory utilization (~350s for 40Go)
    for index, row in data_df.iterrows():
        wav, sr = torchaudio.load(row['wav'])
        wav_len = wav.shape[1]
        duration = wav_len/sr
        if longest_duration is None:
            longest_duration = duration
        batch_duration += duration
        wavs.append({"wav" : wav.to(device).squeeze(), "duration": duration})
        if batch_duration > 350.0 or index == data_df.index[-1]:
            padded_batch = PaddedBatch(wavs, padded_keys=["wav"])

            # Processing
            encoder_out : torch.Tensor = asr_model.encode_batch(padded_batch.wav.data, padded_batch.wav.lengths)
            logits_list = encoder_out.detach().cpu().clone().numpy()
            del encoder_out

            with multiprocessing.get_context("fork").Pool(56) as pool:
                beams_list = language_model.decode_beams_batch(pool, logits_list)
            for beams in beams_list:
                beam = beams[0]

                res.append(beam[OutputBeamMPSafe.TRANSCRIPT])

            if torch.cuda.is_available():
                torch.cuda.empty_cache()


            wavs = []
            batch_duration = 0.0
            longest_duration = None

    for idx, (index, row) in enumerate(data_df.iterrows()):
        gnd_truth_full.append(row["wrd"])
        transcripts_full.append(res[idx])

    print(f">>> Batch WER: {wer(gnd_truth_full, transcripts_full)}")

hyperparams_inference.yaml

# ################################
# Model: wav2vec2 + DNN + CTC/Attention
# Augmentation: SpecAugment
# Authors: Titouan Parcollet 2021
# ################################

sample_rate: 16000
wav2vec2_hub: LeBenchmark/wav2vec2-FR-7K-large
wav2vec2_folder:  /data//wav2vec2_checkpoints/

# BPE parameters
token_type: bpe
character_coverage: 1.0

# Model parameters
# activation: !name:torch.nn.LeakyReLU
dnn_layers: 2
dnn_neurons: 1024
emb_size: 128
dec_neurons: 1024
freeze_feature_extractor: false
dropout: 0.1
warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps

# Outputs
output_neurons: 600

# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
blank_index: 0
bos_index: 1
eos_index: 2

enc: !new:speechbrain.nnet.containers.Sequential
  input_shape: [null, null, 1024]
  linear1: !name:speechbrain.nnet.linear.Linear
    n_neurons: 1024
    bias: true
  bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
  activation: !new:torch.nn.LeakyReLU
  drop: !new:torch.nn.Dropout
    p: 0.15
  linear2: !name:speechbrain.nnet.linear.Linear
    n_neurons: 1024
    bias: true
  bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
  activation2: !new:torch.nn.LeakyReLU
  drop2: !new:torch.nn.Dropout
    p: 0.15
  linear3: !name:speechbrain.nnet.linear.Linear
    n_neurons: 1024
    bias: true
  bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
  activation3: !new:torch.nn.LeakyReLU

wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
  source: !ref <wav2vec2_hub>
  output_norm: true
  freeze: true
  freeze_feature_extractor: false
  save_path: !ref <wav2vec2_folder>

ctc_lin: !new:speechbrain.nnet.linear.Linear
  input_size: !ref <dnn_neurons>
  n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
  apply_log: true

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
  blank_index: !ref <blank_index>

asr_model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]
tokenizer: !new:sentencepiece.SentencePieceProcessor

encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
  wav2vec2: !ref <wav2vec2>
  enc: !ref <enc>
  ctc_lin: !ref <ctc_lin>

decoding_function: !name:speechbrain.decoders.ctc_greedy_decode
  blank_id: !ref <blank_index>

modules:
  encoder: !ref <encoder>

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
  loadables:
    wav2vec2: !ref <wav2vec2>
    asr: !ref <asr_model>
    tokenizer: !ref <tokenizer>

Environment Details

Python 3.9.23

pandas==2.0.3
numpy==1.23.3
pytest==8.3.4
jiwer==3.0.3
soundfile==0.12.1
librosa==0.11.0
setuptools==75.3.2
pydub==0.25.1
jsonschema==3.2.0
unidecode==1.3.8
pyctcdecode==0.4.0
kenlm==0.3.0
transformers==4.51.3
speechbrain==1.0.2
torch==2.2.2
torchvision==0.17.2
torchaudio==2.2.2
pygrammalecte==1.3.0
tinytag>=2.1.2

Relevant Log Output

server-0:~/speechbrain$ uv run loop_vs_batch.py 
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
INFO:speechbrain.utils.fetching:Fetch hyperparams_inference.yaml: Using file found at '/data/model/v1.0.2/hyperparams_inference.yaml'
WARNING:speechbrain.lobes.models.huggingface_transformers.huggingface:speechbrain.lobes.models.huggingface_transformers.huggingface - Wav2Vec2Model is frozen.
INFO:speechbrain.utils.fetching:Fetch wav2vec2.ckpt: Using file found at '/data/model/v1.0.2/wav2vec2.ckpt'
INFO:speechbrain.utils.fetching:Fetch asr.ckpt: Using file found at '/data/model/v1.0.2/asr.ckpt'
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Using file found at '/data/model/v1.0.2/tokenizer.ckpt'
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: wav2vec2, asr, tokenizer
Loading the LM will be faster if you build a binary file.
Reading /data/lm/lm_v1.0.2.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
INFO:pyctcdecode.decoder:Using arpa instead of binary LM file, decoder instantiation might be slow.
INFO:pyctcdecode.alphabet:Alphabet determined to be of BPE style.
INFO:pyctcdecode.alphabet:Found <pad> in vocabulary, substituting with .
WARNING:pyctcdecode.alphabet:UNK token ▁⁇▁ not found, is this a mistake?
WARNING:pyctcdecode.alphabet:Unigrams and labels don't seem to agree.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.32it/s]
>>> Loop WER: 0.0
>>> Batch WER: 0.025

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      X Tutup