Transcribe() lines are WAY too long -- for both subtitles and karaoke

How can we make transcribe() produce shorter segments?

The segment length doesn’t seem configurable. Not that there’s a specific length sought after, just… shorter, please.

People are out there playing with the code in decoding.py (around line 440)" until they find something that works for them, but this is a bit inscrutible of a process.

Meanwhile, subtitles produced with this method end up throwing 20 seconds worth of words on the screen all at once, rather than a few at a time as they are spoken

Karaokes produced with this method don’t align with the singing because the segment lengths are so long that it will overflow the screen

A bunch of people are asking for this over on github so i thought i’d mention it here

Getting lyrics matched with singling seems like a thing AI could do, but I’m not sure if a transcription model is the solution.

Transcription does not really require super accurate syllable level timing information, it actually sounds like a different speech to text model is required.

Perhaps there are some tricks that can be used to get close to what you want but I’m with everything else going on I’m not sure where it would end up priority wise.

With all due respect, your comment isn’t at all helpful.

whisper is already outputting subtitle files with timestamps.

People are asking for shorter line length

That has literally nothing to do with per-syllable timing.

My apologies, if you have a compliable instance running perhaps you could try this replacement function which introduces a new optional parameter of line length to supress the generation of new tokens between time stamps.

class ApplyTimestampRules(LogitFilter):
    def __init__(
        self,
        tokenizer: Tokenizer,
        sample_begin: int,
        max_initial_timestamp_index: Optional[int],
        max_line_length: Optional[int] = None,  # new parameter to control line length
    ):
        self.tokenizer = tokenizer
        self.sample_begin = sample_begin
        self.max_initial_timestamp_index = max_initial_timestamp_index
        self.max_line_length = max_line_length  # store the parameter
        self.line_length = 0  # initialize a counter to keep track of the line length

    def apply(self, logits: Tensor, tokens: Tensor):
        # suppress  which is handled by without_timestamps
        if self.tokenizer.no_timestamps is not None:
            logits[:, self.tokenizer.no_timestamps] = -np.inf

        # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
        for k in range(tokens.shape[0]):
            sampled_tokens = tokens[k, self.sample_begin :]
            seq = [t for t in sampled_tokens.tolist()]
            last_was_timestamp = (
                len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
            )
            penultimate_was_timestamp = (
                len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
            )

            if last_was_timestamp:
                self.line_length = 0  # reset the line length counter
                if penultimate_was_timestamp:  # has to be non-timestamp
                    logits[k, self.tokenizer.timestamp_begin :] = -np.inf
                else:  # cannot be normal text tokens
                    logits[k, : self.tokenizer.eot] = -np.inf
            else:
                self.line_length += 1  # increment the line length counter
                if self.max_line_length is not None and self.line_length >= self.max_line_length:
                    # suppress the generation of further text tokens
                    logits[k, : self.tokenizer.eot] = -np.inf

            timestamps = sampled_tokens[
                sampled_tokens.ge(self.tokenizer.timestamp_begin)
            ]
            if timestamps.numel() > 0:
                # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
                # also force each segment to have a nonzero length, to prevent infinite looping
                if last_was_timestamp and not penultimate_was_timestamp:
                    timestamp_last = timestamps[-1]
                else:
                    timestamp_last = timestamps[-1] + 1
                logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf

        if tokens.shape[1] == self.sample_begin:
            # suppress generating non-timestamp tokens at the beginning
            logits[:, : self.tokenizer.timestamp_begin] = -np.inf

            # apply the `max_initial_timestamp` option
            if self.max_initial_timestamp_index is not None:
                last_allowed = (
                    self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
                )
                logits[:, last_allowed + 1 :] = -np.inf

        # if sum of probability over timestamps is above any other token, sample timestamp
        logprobs = F.log_softmax(logits.float(), dim=-1)
        for k in range(tokens.shape[0]):
            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
                dim=-1
            )
            max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
            if timestamp_logprob > max_text_token_logprob:
                logits[k, : self.tokenizer.timestamp_begin
1 Like

Hi @cliocjs

Are you self-hosting the model using the GitHub repo, or using the Audio API?

1 Like

I’m self-hosting the model using the GitHub repo

This piece of code is working for me.

import whisper
from whisper.utils import get_writer 

audio = './audio.mp3'
model = whisper.load_model(model='small')
result = model.transcribe(audio=audio, language='en', word_timestamps=True, task="transcribe")

# Set VTT Line and words width
word_options = {
    "highlight_words": False,
    "max_line_count": 1,
    "max_line_width": 42
}
vtt_writer = get_writer(output_format='vtt', output_dir='./')
vtt_writer(result, audio, word_options)