Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Group Query Attention support with OV base OPs #28163

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

sgbihu
Copy link

@sgbihu sgbihu commented Dec 20, 2024

Details:

  • Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on CPU, Phi3 can work with iGPU)

Test scripts

import onnxruntime as rt
import os
import numpy as np
import time

import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
from transformers import PreTrainedTokenizerFast


test_lama3 = False
test_phi3 = True
if test_phi3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')

if test_lama3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx')

so = rt.SessionOptions()
# so.log_severity_level = 3

# sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider'])
sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}])
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath)

# print(sess.get_device())
# for name in sess.get_inputs():
#     print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}")
outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))


# Assuming the model has 32 layers and each layer has a key and value state
# Phi3
def get_phi3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 32
    sequence_length = 2048
    hidden_size = 96
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

# lama
def get_llama3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 8
    sequence_length = 2048
    hidden_size = 128
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

if test_phi3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param()

if test_lama3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param()

# Initialize past_key_values with zeros
cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)

# print("Output names: ", outputs[0].type.data)

def create_present_state_binding(binding, outputs):
    outputMap={}
    for output in outputs:
        shapes = []
        for item in output.shape:
            if isinstance(item, str):
                if 'batch_size' in item:
                    shapes.append(batch_size)
                elif 'sequence_length' in item:
                    if output.name == 'logits':
                        shapes.append(len(inputToken))
                    else:
                        shapes.append(sequence_length)
                elif 'hidden_size' in item:
                    shapes.append(hidden_size)
                elif 'num_heads' in item:
                    shapes.append(num_heads)
                else:
                    raise ValueError(f"Unknown dimension: {item}")
            else:
                shapes.append(item)
            
        present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32)
        binding.bind_ortvalue_output(output.name, present_state)
        outputMap[output.name] = present_state
    return outputMap

def rebind_inputs(lastOutput, binding):
    for index in range(num_layers):
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key'])
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value'])
    return binding

def init_input_with_binding(binding):
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state)
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state)
    return binding

def reinit_input_bindings(bindings, lastOutput):
    newOutput = create_present_state_binding(bindings, lastOutput)
    binding = rebind_inputs(lastOutput, bindings)
    return binding, newOutput

def create_numpy_inputs(inputToken):
    tokenLen = len(inputToken)
    npinput_ids = np.array([inputToken], dtype=np.int64)
    npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64)
    return npinput_ids, npattention_mask


def init_ortinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    ids, mask = create_numpy_inputs(inputToken)
    flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids)
    flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask)
    return flattened_past_key_values

def init_npinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken)
    return flattened_past_key_values

def init_bindinginput(inputToken):
    binding = sess.io_binding()
    binding = init_input_with_binding(binding)
    
    ids, mask = create_numpy_inputs(inputToken)
    binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask))
    binding.bind_ortvalue_input(f'input_ids',  rt.OrtValue.ortvalue_from_numpy(ids))
    return binding


# Question
# The Sun is yellow because

# Phi3
if test_phi3:
    # 450 8991 5692
    # inputToken = [32010, 29871, 13]
    inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001]
    # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010]
# lama3
if test_lama3:
    # 315 1202 7479
    inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29]
    # inputToken = [315]
history_tokens = inputToken

flattened_past_key_values = init_npinput(inputToken)

# flattened_past_key_values = init_ortinput(inputToken)

# binding = init_bindinginput(inputToken)
# lastoutput = create_present_state_binding(binding, outputs)

lastTokenLen = len(inputToken)


# roption = rt.RunOptions()
# roption.add_run_config_entry("gpu_graph_id", "-1")

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# print(np.argmax(results[0].numpy(), axis=-1)[-1])
print(np.argmax(results[0], axis=-1)[-1])

# print(results[0])
# print(output_names[1])
# print(results[1][0][0][0])
# print(results[1][0][0][1])
# print(results[1][0][0][2])
# # print(results[1][0][0][14])
# # print(results[1])
# print(output_names[2])
# # print(results[2])
# print(results[2][0][0][0])
# print(results[2][0][0][1])
# print(results[2][0][0][2])
# print(results[2][0][0][14])
# inputToken.append(450)

# rebind_inputs(lastOutput, binding)

def update_kvcache(inputsMap, results):
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        inputsMap[inputname] = results[index]
    return inputsMap
# lastOutput = create_present_state_binding(binding, sess.get_outputs())

# flattened_past_key_values = update_kvcache(flattened_past_key_values, results)

for index in range(len(output_names)):
    if not output_names[index].startswith('present'):
        continue
    # print(f'{output_names[index]}: {results[index].shape}')
    outputname = output_names[index]
    inputname = outputname.replace('present', 'past_key_values')
    flattened_past_key_values[inputname] = results[index]
if test_phi3:
    inputToken = [450]

if test_lama3:
    inputToken = [315]
history_tokens += inputToken

npinput_ids = np.array([inputToken], dtype=np.int64)
npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64)
print(f"lastTokenLen:{lastTokenLen}")

# attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask)
# input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids)
# binding.bind_ortvalue_input(f'attention_mask', attention_mask)
# binding.bind_ortvalue_input(f'input_ids', input_ids)
# flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask)
# flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids)
# flattened_past_key_values[f'attention_mask'] = attention_mask
# flattened_past_key_values[f'input_ids'] = input_ids
flattened_past_key_values[f'attention_mask'] = npattention_mask
flattened_past_key_values[f'input_ids'] = npinput_ids
# print(flattened_past_key_values)

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# Results:  [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)]
# index = 0
# for result in results:
#     print(f'{output_names[index]}: {result.shape}, {result.dtype}')
#     index += 1
print(np.argmax(results[0], axis=-1)[-1])
# print(np.argmax(results[0].numpy(), axis=-1)[-1])


# golden results
# Time cost in ms:  1255.2332878112793
# [30751    13    13  1494  1731   263 29889   372    13 24380    13   450]
# lastTokenLen:12
# Time cost in ms:  1006.781816482544
# [8991]

last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
history_tokens.append(last_generated_token)
NUM_INFERENCE = 15
for i in range(NUM_INFERENCE):
    # update kvcahe
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        flattened_past_key_values[inputname] = results[index]

    # update input token
    flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64)
    flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64)

    before = time.time()
    results = sess.run(output_names, flattened_past_key_values)
    after = time.time()
    print("Time cost in ms: ", (after - before) * 1000)

    last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
    history_tokens.append(last_generated_token)

print(tokenizer.decode(history_tokens))

Tickets:

  • ticket-id

@github-actions github-actions bot added the category: ONNX FE OpenVINO ONNX FrontEnd label Dec 20, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalIntelPR External contributor from Intel label Dec 20, 2024
@slyalin
Copy link
Contributor

slyalin commented Dec 20, 2024

How is it related to #27648?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: ONNX FE OpenVINO ONNX FrontEnd ExternalIntelPR External contributor from Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants