How can we make transcribe() produce shorter segments?
The segment length doesn’t seem configurable. Not that there’s a specific length sought after, just… shorter, please.
People are out there playing with the code in decoding.py (around line 440)" until they find something that works for them, but this is a bit inscrutible of a process.
Meanwhile, subtitles produced with this method end up throwing 20 seconds worth of words on the screen all at once, rather than a few at a time as they are spoken
Karaokes produced with this method don’t align with the singing because the segment lengths are so long that it will overflow the screen
A bunch of people are asking for this over on github so i thought i’d mention it here
Getting lyrics matched with singling seems like a thing AI could do, but I’m not sure if a transcription model is the solution.
Transcription does not really require super accurate syllable level timing information, it actually sounds like a different speech to text model is required.
Perhaps there are some tricks that can be used to get close to what you want but I’m with everything else going on I’m not sure where it would end up priority wise.
My apologies, if you have a compliable instance running perhaps you could try this replacement function which introduces a new optional parameter of line length to supress the generation of new tokens between time stamps.
class ApplyTimestampRules(LogitFilter):
def __init__(
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
max_line_length: Optional[int] = None, # new parameter to control line length
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
self.max_line_length = max_line_length # store the parameter
self.line_length = 0 # initialize a counter to keep track of the line length
def apply(self, logits: Tensor, tokens: Tensor):
# suppress which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
self.line_length = 0 # reset the line length counter
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
else:
self.line_length += 1 # increment the line length counter
if self.max_line_length is not None and self.line_length >= self.max_line_length:
# suppress the generation of further text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
# also force each segment to have a nonzero length, to prevent infinite looping
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin