class BaseRenderer(ABC, Generic[_T]):
@classmethod
@abstractmethod
def from_config(
cls,
config: "VllmConfig",
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
raise NotImplementedError
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
super().__init__()
self.config = config
self.model_config = config.model_config
self.tokenizer = tokenizer
# Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
config.observability_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
if self._async_tokenizer is None:
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
return self._async_tokenizer
def get_mm_processor(self) -> "BaseMultiModalProcessor":
if self.mm_processor is None:
raise ValueError("Multi-modal processor not available for text-only models")
return self.mm_processor
@property
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
if self.mm_processor is None:
return None
return self.mm_processor.cache
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self._mm_cache_stats
if mm_cache_stats is None:
return None
self._mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def update_mm_cache_stats(self) -> None:
mm_processor_cache = self.mm_processor_cache
mm_cache_stats = self._mm_cache_stats
if mm_processor_cache and mm_cache_stats:
delta = mm_processor_cache.make_stats(delta=True)
mm_cache_stats.record(delta.total, delta.hits)
def clear_mm_cache(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.clear_cache()
if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True
def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.close()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
def get_dec_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder model,
raising an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
@cached_property
def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_chat_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
)
# Step 1: Convert raw inputs to prompts
def render_prompt(
self,
prompt: DictPrompt | bytes,
) -> DictPrompt:
if isinstance(prompt, bytes):
embeds = safe_load_prompt_embeds(self.model_config, prompt)
prompt = EmbedsPrompt(prompt_embeds=embeds)
return prompt
def render_prompts(
self,
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
if len(prompts) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_prompt(prompt) for prompt in prompts]
async def render_prompts_async(
self,
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
return self.render_prompts(prompts)
@abstractmethod
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], DictPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], DictPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def _tokenize_prompt(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
async def _tokenize_prompt_async(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
return prompt
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
return prompt
@overload
def _tokenize_singleton_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def _tokenize_singleton_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
def _tokenize_singleton_prompt(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
@overload
async def _tokenize_singleton_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
async def _tokenize_singleton_prompt_async(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
async def _tokenize_enc_dec_prompt_async(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self._tokenize_singleton_prompt_async(
prompt["decoder_prompt"], params
)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
def tokenize_prompt(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
return self._tokenize_singleton_prompt(prompt, params)
def tokenize_prompts(
self,
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
async def tokenize_prompt_async(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
return await self._tokenize_singleton_prompt_async(prompt, params)
async def tokenize_prompts_async(
self,
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)
# Step 3: Add extra keys to the prompts
def _apply_prompt_extras(
self,
prompts: Sequence[TokPrompt],
prompt_extras: dict[str, Any] | None,
):
if not prompt_extras:
return
for prompt in prompts:
target_prompt = extract_target_prompt(self.model_config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
) -> None:
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in
# `mm_data_items`
modalities = mm_data.keys() | mm_uuid_items.keys()
for modality in modalities:
data_items = mm_data_items.get(modality)
uuid_items = mm_uuid_items.get(modality)
if data_items is None:
if uuid_items is None:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
elif uuid_items is not None:
if len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None and uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
mm_req_id: str,
):
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
if (
model_config.multimodal_config
and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching
):
mm_uuid_items = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_data_items.get_all_counts().items()
}
self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
return mm_uuid_items
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_uuids: "MultiModalUUIDDict | None",
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing.context import set_request_id
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuids = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id
)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
self.update_mm_cache_stats()
return mm_inputs
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = prompt["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
return engine_prompt
# Top-level methods
def render_cmpl(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
dict_prompts = self.render_prompts(prompts)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
async def render_cmpl_async(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
dict_prompts = await self.render_prompts_async(prompts)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
def render_chat(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in rendered:
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
async def render_chat_async(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages_async(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in await asyncio.gather(*rendered):
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts