Whisper ASR Model Skipping Chunks in Audio Transcription

Description: I’m using the Whisper ASR model (openai/whisper-tiny or medium model ) to transcribe audio files. The audio is being split into 30-second chunks with a 5-second overlap to handle long recordings. While the majority of the transcription works as expected, I’ve noticed that some chunks are entirely skipped or only partially transcribed.

Problem Example: For an audio file with the following content in the first 30 seconds:

“Hello Everyone, My name is John Doe.”
The ASR model outputs only “Hello Everyone,” and skips the rest of the chunk. The transcription then continues from the next chunk, but the skipped content is permanently lost. This behavior is inconsistent and doesn’t happen for every chunk.
Code for Reference: Below is the relevant portion of my code:
python

def initialize_whisper_model():
model_name = “openai/whisper-tiny.en”
print(f"Loading Whisper model: {model_name}")
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
return model, processor

Function to split long audio into 30-second chunks with 5-second overlap

def split_audio_into_chunks(waveform, chunk_duration_sec, sample_rate, overlap_sec=5):
chunk_size = chunk_duration_sec * sample_rate
overlap_size = overlap_sec * sample_rate
stride = chunk_size - overlap_size # Calculate the stride

chunks = []
start = 0

while start < waveform.size(0):
    end = min(start + chunk_size, waveform.size(0))  # Ensure end doesn't exceed waveform size
    chunks.append(waveform[start:end])
    start += stride  # Move by the stride (with overlap)

# Debug information about each chunk
for idx, chunk in enumerate(chunks):
    print(f"Chunk {idx + 1}: Start = {idx * stride}, Length = {len(chunk)} samples, Max = {chunk.max()}, Min = {chunk.min()}")

return chunks

Function to transcribe audio chunks using Whisper

def transcribe_audio_with_whisper(waveform, processor, model, sampling_rate=16000):
chunks = split_audio_into_chunks(waveform, 30, sampling_rate)
transcriptions =

for idx, chunk in enumerate(chunks):
    if chunk.size(0) == 0:
        print(f"Chunk {idx+1} is empty, skipping.")
        continue
    print(f"Processing Chunk {idx+1}...")
    inputs = processor(chunk, sampling_rate=sampling_rate, return_tensors="pt", language='en')
    input_features = inputs.input_features
    with torch.no_grad():
        predicted_ids = model.generate(input_features)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)[0]
    print(f"Chunk {idx+1} Transcription: {transcription}")
    unwanted_prefix = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
    unwanted_2 = "<|startoftranscript|><|notimestamps|>"
    if transcription.startswith(unwanted_prefix):
        transcription = transcription[len(unwanted_prefix):]
    elif transcription.startswith(unwanted_2):
        transcription = transcription[len(unwanted_2):]
    transcriptions.append(transcription)

return " ".join(transcriptions)

def load_wav_from_gcs(blob):
# Download the audio file bytes
file_bytes = blob.download_as_bytes()
print(f"Downloaded file: {blob.name}, Size: {len(file_bytes)} bytes")

# Load audio using pydub
audio = AudioSegment.from_wav(io.BytesIO(file_bytes))
print(f"Original Duration: {audio.duration_seconds:.2f} seconds, Frame Rate: {audio.frame_rate} Hz, Channels: {audio.channels}")

# Convert audio to mono if necessary
if audio.channels > 1:
    print("Converting stereo/multichannel audio to mono...")
    audio = audio.set_channels(1)

# Convert audio samples to numpy array and normalize
samples = np.array(audio.get_array_of_samples(), dtype=np.float32) / 2**15  # Normalize to range [-1, 1]
samples = torch.from_numpy(samples)

# Resample audio if necessary
original_sample_rate = audio.frame_rate
target_sample_rate = 16000
if original_sample_rate != target_sample_rate:
    print(f"Resampling from {original_sample_rate} Hz to {target_sample_rate} Hz...")
    resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
    samples = resampler(samples)

# Normalize the audio to [-1, 1] range
samples = samples / samples.abs().max()

print(f"Waveform shape: {samples.shape}, Normalized range: [{samples.min().item()}, {samples.max().item()}]")
return samples, target_sample_rate