Skip to content

vllm.model_executor.models.voxtral_realtime

TimeEmbedding

Bases: Module

Sinusoidal Embedding for encoding time

Source code in vllm/model_executor/models/voxtral_realtime.py
class TimeEmbedding(torch.nn.Module):
    """Sinusoidal Embedding for encoding time"""

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = torch.exp(
            -math.log(self.theta)
            * torch.arange(self.dim // 2).float()
            / (self.dim // 2)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        t = t[..., None]  # (B,) -> (B, 1) or (B, T) -> (B, T, 1)
        inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype)
        emb = (
            t * inv_freq
        )  # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
        return torch.cat((emb.cos(), emb.sin()), dim=-1)  # (B, D) or (B, T, D)

VoxtralRealtimeGeneration

Bases: VoxtralForConditionalGeneration, SupportsRealtime

Source code in vllm/model_executor/models/voxtral_realtime.py
@MULTIMODAL_REGISTRY.register_processor(
    VoxtralRealtimeMultiModalProcessor,
    info=VoxtralProcessingInfo,
    dummy_inputs=VoxtralDummyInputsBuilder,
)
@support_torch_compile
class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtime):
    requires_raw_input_tokens = True
    # transformers' currently has limited support for MistralCommon backend
    # and cached_get_processor. Let's skip until fixed
    skip_warmup_audio_preprocessing = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        assert (
            not vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
        ), "Voxtral realtime doesn't support full cudagraphs yet. Please use PIECEWISE."

        self.time_embedding: TimeEmbedding = TimeEmbedding(
            dim=self.config.text_config.hidden_size
        )

        audio_config = self.tokenizer.instruct.audio_encoder.audio_config
        self.n_delay_tokens = audio_config.get_num_delay_tokens()

    # for realtime transcription
    @classmethod
    async def buffer_realtime_audio(
        cls,
        audio_stream: AsyncGenerator[np.ndarray, None],
        input_stream: asyncio.Queue[list[int]],
        model_config: ModelConfig,
    ) -> AsyncGenerator[PromptType, None]:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio_encoder = tokenizer.instruct.audio_encoder
        config = audio_encoder.audio_config

        # Get prompt tokens (streaming prefix tokens) without encoding audio
        prompt_tokens = (
            tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
        )

        # Get left/right padding audio
        left_pad, right_pad = audio_encoder.get_padding_audio()

        buffer = VoxtralRealtimeBuffer(config, prompt_tokens)

        # Feed audio with padding into buffer in background
        async def feed_audio():
            yielded_first_chunk = False
            async for audio_chunk in audio_stream:
                if not yielded_first_chunk:
                    yielded_first_chunk = True
                    # Prepend left padding before first real audio
                    await buffer.append_audio(left_pad.audio_array)
                await buffer.append_audio(audio_chunk)
            # Append right padding at the end
            await buffer.append_audio(right_pad.audio_array)
            await buffer.append_audio(None)  # signal end

        # Feed output tokens back into buffer in background
        async def feed_tokens():
            while True:
                all_outputs = await asyncio.wait_for(
                    input_stream.get(),
                    timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
                )
                await buffer.append_tokens(all_outputs[-1:])

        audio_task = asyncio.create_task(feed_audio())
        token_task = asyncio.create_task(feed_tokens())

        try:
            async for streaming_input in buffer.get_input_stream():
                yield streaming_input.prompt
        finally:
            audio_task.cancel()
            token_task.cancel()

    @property
    def audio_config(self):
        return self.tokenizer.instruct.audio_encoder.audio_config

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
    ) -> torch.Tensor:
        """Pass post-conv embeddings directly as input.

        For realtime models, multimodal embeddings are required at every
        decode step.  If they are missing (e.g. due to an empty audio
        commit, encoder-cache eviction under GPU memory pressure, or a
        client disconnect), return zero embeddings instead of crashing
        the engine so that all other in-flight requests stay alive.
        """
        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            logger.warning(
                "Realtime model received empty multimodal embeddings "
                "for %d input tokens. Returning zero embeddings to "
                "avoid engine crash.",
                input_ids.shape[0],
            )
            pool_size = self.config.audio_config.block_pool_size
            embed_dim = self.config.audio_config.d_model * pool_size
            return torch.zeros(
                input_ids.shape[0],
                embed_dim,
                dtype=self.whisper_encoder.dtype,
                device=input_ids.device,
            )
        mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
        return mm_embeds_flat

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        assert inputs_embeds is not None
        assert input_ids is not None

        pool_size = self.config.audio_config.block_pool_size
        inputs_embeds = inputs_embeds.view(
            inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
        )

        whisper_positions = _expand_tensor(positions, pool_size)
        audio_hidden_states = self.whisper_encoder.whisper_encoder(
            inputs_embeds, whisper_positions
        )

        num_tokens, audio_hidden_size = audio_hidden_states.shape
        assert num_tokens % self.downsample_factor == 0
        audio_hidden_states = audio_hidden_states.reshape(
            num_tokens // self.downsample_factor,
            audio_hidden_size * self.downsample_factor,
        )
        audio_text_embeds = self.audio_language_adapter(audio_hidden_states)

        text_embeds = self.language_model.embed_input_ids(input_ids)

        # sum pool text and audio embeddings
        inputs_embeds = audio_text_embeds + text_embeds

        time_tensor = torch.full(
            (1,),
            fill_value=self.n_delay_tokens,
            device=inputs_embeds.device,
            dtype=inputs_embeds.dtype,
        )
        t_cond = self.time_embedding(time_tensor)

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
            t_cond=t_cond,
        )

        return hidden_states

    def embed_multimodal(
        self, **kwargs
    ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
        """Transform audio waveforms -> initial whisper post-conv embeddings"""
        audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)

        if audio_inputs is None:
            logger.warning(
                "Realtime model received no audio inputs in "
                "embed_multimodal. Returning empty embeddings."
            )
            return []

        def _truncate_left(
            sample: torch.Tensor, mult_of: int, pos: int
        ) -> torch.Tensor:
            assert pos in [0, 1], pos
            if (ctx := sample.shape[pos] % mult_of) != 0:
                sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
                assert sample.shape[pos] > 0, (
                    f"Sample is empty after truncation with ctx {ctx}"
                )

            return sample

        mel_features = [
            self.whisper_encoder.compute_whisper_melspec(audio).to(
                self.whisper_encoder.dtype
            )
            for audio in audio_inputs
        ]

        # we truncate the left most mel feature
        # if the sequence length in impair
        mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]

        seq_lens = [mel.shape[1] for mel in mel_features]
        # [total_num_20ms_frames, hidden_size]
        audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
            mel_features
        )
        conv_stride = self.whisper_encoder.whisper_encoder.total_stride
        audio_embeddings_per_sample = audio_embeddings.split(
            [s // conv_stride for s in seq_lens], dim=0
        )

        # audio_embeddings per sample need to be divisible by 4
        pool_size = self.config.audio_config.block_pool_size

        audio_embeddings_per_sample = [
            _truncate_left(sample, pool_size, 0)
            for sample in audio_embeddings_per_sample
        ]

        audio_embeddings_per_sample = [
            e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
            for e in audio_embeddings_per_sample
        ]
        return audio_embeddings_per_sample

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio_config = tokenizer.instruct.audio_encoder.audio_config
        sample_rate = audio_config.sampling_rate
        return SpeechToTextConfig(
            max_audio_clip_s=None,  # only limited by memory
            sample_rate=sample_rate,
            min_energy_split_window_size=None,
        )

    @classmethod
    # for speech-to-text transcription
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio = Audio(audio, int(stt_config.sample_rate), format="wav")  # lossless

        req = TranscriptionRequest(
            model=model_config.model,
            audio=RawAudio.from_audio(audio),
            language=language,
            streaming=StreamingMode.OFFLINE,
        )

        tokenized = tokenizer.instruct.encode_transcription(req)

        return TokensPrompt(
            prompt_token_ids=tokenized.tokens,
            multi_modal_data={
                "audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
            },
        )

embed_input_ids

embed_input_ids(
    input_ids: Tensor,
    multimodal_embeddings: MultiModalEmbeddings
    | None = None,
    *,
    is_multimodal: Tensor | None = None,
    handle_oov_mm_token: bool = True,
) -> Tensor

Pass post-conv embeddings directly as input.

For realtime models, multimodal embeddings are required at every decode step. If they are missing (e.g. due to an empty audio commit, encoder-cache eviction under GPU memory pressure, or a client disconnect), return zero embeddings instead of crashing the engine so that all other in-flight requests stay alive.

Source code in vllm/model_executor/models/voxtral_realtime.py
def embed_input_ids(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: MultiModalEmbeddings | None = None,
    *,
    is_multimodal: torch.Tensor | None = None,
    # Multi-modal token ID may exceed vocab size
    handle_oov_mm_token: bool = True,
) -> torch.Tensor:
    """Pass post-conv embeddings directly as input.

    For realtime models, multimodal embeddings are required at every
    decode step.  If they are missing (e.g. due to an empty audio
    commit, encoder-cache eviction under GPU memory pressure, or a
    client disconnect), return zero embeddings instead of crashing
    the engine so that all other in-flight requests stay alive.
    """
    if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
        logger.warning(
            "Realtime model received empty multimodal embeddings "
            "for %d input tokens. Returning zero embeddings to "
            "avoid engine crash.",
            input_ids.shape[0],
        )
        pool_size = self.config.audio_config.block_pool_size
        embed_dim = self.config.audio_config.d_model * pool_size
        return torch.zeros(
            input_ids.shape[0],
            embed_dim,
            dtype=self.whisper_encoder.dtype,
            device=input_ids.device,
        )
    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    return mm_embeds_flat

embed_multimodal

embed_multimodal(
    **kwargs,
) -> list[Tensor] | Tensor | tuple[Tensor, ...] | None

Transform audio waveforms -> initial whisper post-conv embeddings

Source code in vllm/model_executor/models/voxtral_realtime.py
def embed_multimodal(
    self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
    """Transform audio waveforms -> initial whisper post-conv embeddings"""
    audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)

    if audio_inputs is None:
        logger.warning(
            "Realtime model received no audio inputs in "
            "embed_multimodal. Returning empty embeddings."
        )
        return []

    def _truncate_left(
        sample: torch.Tensor, mult_of: int, pos: int
    ) -> torch.Tensor:
        assert pos in [0, 1], pos
        if (ctx := sample.shape[pos] % mult_of) != 0:
            sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
            assert sample.shape[pos] > 0, (
                f"Sample is empty after truncation with ctx {ctx}"
            )

        return sample

    mel_features = [
        self.whisper_encoder.compute_whisper_melspec(audio).to(
            self.whisper_encoder.dtype
        )
        for audio in audio_inputs
    ]

    # we truncate the left most mel feature
    # if the sequence length in impair
    mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]

    seq_lens = [mel.shape[1] for mel in mel_features]
    # [total_num_20ms_frames, hidden_size]
    audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
        mel_features
    )
    conv_stride = self.whisper_encoder.whisper_encoder.total_stride
    audio_embeddings_per_sample = audio_embeddings.split(
        [s // conv_stride for s in seq_lens], dim=0
    )

    # audio_embeddings per sample need to be divisible by 4
    pool_size = self.config.audio_config.block_pool_size

    audio_embeddings_per_sample = [
        _truncate_left(sample, pool_size, 0)
        for sample in audio_embeddings_per_sample
    ]

    audio_embeddings_per_sample = [
        e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
        for e in audio_embeddings_per_sample
    ]
    return audio_embeddings_per_sample