Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use free_table as a mask tensor #1086

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
Expand Down Expand Up @@ -88,12 +88,10 @@ def update_for_prefill(
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

nb = num_blocks[i]
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
Expand All @@ -103,7 +101,6 @@ def update_for_prefill(
all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
Expand All @@ -127,16 +124,16 @@ def update_for_decode(
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if slot_offset_in_block[i] == 0:
# need a new block:
b_idx = start_block_idx[i]
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
Expand Down Expand Up @@ -196,7 +193,7 @@ def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
Expand All @@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
updated_table = torch.zeros_like(beam_idx)
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
nb = num_blocks[i]
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
updated_table[i] = self.block_tables[i][nb - 1]
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
Expand All @@ -235,4 +234,6 @@ def crop(self, maximum_length: int):
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1
Loading