From 0d7caead36bae50524a957a4b405c24bd3daf5b9 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 13 Dec 2024 09:44:39 +0000 Subject: [PATCH 1/7] chore: bump transformers and optimum version --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 535a96f33..079412c3e 100644 --- a/setup.py +++ b/setup.py @@ -13,9 +13,9 @@ INSTALL_REQUIRES = [ - "transformers == 4.43.2", + "transformers == 4.46.2", "accelerate == 0.29.2", - "optimum ~= 1.22.0", + "optimum ~= 1.23.0", "huggingface_hub >= 0.20.1", "numpy>=1.22.2, <=1.25.2", "protobuf>=3.20.3, <4", From 9b911e95aa142e8ac2bdc287030a5ea99a45b564 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 24 Dec 2024 09:26:13 +0000 Subject: [PATCH 2/7] test(tgi): adjust sampling expectations --- .../tests/integration/test_generate.py | 4 ++-- text-generation-inference/tests/server/test_decode.py | 8 ++++---- text-generation-inference/tests/server/test_prefill.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/text-generation-inference/tests/integration/test_generate.py b/text-generation-inference/tests/integration/test_generate.py index 0f75a82ad..4271bc81d 100644 --- a/text-generation-inference/tests/integration/test_generate.py +++ b/text-generation-inference/tests/integration/test_generate.py @@ -47,8 +47,8 @@ async def test_model_single_request(tgi_service): ) sample_expectations = { "gpt2": "Deep Learning", - "llama": "Deep Learning", - "mistral": "Deep learning", + "llama": "Deep learning", + "mistral": "Deep Learning", "qwen2": "Deep Learning", } assert sample_expectations[service_name] in response diff --git a/text-generation-inference/tests/server/test_decode.py b/text-generation-inference/tests/server/test_decode.py index 7b69eae98..eadb7c43b 100644 --- a/text-generation-inference/tests/server/test_decode.py +++ b/text-generation-inference/tests/server/test_decode.py @@ -36,10 +36,10 @@ def _test_decode(config_name, generator, do_sample): assert output.finish_reason == 0 if do_sample: expected_text = { - "gpt2": " The sun was set", - "llama": "George Orwell, 1984", - "mistral": "The sky was", - "qwen2": " A young woman with", + "gpt2": " the wind was blowing", + "llama": "George Orwell", + "mistral": "The sky is black", + "qwen2": " I stood in the back yard", }[config_name] assert expected_text in output.text else: diff --git a/text-generation-inference/tests/server/test_prefill.py b/text-generation-inference/tests/server/test_prefill.py index 7c50fd6bf..aec02ffd5 100644 --- a/text-generation-inference/tests/server/test_prefill.py +++ b/text-generation-inference/tests/server/test_prefill.py @@ -35,10 +35,10 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert len(generations) == batch_size if do_sample: expectations = { - "gpt2": [383, " The"], + "gpt2": [632, " It"], "llama": [10058, " George"], "mistral": [450, " The"], - "qwen2": [362, " A"], + "qwen2": [358, " I"], }[config_name] else: expectations = { From bed6dddec3ca9eba4321d068c7b4dac2dd685bf6 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 13 Dec 2024 09:46:01 +0000 Subject: [PATCH 3/7] feat(decoder): add support for granite models (1) Using Llama modeling code copied from TnX, not including granite custom multipliers. As expected, outputs are incorrect. --- .../neuron/model_configs/decoder_configs.py | 7 + optimum/neuron/models/granite/__init__.py | 14 + optimum/neuron/models/granite/config.py | 32 + optimum/neuron/models/granite/hlo.py | 829 ++++++++++++++++++ optimum/neuron/models/granite/model.py | 301 +++++++ optimum/neuron/models/granite/modules.py | 87 ++ 6 files changed, 1270 insertions(+) create mode 100644 optimum/neuron/models/granite/__init__.py create mode 100644 optimum/neuron/models/granite/config.py create mode 100644 optimum/neuron/models/granite/hlo.py create mode 100644 optimum/neuron/models/granite/model.py create mode 100644 optimum/neuron/models/granite/modules.py diff --git a/optimum/exporters/neuron/model_configs/decoder_configs.py b/optimum/exporters/neuron/model_configs/decoder_configs.py index dd7f01d3b..30ddc808e 100644 --- a/optimum/exporters/neuron/model_configs/decoder_configs.py +++ b/optimum/exporters/neuron/model_configs/decoder_configs.py @@ -17,6 +17,7 @@ from optimum.exporters.tasks import TasksManager +from ....neuron.models.granite.model import GraniteForSampling from ....neuron.models.qwen2.model import Qwen2ForSampling from ..config import TextNeuronDecoderConfig @@ -63,3 +64,9 @@ class Qwen2NeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = Qwen2ForSampling CONTINUOUS_BATCHING = True FUSE_QKV = False + + +@register_in_tasks_manager("granite", "text-generation") +class GraniteNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = GraniteForSampling + CONTINUOUS_BATCHING = True diff --git a/optimum/neuron/models/granite/__init__.py b/optimum/neuron/models/granite/__init__.py new file mode 100644 index 000000000..fdc025786 --- /dev/null +++ b/optimum/neuron/models/granite/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/granite/config.py b/optimum/neuron/models/granite/config.py new file mode 100644 index 000000000..6eefd30a6 --- /dev/null +++ b/optimum/neuron/models/granite/config.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PretrainedConfig +from transformers_neuronx.llama.config import LlamaConfig + + +class GraniteConfig(LlamaConfig): + """The Granite model uses the same configuration as the TnX LLama model""" + + def __init__( + self, config: PretrainedConfig, n_positions: int, batch_size: int, amp: str, tp_degree: int, **kwargs + ): + super().__init__(config, n_positions, batch_size, amp, tp_degree, **kwargs) + self.model_type = "granite" + # These are parameters specific to the granite modeling + self.attention_multiplier = config.attention_multiplier + self.embedding_multiplier = config.embedding_multiplier + self.logits_scaling = config.logits_scaling + self.residual_multiplier = config.residual_multiplier diff --git a/optimum/neuron/models/granite/hlo.py b/optimum/neuron/models/granite/hlo.py new file mode 100644 index 000000000..9b8e74b7f --- /dev/null +++ b/optimum/neuron/models/granite/hlo.py @@ -0,0 +1,829 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Optional + +from transformers_neuronx import constants, hlo, utils +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB +from transformers_neuronx.hlo import dequantize_kv_cache_direct_cast, quantize_kv_cache_direct_cast +from transformers_neuronx.layers import attention, attention_utils, flash_decoding, rotary, transformer +from transformers_neuronx.nki.compile import nki_call + +from .config import GraniteConfig + + +class GraniteForSamplingNoEmbeddingHlo: + + def __init__(self, config: GraniteConfig, neuron_config: Optional[NeuronConfig] = None): + self.config = config + self.neuron_config = neuron_config + self.n_positions = None + self.num_active_blocks = None + + @property + def shard_over_batch(self): + # Property access allows fallback configuration to be enabled after construction + return ( + self.neuron_config is not None + and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH + ) + + def inputs(self, scribe, dtype, n_active_tokens, batch_size): + tensors, dims = transformer.inputs( + scribe, + dtype, + batch_size, + n_active_tokens, + self.config.hidden_size, + self.neuron_config, + self.config.tp_degree, + ) + + return tensors, dims + + def token_tree_inputs(self, scribe, dtype, n_active_tokens, batch_size): + tensors, dims = self.inputs(scribe, dtype, n_active_tokens, batch_size) + s32 = scribe.s32 + cache_2d = self.neuron_config and self.neuron_config.use_2d_cache_ids + # Allow tree based speculation inputs + if cache_2d: + position_sizes = batch_size, n_active_tokens + previous_cache_ids = s32[position_sizes].Parameter(parameter_number=4) + reorder_mapping = s32[position_sizes].Parameter(parameter_number=5) + else: + previous_cache_ids = s32[n_active_tokens].Parameter(parameter_number=4) + reorder_mapping = s32[n_active_tokens].Parameter(parameter_number=5) + seq_slice_dim = 1 if cache_2d else 0 + + return (*tensors, previous_cache_ids, reorder_mapping), (*dims, seq_slice_dim, seq_slice_dim) + + def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights): + if self.neuron_config.shard_over_sequence and self.neuron_config.on_device_embedding: + *rst, embed_weight = weights + else: + embed_weight, *rst = weights + dtype = getattr(input_ids.scribe, self.config.amp) + if self.neuron_config.on_device_embedding and self.neuron_config.sequence_parallel_norm: + hidden = hlo.embedding(embed_weight, input_ids, tp_degree=1, dtype=dtype) + else: + hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype) + if self.config.hidden_size % self.config.tp_degree != 0: + hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0) + if self.neuron_config.attention_layout == LAYOUT_HSB: + hidden = hlo.transpose210(hidden) + return hidden + + def token_tree_embedding( + self, input_ids, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights + ): + return self.embedding(input_ids, cache_ids, start_ids, last_token_id, *weights) + + def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): + # TODO: move this fallback calculation to decoder.py + if self.num_active_blocks is None and self.neuron_config.optimized_paged_attention: + max_model_len = self.neuron_config.continuous_batching.max_model_len + max_num_seqs = self.neuron_config.continuous_batching.max_num_seqs + block_size = self.neuron_config.continuous_batching.block_size + self.num_active_blocks = (max_model_len * max_num_seqs // block_size) - 2 + + if self.neuron_config.optimized_paged_attention and len(last_token_id.sizes) == 2: + # For decoding with multiple KV cache blocks: + # - cache_ids are used as context_lens + # - start_ids are used as slot_mapping + # - last_token_id is used as block_tables + # The function below transforms 2D block_tables into 1D active block table + last_token_id = attention_utils.active_block_tables( + block_tables=last_token_id, + context_lens=cache_ids, + num_active_blocks=self.num_active_blocks, + neuron_config=self.neuron_config, + ) + max_num_seqs = self.neuron_config.continuous_batching.max_num_seqs + block_size = self.neuron_config.continuous_batching.block_size + block_to_seq = attention_utils.block_to_seq_indexing( + context_lens=cache_ids, num_seqs=max_num_seqs, num_blocks=self.num_active_blocks, block_size=block_size + ) + else: + block_to_seq = None + + head_dim = self.config.attention_head_size + pos_embed = rotary.hlo_rotary_embedding( + hidden.dtype, + int(head_dim * self.config.rotary_percentage), + cache_ids, + base=self.config.rope_theta, + interpolation_factor=self.config.position_interpolation_factor, + rope_scaling=self.config.rope_scaling, + ) + core_id = None + + # flash decoding + if self.neuron_config.shard_over_sequence: + core_id, *rst = weights + n_kv_heads = ( + self.config.num_key_value_heads + if hasattr(self.config, "num_key_value_heads") + else self.config.num_attention_heads + ) + cores_per_kv_head = self.config.tp_degree // n_kv_heads + self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree + cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id( + cache_ids, start_ids, core_id, self.n_positions, cores_per_kv_head=self.cores_per_kv_head + ) + else: + mask, active_mask = hlo.attention_mask( + cache_ids, + start_ids, + self.n_positions, + last_token_id=last_token_id, + num_active_blocks=self.num_active_blocks, + neuron_config=self.neuron_config, + ) + + return hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id + + def token_tree_pre_layer( + self, hidden, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights + ): + hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id = ( + self.pre_layer(hidden, cache_ids, start_ids, last_token_id, *weights) + ) + if self.neuron_config.on_device_embedding: + embed_weight, token_tree_mask = weights + else: + token_tree_mask, *rst = weights + active_mask = hlo.token_tree_attention_mask(token_tree_mask, active_mask) + return ( + hidden, + last_token_id, + pos_embed, + cache_ids, + start_ids, + block_to_seq, + previous_cache_ids, + reorder_mapping, + mask, + active_mask, + core_id, + ) + + def layer( + self, + hidden, + last_token_id, + pos_embed, + cache_ids, + start_ids, + block_to_seq, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + pre_attn_ln_weight, + pre_attn_ln_bias, + fused_pre_attn_ln_qkv_weight, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + post_attn_ln_weight, + post_attn_ln_bias, + pre_mlp_ln_weight, + pre_mlp_ln_bias, + mlp_in_weight, + mlp_in_scales, + mlp_in_bias, + mlp_out_weight, + mlp_out_scales, + mlp_out_bias, + post_mlp_ln_weight, + post_mlp_ln_bias, + in0_weight=None, + in0_scales=None, + in1_weight=None, + in1_scales=None, + out_weight=None, + out_scales=None, + ): + eps = self.config.rms_norm_eps + is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH + if self.neuron_config and self.neuron_config.fused_rmsnorm_qkv and active_mask is None: + assert fused_pre_attn_ln_qkv_weight is not None + attn_output, out_attn_k_cache, out_attn_v_cache = self.fused_rmsnorm_qkv( + hidden, + None, + eps, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + fused_pre_attn_ln_qkv_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, # should be none + attn_v_weight, + attn_v_scales, + attn_v_bias, # should be none + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) + else: + ln_hidden = ( + hlo.rms_norm( + hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree + ) + if is_bsh + else hlo.rms_norm( + hidden, + pre_attn_ln_weight, + eps, + dim=0, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + ) + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) + hidden = hlo.add(attn_output, hidden) + gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp + rms_norm_dim = 2 if is_bsh else 0 + norm_hidden = hlo.rms_norm( + hidden, + pre_mlp_ln_weight, + eps, + dim=rms_norm_dim, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + if self.neuron_config.fuse_mlp: + assert all( + map(lambda x: not (x), [in0_weight, in1_weight, out_weight, in0_scales, in1_scales, out_scales]) + ), "in0, in1 and out weights have to be None" + in0_weight, in0_scales = mlp_in_weight, mlp_in_scales + out_weight, out_scales = mlp_out_weight, mlp_out_scales + + mlp_hidden = gated_mlp( + norm_hidden, + in0_weight, + in1_weight, + out_weight, + in0_scales=in0_scales, + in1_scales=in1_scales, + out_scales=out_scales, + activation_function="silu", + tp_degree=self.config.tp_degree, + neuron_config=self.neuron_config, + ) + res_hidden = hlo.add(mlp_hidden, hidden) + return res_hidden, out_attn_k_cache, out_attn_v_cache + + def token_tree_layer( + self, + hidden, + last_token_id, + pos_embed, + cache_ids, + start_ids, + block_to_seq, + previous_cache_ids, + reorder_mapping, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + pre_attn_ln_weight, + pre_attn_ln_bias, + fused_pre_attn_ln_qkv_weight, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + post_attn_ln_weight, + post_attn_ln_bias, + pre_mlp_ln_weight, + pre_mlp_ln_bias, + mlp_in_weight, + mlp_in_scales, + mlp_in_bias, + mlp_out_weight, + mlp_out_scales, + mlp_out_bias, + post_mlp_ln_weight, + post_mlp_ln_bias, + in0_weight, + in0_scales, + in1_weight, + in1_scales, + out_weight, + out_scales, + ): + eps = self.config.rms_norm_eps + is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH + ln_hidden = ( + hlo.rms_norm( + hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree + ) + if is_bsh + else hlo.rms_norm( + hidden, + pre_attn_ln_weight, + eps, + dim=0, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + ) + reordered_attn_k_cache, reordered_attn_v_cache = attention.reorder_kv_cache( + attn_k_cache, attn_v_cache, previous_cache_ids, reorder_mapping, neuron_config=self.neuron_config + ) + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + reordered_attn_k_cache, + reordered_attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) + hidden = hlo.add(attn_output, hidden) + gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp + rms_norm_dim = 2 if is_bsh else 0 + norm_hidden = hlo.rms_norm( + hidden, + pre_mlp_ln_weight, + eps, + dim=rms_norm_dim, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + mlp_hidden = gated_mlp( + norm_hidden, + in0_weight, + in1_weight, + out_weight, + in0_scales=in0_scales, + in1_scales=in1_scales, + out_scales=out_scales, + activation_function="silu", + tp_degree=self.config.tp_degree, + neuron_config=self.neuron_config, + ) + res_hidden = hlo.add(mlp_hidden, hidden) + return res_hidden, out_attn_k_cache, out_attn_v_cache + + def ln_lm_head( + self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias, return_all_outputs=True + ): + logits = transformer.rms_lm_head( + self.config.tp_degree, + hidden, + last_token_id, + rms_weight, + lm_head_weight, + lm_head_bias, + return_all_outputs, + eps=self.config.rms_norm_eps, + neuron_config=self.neuron_config, + ) + return logits + + def fused_rmsnorm_qkv( + self, + hidden, + pre_attn_ln_weight, + eps, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, # should be none + attn_v_weight, + attn_v_scales, + attn_v_bias, # should be none + attn_out_weight, + attn_out_scales, + attn_out_bias, + ): + # TODO: refactor below + from neuronxcc.nki._private_kernels.fused_linear import fused_rms_norm_qkv + + def _kernel(h, w, output): + return fused_rms_norm_qkv(h, w, output, eps=eps) + + n_seqs, n_active_tokens, _ = hidden.sizes + d_head = self.config.attention_head_size + tp_degree = self.config.tp_degree + + # Compute the expected number of KV heads (Used in case fused QKV is used) + n_kv_heads_tp = None + if self.config.num_key_value_heads is not None: + n_head = self.config.num_attention_heads + n_kv_head = self.config.num_key_value_heads + n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) + n_kv_heads_tp = n_kv_head_padded // tp_degree + + _, hidden_size_tp = attn_q_weight.sizes + + n_total_heads_tp = hidden_size_tp // d_head + n_heads_tp = n_total_heads_tp - 2 * n_kv_heads_tp + # Q hidden size + hidden_size_tp = d_head * n_heads_tp + + nki_output = nki_call( + _kernel, + hidden, + attn_q_weight, + output_HloShapes=[hidden.dtype[hidden.sizes[0], hidden.sizes[1], attn_q_weight.sizes[-1]]], + ) + slice_lim = nki_output.sizes[-1] // (n_heads_tp + 2 * n_kv_heads_tp) + query = hlo.slice_along(nki_output, -1, n_heads_tp * slice_lim, start=0) + key = hlo.slice_along(nki_output, -1, (n_heads_tp + n_kv_heads_tp) * slice_lim, start=n_heads_tp * slice_lim) + value = hlo.slice_along( + nki_output, + -1, + (n_heads_tp + 2 * n_kv_heads_tp) * slice_lim, + start=(n_heads_tp + n_kv_heads_tp) * slice_lim, + ) + + # shard over head (llama/hlo.py) + active_q_sizes = n_active_tokens, n_seqs, n_heads_tp, d_head + active_kv_sizes = n_active_tokens, n_seqs, n_kv_heads_tp, d_head + query = hlo.reshape(query, active_q_sizes) + key = hlo.reshape(key, active_kv_sizes) + value = hlo.reshape(value, active_kv_sizes) + assert all( + [ + attn_q_scales is None, + attn_q_bias is None, + attn_k_weight is None, + attn_k_scales is None, + attn_k_bias is None, + attn_v_weight is None, + attn_v_scales is None, + attn_v_bias is None, + ] + ) + + # Pass QKV tuple since it will not be computed in the attention block + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + nki_output, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + attn_k_cache, + attn_v_cache, + attn_q_weight, + None, + None, + None, + None, + None, + None, + None, + None, + attn_out_weight, + attn_out_scales, + attn_out_bias, + qkv_tuple=(query, key, value), + ) + return attn_output, out_attn_k_cache, out_attn_v_cache + + def attention( + self, + hidden, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + cached_keys, + cached_values, + q_weight, + q_scales, + q_bias, + k_weight, + k_scales, + k_bias, + v_weight, + v_scales, + v_bias, + out_weight, + out_scales, + out_bias, + qkv_tuple: tuple = None, + ): + d_head = self.config.attention_head_size + tp_degree = self.config.tp_degree + + # Compute the expected number of KV heads (Used in case fused QKV is used) + n_kv_heads_tp = None + if self.config.num_key_value_heads is not None: + n_head = self.config.num_attention_heads + n_kv_head = self.config.num_key_value_heads + n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) + n_kv_heads_tp = n_kv_head_padded // tp_degree + + # Q = (hidden @ wQ) + bQ + # K = (hidden @ wK) + bK + # V = (hidden @ wV) + bV + if qkv_tuple: + # If computed already, skip computation here + assert active_mask is None + query, key, value = qkv_tuple + else: + query, key, value = attention.query_key_value( + hidden, + q_weight, + q_scales, + q_bias, + k_weight, + k_scales, + k_bias, + v_weight, + v_scales, + v_bias, + d_head, + neuron_config=self.neuron_config, + tp_degree=tp_degree, # TODO: include tp_degree into neuron_config + shard_over_batch=self.shard_over_batch, + n_kv_heads_tp=n_kv_heads_tp, + ) + + # Q = Rotate(Q) + # K = Rotate(K) + query, key = rotary.rotate_half( + query, + key, + pos_embed, + self.config.rotary_percentage, + tp_degree=tp_degree, + shard_over_batch=self.shard_over_batch, + ) + + # Q = Q / sqrt(d_head) + query = attention.scale(query, d_head) + + # In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV. + bsh_cache_layout = False + batch_dim = 1 + if self.neuron_config is not None: + bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH + if bsh_cache_layout: + query, key, value = attention_utils.transpose_qkv(query, key, value) + batch_dim = 0 + + # Single Token Generation ("Prefetch"-style) ans speculative forward + if active_mask is not None: + + n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0] + if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching: + # For speculative forward + continuous batching, slice out samples in the batch size + # corresponding to the batch size of the speculative head + slice_sizes = [1] * len(cached_keys.sizes) + if cached_keys.sizes[batch_dim] == 1: + # Use hlo.select for batch size 1 as index select is prohibitively slow + # TODO: revert to hlo.index_select once its faster P126527643 + cached_keys_s = hlo.select( + cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes), keepdim=True + ) + cached_values_s = hlo.select( + cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes), keepdim=True + ) + else: + cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids) + cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids) + if self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys_s, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values_s, self.neuron_config) + elif self.neuron_config and self.neuron_config.paged_attention: + # For decoding with multiple KV cache blocks, start_ids are used as block_tables + cached_keys_s = attention_utils.gather_blocks( + cached_keys, block_tables=last_token_id, neuron_config=self.neuron_config + ) + cached_values_s = attention_utils.gather_blocks( + cached_values, block_tables=last_token_id, neuron_config=self.neuron_config + ) + if self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys_s, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values_s, self.neuron_config) + elif self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values, self.neuron_config) + else: + cached_keys_s = cached_keys + cached_values_s = cached_values + # Communication 1: all-gather query from cores + if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence: + query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, n_head, tp_degree) + + # Sp = Q @ Kp + prior_scores = attention.score( + query, + cached_keys_s, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + block_to_seq=block_to_seq, + neuron_config=self.neuron_config, + ) + prior_scores = attention.mask( + prior_scores, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch + ) + + # Sa = Q @ Ka + active_score = attention.score( + query, + key, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + active_score = attention.mask( + active_score, active_mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch + ) + + # C = softmax(Sa, Sp) @ (Va, Vp) + if self.neuron_config.shard_over_sequence: + dtype = query.dtype + context = flash_decoding.context( + prior_scores, + active_score, + cached_values_s, + value, + core_id, + mask, + active_mask, + n_kv_heads=self.config.num_key_value_heads, + n_heads=n_head, + dtype=dtype, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + shard_over_batch=self.shard_over_batch, + ) + cache_ids, value, key = flash_decoding.select_values_within_bound( + cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0 + ) + + else: + context = attention.context( + prior_scores, + active_score, + cached_values_s, + value, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + context_lens=cache_ids, + num_active_blocks=self.num_active_blocks, + block_to_seq=block_to_seq, + neuron_config=self.neuron_config, + ) + + # KCache[I], VCache[I] = K, V + updated_keys, updated_values = attention.fused_kv_update_cache( + cached_keys, cached_values, cache_ids, key, value, start_ids, neuron_config=self.neuron_config + ) + + # Multi-Token Context Encoding + else: + _, batch_size, _, _ = query.sizes + if self.neuron_config.lhs_aligned or batch_size == 1: + context = attention.flash_attention(query, key, value) + else: + # do not use flash attention for lhs padded (right aligned) batch > 1 case + # because it does not correctly take mask into account + context = None + + if context is None: + # S = Q @ K + + score = attention.score( + query, + key, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + context = attention.context_combined( + score, + value, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + + if self.neuron_config.shard_over_sequence: + cache_ids, value, key = flash_decoding.select_values_within_bound( + cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0 + ) + # KCache, VCache = K, V + if cached_keys.sizes == key.sizes: + if self.neuron_config and self.neuron_config.kv_cache_quant: + updated_keys = quantize_kv_cache_direct_cast(key, self.neuron_config) + updated_values = quantize_kv_cache_direct_cast(value, self.neuron_config) + else: + updated_keys, updated_values = key, value + else: + updated_keys, updated_values = attention.fused_kv_update_cache( + cached_keys, cached_values, cache_ids, key, value, start_ids, neuron_config=self.neuron_config + ) + + # O = (C @ wO) + bO + output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config) + return output, updated_keys, updated_values diff --git a/optimum/neuron/models/granite/model.py b/optimum/neuron/models/granite/model.py new file mode 100644 index 000000000..0e3739e4d --- /dev/null +++ b/optimum/neuron/models/granite/model.py @@ -0,0 +1,301 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import warnings + +import torch +from transformers import PretrainedConfig +from transformers_neuronx import base, bucket, decoder, ops, utils +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.constants import KV_SHARD_PAD, LAYOUT_HSB + +from .config import GraniteConfig +from .hlo import GraniteForSamplingNoEmbeddingHlo +from .modules import GraniteForCausalLM + + +class GraniteForSampling(base.NeuronModelBase): + """The Granite model is a LLama model with 4 scalar multpliers that are applied to: + - the embeddings, + - the QK product in the attention (instead of the static 1/sqrt(num_heads)) + - the MLP outputs + - the lm_head logits + The implementation in this class is very similar to the one used for Llama in Tnx. + The only differences are: + - the config (GraniteConfig) and base model (GraniteForCausalLM) used in __init__, + - the multiplication of the logits by the logits multiplier + """ + + def __init__( + self, + config: PretrainedConfig, + *, + n_positions: int = 2048, + batch_size: int = 1, + amp: str = "f32", + tp_degree: int = 2, + context_length_estimate: int = None, + context_unroll: int = None, + unroll: int = None, + neuron_config: NeuronConfig = None, + prefixed_length: int = 0, + **kwargs, + ): + config = GraniteConfig(config, n_positions, batch_size, amp, tp_degree) + super().__init__(GraniteForCausalLM, config) + self.context_pre_hook = None + self.context_hook = None + self.config = config + self.neuron_config = neuron_config if neuron_config else NeuronConfig() + if self.neuron_config.shard_over_sequence: + n_kv_head = self.config.num_key_value_heads + kv_shard_degree = self.config.tp_degree // n_kv_head + assert kv_shard_degree <= KV_SHARD_PAD, "increase kv_shard degree is higher than default 128" + warnings.warn(f"shard over sequence enabled, increasing n_positions {n_positions} by 128") + if isinstance(n_positions, list): + npos = sorted(n_positions) + npos[-1] += KV_SHARD_PAD + else: + npos = n_positions + KV_SHARD_PAD + self.config.n_positions = npos + config.n_positions = npos + n_positions = npos + if self.neuron_config.on_device_generation: + self.neuron_config.on_device_generation.vocab_size = self.config.vocab_size + + self.layers_after_partition = self.neuron_config.auto_layer_partition(config.num_hidden_layers) + self.prefixed_length = prefixed_length + + if context_unroll is None: + context_unroll = len(self.layers_after_partition) + self.context_unroll = context_unroll + + if unroll is None: + unroll = len(self.layers_after_partition) + self.unroll = unroll + + self.token_buckets = bucket.token_sizes(n_positions) + self.context_buckets = bucket.context_sizes(context_length_estimate, self.token_buckets) + # input length should be divisable by tp_degree to activate seq paralle + if neuron_config and neuron_config.sequence_parallel_norm: + for bucket_size in self.context_buckets: + if ( + bucket_size > neuron_config.sequence_parallel_norm_threshold + and bucket_size % self.config.tp_degree != 0 + ): + raise ValueError( + f"Sequence parallel normalization requires the bucket size ({bucket_size}) to be divisible by the tensor parallel degree ({self.config.tp_degree})" + ) + self.window_context_buckets = [] + if prefixed_length: + if prefixed_length not in self.context_buckets: + self.context_buckets.append(prefixed_length) + self.context_buckets = sorted(self.context_buckets) + + self.batch_sizes = bucket.batch_sizes(batch_size) + self.context_batch_sizes = ( + [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes + ) + hlo_builder = GraniteForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) + self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( + tp_degree=tp_degree, + n_positions_list=self.token_buckets, + n_active_tokens=1, + batch_size=self.batch_sizes, + attention_head_size=config.attention_head_size, + amp=amp, + num_layers=len(self.layers_after_partition), + n_head=config.num_attention_heads, + n_kv_head=config.num_key_value_heads, + unroll=unroll, + neuron_config=self.neuron_config, + allow_pad=True, + builder=hlo_builder, + ) + self.decoder_lm_head = self.decoder_param_set.init_token_decoder( + unroll=self.unroll, buckets=self.token_buckets, model_obj=self + ) + self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder( + unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self + ) + self.decoder_lm_head_for_speculation = {} + self.decoder_lm_head_for_window_context = {} + + def load_weights(self): + self.materialize_embeddings() + ops.init() + + for layer_id, layer in enumerate(self.chkpt_model.model.layers): + if layer_id not in self.layers_after_partition: + continue + layer.materialize() + attn = layer.self_attn + mlp = layer.mlp + if self.neuron_config and self.neuron_config.quant: + is_unit_scale = self.neuron_config.quant.is_unit_scale(layer_id) + else: + is_unit_scale = False + new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale) + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + new_layer.add_attention_query(attn.q_proj.weight.detach().T, None) + new_layer.add_attention_key(attn.k_proj.weight.detach().T, None) + new_layer.add_attention_value(attn.v_proj.weight.detach().T, None) + if self.neuron_config and self.neuron_config.attn_output_transposed: + new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) + else: + new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + + # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state + if self.neuron_config.fuse_mlp: + assert all( + getattr(mlp, attr, None) for attr in ["gate_proj", "up_proj"] + ), "fuse_mlp need to have gate and up proj weights" + assert all( + getattr(mlp, attr, None).weight.shape[0] % self.config.tp_degree == 0 + for attr in ["gate_proj", "up_proj"] + ), f" mlp weights are not divisible tp_degree {self.config.tp_degree}" + mlp_in_weight = utils.interleave_mlp( + mlp.gate_proj.weight, mlp.up_proj.weight, tp_degree=self.config.tp_degree, dim=0 + ) + new_layer.add_mlp_input(mlp_in_weight.T.detach(), None) + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_mlp_output( + mlp.down_proj.weight.T.detach(), + None, + sharding=0, + transposed=True, + ) + else: + new_layer.add_mlp_output( + mlp.down_proj.weight.detach(), + None, + sharding=1, + transposed=False, + ) + else: + new_layer.add_parameter( + mlp.gate_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True + ) + new_layer.add_parameter( + mlp.up_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True + ) + if self.neuron_config.weight_tiling: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True, allow_transform=True + ) + else: + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True + ) + else: + new_layer.add_parameter( + mlp.down_proj.weight, sharding=1, allow_pad=True, allow_quantize=True, out_feature_dim=0 + ) + new_layer.to_neuron() + layer.nullify() + if self.neuron_config.shard_over_sequence: + self.decoder_lm_head.add_pre_layer_parameter(torch.arange(self.config.tp_degree), sharding=0) + # For pipeline parallel, we need to load ln and lm_head for now even if the pipeline stage doesn't compute the, because + # 1) we need the ln_lm_head hlo for pp0 to get the logits shape and dtype + # 2) we don't needs these for intermediate pp stages, but to keep things simple, just include ln_lm_head for all pp stages for now + # 3) to get ln_lm_head hlo, we need to do weight loading and sharding + # 4) this will introduce extra memory allocation, but ln_lm_head i/o tensor is much smaller and we can get rid of it when we can construct hlo in init + ln_f = self.chkpt_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + + lm_head = self.chkpt_model.lm_head + lm_head.materialize() + self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + if self.neuron_config.on_device_embedding: + if self.neuron_config.sequence_parallel_norm: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=None, allow_pad=True + ) + else: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True + ) + lm_head.nullify() + + self.decoder_lm_head.to_neuron() + self.init_rest_of_model() + + def materialize_embeddings(self): + # Materialize the embedding to CPU + self.chkpt_model.model.embed_tokens.materialize() + + def init_rest_of_model(self): + # Pipeline sparallel deosn't support executor right now + if not self.neuron_config.is_pp(): + self.decoder_lm_head.use_executor = True + + if self.context_buckets: + for context_length_estimate in self.context_buckets: + for batch_size in self.context_batch_sizes: + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_context[context_length_estimate, batch_size] + ) + # PERF: No latency improvement seen in multi-layer models from executor + # Pipeline parallel deosn't support executor right now + if self.context_unroll == self.config.num_hidden_layers and not self.neuron_config.is_pp(): + model.use_executor = True + self.decoder_lm_head_for_context[context_length_estimate, batch_size] = model + + if self.decoder_lm_head_for_speculation: + for i, k in enumerate(self.decoder_lm_head_for_speculation): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, + new=self.decoder_lm_head_for_speculation[k], + embed_weight=self.chkpt_model.model.embed_tokens.weight, + ) + self.decoder_lm_head_for_speculation[k] = model + + if self.decoder_lm_head_for_window_context: + for i, k in enumerate(self.decoder_lm_head_for_window_context): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_window_context[k] + ) + self.decoder_lm_head_for_window_context[k] = model + + def set_prefixed(self, input_ids): + self.prefixed_input_ids = input_ids[:, : self.prefixed_length] + prefixed_length = self.prefixed_length + self.prefixed_length = 0 + self.forward(self.prefixed_input_ids) + self.prefixed_length = prefixed_length + + def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwargs): + padded_inputs, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids, **kwargs) + if not self.neuron_config.on_device_embedding: + input_embeddings = self.chkpt_model.model.embed_tokens(padded_inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + input_embeddings = input_embeddings.transpose(0, -1).contiguous() + else: + # embedding layer is on device and will be computed as part of self._forward(), so don't compute here + input_embeddings = None + return padded_inputs, input_embeddings, *rst + + def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + if last_token_id is not None: # preprocess_and_embed() has already been invoked + rst = cache_ids, start_ids, last_token_id + else: # invoke preprocess_and_embed() + input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs) + # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) + inputs = input_embeddings if input_embeddings is not None else input_ids + logits = self._forward(inputs, *rst) + logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + return logits diff --git a/optimum/neuron/models/granite/modules.py b/optimum/neuron/models/granite/modules.py new file mode 100644 index 000000000..4cbbcc9f3 --- /dev/null +++ b/optimum/neuron/models/granite/modules.py @@ -0,0 +1,87 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers_neuronx import dtypes, module, utils + +from .config import GraniteConfig + + +class GraniteForCausalLM(module.PretrainedModel): + + def __init__(self, config: GraniteConfig): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.model = GraniteModel(config) + self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) + + def get_tied_parameters(self): + return [(self.model.embed_tokens.weight, self.lm_head.weight)] + + def get_base_model(self): + return self.model + + +class GraniteModel(module.LowMemoryModule): + + def __init__(self, config: GraniteConfig): + super().__init__() + self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) + self.layers = module.LowMemoryModuleList( + [GraniteDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = GraniteRMSNorm(config) + + +class GraniteRMSNorm(module.LowMemoryModule): + + def __init__(self, config: GraniteConfig) -> None: + super().__init__() + self.weight = module.UninitializedParameter() + + +class GraniteDecoderLayer(module.LowMemoryModule): + + def __init__(self, config: GraniteConfig): + super().__init__() + self.self_attn = GraniteAttention(config) + self.mlp = GraniteMLP(config) + self.input_layernorm = GraniteRMSNorm(config) + self.post_attention_layernorm = GraniteRMSNorm(config) + + +class GraniteAttention(module.LowMemoryModule): + + def __init__(self, config: GraniteConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) + + +class GraniteMLP(module.LowMemoryModule): + + def __init__(self, config: GraniteConfig): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) From 9907df44096699099747f7f03264f177e960075f Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 13 Dec 2024 09:46:01 +0000 Subject: [PATCH 4/7] feat(decoder): add support for granite models (2) Apply granite specific multipliers. --- optimum/neuron/models/granite/hlo.py | 21 +++++++++++++++++++-- optimum/neuron/models/granite/model.py | 2 ++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/models/granite/hlo.py b/optimum/neuron/models/granite/hlo.py index 9b8e74b7f..bfeab595d 100644 --- a/optimum/neuron/models/granite/hlo.py +++ b/optimum/neuron/models/granite/hlo.py @@ -24,6 +24,16 @@ from .config import GraniteConfig +def scale_mul(t, scale): + """Multiply a tensor by a float scale""" + dtype = t.dtype + # Convert float to a constant scalar tensor of the target dtype + scale_t = dtype.Constant(constant_value=scale) + # Expand the scalar tensor to the target shape + scale_br_t = dtype[t.sizes].Broadcast(scale_t, dimensions=[]) + return dtype[t.sizes].Multiply(t, scale_br_t) + + class GraniteForSamplingNoEmbeddingHlo: def __init__(self, config: GraniteConfig, neuron_config: Optional[NeuronConfig] = None): @@ -118,6 +128,9 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): else: block_to_seq = None + # Granite specific: embeddings are multiplied by embedding_multiplier + hidden = scale_mul(hidden, self.config.embedding_multiplier) + head_dim = self.config.attention_head_size pos_embed = rotary.hlo_rotary_embedding( hidden.dtype, @@ -297,6 +310,8 @@ def layer( attn_out_scales, attn_out_bias, ) + # Granite specific: attention output is multiplied by residual multiplier + attn_output = scale_mul(attn_output, self.config.residual_multiplier) hidden = hlo.add(attn_output, hidden) gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp rms_norm_dim = 2 if is_bsh else 0 @@ -327,6 +342,8 @@ def layer( tp_degree=self.config.tp_degree, neuron_config=self.neuron_config, ) + # Granite specific: MLP output is multiplied by residual_multiplier + mlp_hidden = scale_mul(mlp_hidden, self.config.residual_multiplier) res_hidden = hlo.add(mlp_hidden, hidden) return res_hidden, out_attn_k_cache, out_attn_v_cache @@ -657,8 +674,8 @@ def attention( shard_over_batch=self.shard_over_batch, ) - # Q = Q / sqrt(d_head) - query = attention.scale(query, d_head) + # Granite specific: instead of dividing the QK product, multiply it by the attention_multiplier + query = scale_mul(query, self.config.attention_multiplier) # In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV. bsh_cache_layout = False diff --git a/optimum/neuron/models/granite/model.py b/optimum/neuron/models/granite/model.py index 0e3739e4d..ddd3aecf2 100644 --- a/optimum/neuron/models/granite/model.py +++ b/optimum/neuron/models/granite/model.py @@ -297,5 +297,7 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) inputs = input_embeddings if input_embeddings is not None else input_ids logits = self._forward(inputs, *rst) + # Granite specific: divide logits by scaling factor + logits = logits / self.config.logits_scaling logits = self._postprocess(logits, start_ids=start_ids, **kwargs) return logits From 00284fac2293688f22470557fe35e711f8f26a2f Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 23 Dec 2024 13:03:09 +0000 Subject: [PATCH 5/7] test(decoder): add granite unit tests --- tests/decoder/conftest.py | 4 ++++ tests/decoder/test_decoder_export.py | 1 + 2 files changed, 5 insertions(+) diff --git a/tests/decoder/conftest.py b/tests/decoder/conftest.py index 60d728945..677b8ffbf 100644 --- a/tests/decoder/conftest.py +++ b/tests/decoder/conftest.py @@ -37,6 +37,10 @@ "model_id": "Qwen/Qwen2.5-0.5B", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"}, }, + "granite": { + "model_id": "ibm-granite/granite-3.1-2b-instruct", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, + }, "mistral": { "model_id": "optimum/mistral-1.1b-testing", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, diff --git a/tests/decoder/test_decoder_export.py b/tests/decoder/test_decoder_export.py index 9224ecb22..61aa57481 100644 --- a/tests/decoder/test_decoder_export.py +++ b/tests/decoder/test_decoder_export.py @@ -31,6 +31,7 @@ "mixtral": "dacorvo/Mixtral-tiny", "opt": "hf-internal-testing/tiny-random-OPTForCausalLM", "qwen2": "yujiepan/qwen2.5-128k-tiny-random", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", } From c2c6557337e859d8a3f7a36306a0571b1572f98e Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 23 Dec 2024 13:06:58 +0000 Subject: [PATCH 6/7] test(tgi): add granite tests --- text-generation-inference/tests/fixtures/model.py | 4 ++++ text-generation-inference/tests/integration/test_generate.py | 3 +++ text-generation-inference/tests/server/test_decode.py | 2 ++ text-generation-inference/tests/server/test_prefill.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/text-generation-inference/tests/fixtures/model.py b/text-generation-inference/tests/fixtures/model.py index 73f633862..6fa63ce86 100644 --- a/text-generation-inference/tests/fixtures/model.py +++ b/text-generation-inference/tests/fixtures/model.py @@ -41,6 +41,10 @@ "model_id": "Qwen/Qwen2.5-0.5B", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"}, }, + "granite": { + "model_id": "ibm-granite/granite-3.1-2b-instruct", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, + }, } diff --git a/text-generation-inference/tests/integration/test_generate.py b/text-generation-inference/tests/integration/test_generate.py index 4271bc81d..75c064a38 100644 --- a/text-generation-inference/tests/integration/test_generate.py +++ b/text-generation-inference/tests/integration/test_generate.py @@ -25,6 +25,7 @@ async def test_model_single_request(tgi_service): "llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", + "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", } assert response.generated_text == greedy_expectations[service_name] @@ -50,6 +51,7 @@ async def test_model_single_request(tgi_service): "llama": "Deep learning", "mistral": "Deep Learning", "qwen2": "Deep Learning", + "granite": "Deep learning", } assert sample_expectations[service_name] in response @@ -84,6 +86,7 @@ async def test_model_multiple_requests(tgi_service, generate_load): "llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", + "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", } expected = expectations[tgi_service.client.service_name] for r in responses: diff --git a/text-generation-inference/tests/server/test_decode.py b/text-generation-inference/tests/server/test_decode.py index eadb7c43b..5bfc6ca97 100644 --- a/text-generation-inference/tests/server/test_decode.py +++ b/text-generation-inference/tests/server/test_decode.py @@ -40,6 +40,7 @@ def _test_decode(config_name, generator, do_sample): "llama": "George Orwell", "mistral": "The sky is black", "qwen2": " I stood in the back yard", + "granite": "Aldous Huxley, Brave New World", }[config_name] assert expected_text in output.text else: @@ -49,5 +50,6 @@ def _test_decode(config_name, generator, do_sample): "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story", "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.", "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", + "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198", }[config_name] assert output.text == expected_text diff --git a/text-generation-inference/tests/server/test_prefill.py b/text-generation-inference/tests/server/test_prefill.py index aec02ffd5..7214a6b6a 100644 --- a/text-generation-inference/tests/server/test_prefill.py +++ b/text-generation-inference/tests/server/test_prefill.py @@ -39,6 +39,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): "llama": [10058, " George"], "mistral": [450, " The"], "qwen2": [358, " I"], + "granite": [429, " -"], }[config_name] else: expectations = { @@ -46,6 +47,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): "llama": [10058, " George"], "mistral": [13, "\n"], "qwen2": [358, " I"], + "granite": [203, "\n"], }[config_name] for g in generations: tokens = g.tokens @@ -80,6 +82,7 @@ def test_prefill_truncate(neuron_model_config): "llama": [" —", " The", " He", " He"], "mistral": [" He", "\n", " He", " He"], "qwen2": [" He", " The", " He", " He"], + "granite": ["\n", "\n", " I", " He"], }[config_name] for i, g in enumerate(generations): tokens = g.tokens From 93e7f31542e87b5341173cdb21420ae267887ca5 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 23 Dec 2024 15:37:13 +0000 Subject: [PATCH 7/7] fix(granite): apply ruff recommendation --- optimum/neuron/models/granite/hlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/neuron/models/granite/hlo.py b/optimum/neuron/models/granite/hlo.py index bfeab595d..d66f12b8d 100644 --- a/optimum/neuron/models/granite/hlo.py +++ b/optimum/neuron/models/granite/hlo.py @@ -325,7 +325,7 @@ def layer( ) if self.neuron_config.fuse_mlp: assert all( - map(lambda x: not (x), [in0_weight, in1_weight, out_weight, in0_scales, in1_scales, out_scales]) + (not (x) for x in [in0_weight, in1_weight, out_weight, in0_scales, in1_scales, out_scales]) ), "in0, in1 and out weights have to be None" in0_weight, in0_scales = mlp_in_weight, mlp_in_scales out_weight, out_scales = mlp_out_weight, mlp_out_scales