Skip to content

vllm.model_executor.models.mistral

Mistral adaptation of the LLaMA architecture.

MistralAttention

Bases: LlamaAttention

Source code in vllm/model_executor/models/mistral.py
class MistralAttention(LlamaAttention):
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
        quant_config: QuantizationConfig | None = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: CacheConfig | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
        super().__init__(
            config=config,
            hidden_size=hidden_size,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=bias,
            bias_o_proj=bias_o_proj,
            cache_config=cache_config,
            prefix=prefix,
            attn_type=attn_type,
        )

        llama_4_scaling_config: dict[str, int | float | str] | None = getattr(
            config, "llama_4_scaling", None
        )
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            assert llama_4_scaling_config is not None
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

do_llama_4_scaling instance-attribute

do_llama_4_scaling = llama_4_scaling_config is not None

llama_4_scaling_beta instance-attribute

llama_4_scaling_beta = llama_4_scaling_config['beta']

llama_4_scaling_original_max_position_embeddings instance-attribute

llama_4_scaling_original_max_position_embeddings = (
    llama_4_scaling_config[
        "original_max_position_embeddings"
    ]
)

__init__

__init__(
    config: LlamaConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    max_position_embeddings: int = 8192,
    quant_config: QuantizationConfig | None = None,
    bias: bool = False,
    bias_o_proj: bool = False,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
    attn_type: str = DECODER,
) -> None
Source code in vllm/model_executor/models/mistral.py
def __init__(
    self,
    config: LlamaConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    max_position_embeddings: int = 8192,
    quant_config: QuantizationConfig | None = None,
    bias: bool = False,
    bias_o_proj: bool = False,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
    attn_type: str = AttentionType.DECODER,
) -> None:
    super().__init__(
        config=config,
        hidden_size=hidden_size,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        max_position_embeddings=max_position_embeddings,
        quant_config=quant_config,
        bias=bias,
        bias_o_proj=bias_o_proj,
        cache_config=cache_config,
        prefix=prefix,
        attn_type=attn_type,
    )

    llama_4_scaling_config: dict[str, int | float | str] | None = getattr(
        config, "llama_4_scaling", None
    )
    self.do_llama_4_scaling = llama_4_scaling_config is not None
    if self.do_llama_4_scaling:
        assert llama_4_scaling_config is not None
        self.llama_4_scaling_original_max_position_embeddings = (
            llama_4_scaling_config["original_max_position_embeddings"]
        )
        self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

_get_llama_4_attn_scale

_get_llama_4_attn_scale(positions: Tensor) -> Tensor
Source code in vllm/model_executor/models/mistral.py
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
    # Llama4 scaling
    scaling = 1 + self.llama_4_scaling_beta * torch.log(
        1
        + torch.floor(
            positions / self.llama_4_scaling_original_max_position_embeddings
        )
    )
    # Broadcast over head_dim
    return scaling.unsqueeze(-1)

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/mistral.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = self.rotary_emb(positions, q, k)
    if self.do_llama_4_scaling:
        attn_scale = self._get_llama_4_attn_scale(positions)
        q = (q * attn_scale).to(q.dtype)
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

MistralDecoderLayer

Bases: LlamaDecoderLayer

Source code in vllm/model_executor/models/mistral.py
class MistralDecoderLayer(LlamaDecoderLayer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        config: LlamaConfig | None = None,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            config=config,
            attn_layer_type=MistralAttention,
        )

        self.layer_idx = int(prefix.split(sep=".")[-1])
        quant_config = self.get_quant_config(vllm_config)
        config = config or vllm_config.model_config.hf_config

        do_fusion = getattr(
            quant_config, "enable_quantization_scaling_fusion", False
        ) and vllm_config.cache_config.cache_dtype.startswith("fp8")
        if do_fusion:
            self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
            self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj

layer_idx instance-attribute

layer_idx = int(split(sep='.')[-1])

__init__

__init__(
    vllm_config: VllmConfig,
    prefix: str = "",
    config: LlamaConfig | None = None,
) -> None
Source code in vllm/model_executor/models/mistral.py
def __init__(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
    config: LlamaConfig | None = None,
) -> None:
    super().__init__(
        vllm_config=vllm_config,
        prefix=prefix,
        config=config,
        attn_layer_type=MistralAttention,
    )

    self.layer_idx = int(prefix.split(sep=".")[-1])
    quant_config = self.get_quant_config(vllm_config)
    config = config or vllm_config.model_config.hf_config

    do_fusion = getattr(
        quant_config, "enable_quantization_scaling_fusion", False
    ) and vllm_config.cache_config.cache_dtype.startswith("fp8")
    if do_fusion:
        self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
        self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj

MistralForCausalLM

Bases: LlamaForCausalLM

Source code in vllm/model_executor/models/mistral.py
class MistralForCausalLM(LlamaForCausalLM):
    # Mistral: We don't support LoRA on the embedding layers.
    embedding_modules: dict[str, str] = {}

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm",
    }

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = MistralDecoderLayer,
    ):
        super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = MistralDecoderLayer,
    ):
        return MistralModel(
            vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(
            self.maybe_remap_mistral(name, loaded_weight)
            for name, loaded_weight in weights
        )

    def maybe_remap_mistral(
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> tuple[str, torch.Tensor]:
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
            attn_in = self.config.head_dim * n_heads

            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
        if "wk" in modules and modules[-1] == "weight":
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
        elif "wq" in modules and modules[-1] == "weight":
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)

        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = f"{item}.{next_item}" if next_item is not None else None

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight

embedding_modules class-attribute instance-attribute

embedding_modules: dict[str, str] = {}

mistral_mapping class-attribute instance-attribute

mistral_mapping = {
    "layers": "model.layers",
    "attention": "self_attn",
    "qscale_act": "input_scale",
    "qscale_weight": "weight_scale",
    "kv_fake_quantizer.qscale_act": "kv_scale",
    "q_fake_quantizer.qscale_act": "attn.q_scale",
    "k_fake_quantizer.qscale_act": "k_scale",
    "v_fake_quantizer.qscale_act": "v_scale",
    "wq": "q_proj",
    "wk": "k_proj",
    "wv": "v_proj",
    "wo": "o_proj",
    "attention_norm": "input_layernorm",
    "feed_forward": "mlp",
    "w1": "gate_proj",
    "w2": "down_proj",
    "w3": "up_proj",
    "ffn_norm": "post_attention_layernorm",
    "tok_embeddings": "model.embed_tokens",
    "output": "lm_head",
    "norm": "model.norm",
}

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[Module] = MistralDecoderLayer,
)
Source code in vllm/model_executor/models/mistral.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[nn.Module] = MistralDecoderLayer,
):
    super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

_init_model

_init_model(
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[Module] = MistralDecoderLayer,
)
Source code in vllm/model_executor/models/mistral.py
def _init_model(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[nn.Module] = MistralDecoderLayer,
):
    return MistralModel(
        vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
    )

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/mistral.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
    )
    return loader.load_weights(
        self.maybe_remap_mistral(name, loaded_weight)
        for name, loaded_weight in weights
    )

maybe_remap_mistral

maybe_remap_mistral(
    name: str, loaded_weight: Tensor
) -> tuple[str, Tensor]
Source code in vllm/model_executor/models/mistral.py
def maybe_remap_mistral(
    self,
    name: str,
    loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
    def permute(w: torch.Tensor, n_heads: int, attn_out: int):
        attn_in = self.config.head_dim * n_heads

        return (
            w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
            .transpose(1, 2)
            .reshape(attn_in, attn_out)
        )

    mapping = self.mistral_mapping
    modules = name.split(".")

    # rotary embeds should be sliced
    # If using quantized model in mistral format,
    # quantization scales (qscale_weight) also need to be sliced
    if "wk" in modules and modules[-1] == "weight":
        loaded_weight = permute(
            loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
        )
    elif (
        "wk" in modules
        and modules[-1] == "qscale_weight"
        and loaded_weight.numel() > 1
    ):
        loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
    elif "wq" in modules and modules[-1] == "weight":
        loaded_weight = permute(
            loaded_weight, self.config.num_attention_heads, self.config.hidden_size
        )
    elif (
        "wq" in modules
        and modules[-1] == "qscale_weight"
        and loaded_weight.numel() > 1
    ):
        loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)

    num_modules = len(modules)
    for i in range(num_modules):
        item = modules[i]
        next_item = modules[i + 1] if i < num_modules - 1 else None

        combined_item = f"{item}.{next_item}" if next_item is not None else None

        if combined_item in mapping:
            name = name.replace(combined_item, mapping[combined_item])
        elif item in mapping and mapping[item] not in name:
            name = name.replace(item, mapping[item])

    return name, loaded_weight

MistralModel

Bases: LlamaModel

Source code in vllm/model_executor/models/mistral.py
@support_torch_compile
class MistralModel(LlamaModel):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = MistralDecoderLayer,
    ):
        super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[Module] = MistralDecoderLayer,
)
Source code in vllm/model_executor/models/mistral.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    layer_type: type[nn.Module] = MistralDecoderLayer,
):
    super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)