Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWKV changes #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ git clone https://github.com/ContextualAI/gritlm`
cd gritlm
pip install -e .
cd gritlm
````
```

Below are easy examples for getting started:

Expand Down
42 changes: 38 additions & 4 deletions evaluation/eval_mteb.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,38 @@
'query': 'Given a query on COVID-19, retrieve documents that answer the query',
'corpus': '',
},
'LegalSummarization': {
'query': 'Given legal terms, retrieve a passage with more details about the legal terms',
'corpus': '',
},
'LegalBenchConsumerContractsQA': {
'query': 'Given a legal query by a consumer, retrieve relevant legal passages of contracts that answer the query',
'corpus': '',
},
'LegalBenchCorporateLobbying': {
'query': 'Given a one-sentence summary of a legal article, retrieve the legal article',
'corpus': '',
},
'AILACasedocs': {
'query': 'Given a legal scenario, retrieve the most relevant case document',
'corpus': '',
},
'AILAStatutes': {
'query': 'Given a legal situation, retrieve the most relevant statute',
'corpus': '',
},
'LeCaRDv2': {
'query': 'Given a legal scenario, retrieve the document that best matches or is most relevant to it',
'corpus': '',
},
'LegalQuAD': {
'query': 'Given a legal question in German, retrieve the most relevant legal document in German',
'corpus': '',
},
'GerDaLIRSmall': {
'query': 'Given a legal passage in German, retrieve the most relevant legal document in German',
'corpus': '',
},
},
'STS': {
'STS12': "Retrieve semantically similar text.",
Expand Down Expand Up @@ -1098,6 +1130,7 @@ def get_args():
parser.add_argument('--embedding_head', default=None, type=str)
parser.add_argument('--pooling_method', default='mean', type=str)
parser.add_argument('--save_qrels', action='store_true')
parser.add_argument('--second_to_last_hidden', action='store_true')
parser.add_argument('--top_k', default=10, type=int)
return parser.parse_args()

Expand All @@ -1121,8 +1154,9 @@ def get_args():
"torch_dtype": DTYPE_TO_TORCH_DTYPE.get(args.dtype, torch.bfloat16),
"mode": "embedding",
"pooling_method": args.pooling_method,
"attn_implementation": args.attn_implementation,
"attn": args.attn,
"attn_implementation": None if any(x in args.model_name_or_path.lower() for x in ["v5-Eagle-7B-HF", "rwkv"]) else args.attn_implementation,
"attn": None if any(x in args.model_name_or_path.lower() for x in ["v5-Eagle-7B-HF", "rwkv"]) else args.attn,
"second_to_last_hidden": args.second_to_last_hidden,
}

if args.pipeline_parallel:
Expand All @@ -1135,7 +1169,7 @@ def get_args():
elif any([x in args.model_name_or_path for x in ["bge"]]):
assert kwargs["pooling_method"] == "cls"

if args.pooling_method == "lasttoken":
if (args.pooling_method == "lasttoken") and ("v5-Eagle-7B-HF" not in args.model_name_or_path):
kwargs["embed_eos"] = "</e>"
if args.embedding_head:
kwargs["projection"] = args.embedding_head
Expand All @@ -1159,7 +1193,7 @@ def get_args():
kwargs["tasks"] = args.task_names.split(",")
elif args.task_types:
kwargs["task_types"] = args.task_types.split(",")
tasks = [(t.description["name"], t.description["type"]) for t in MTEB(**kwargs).tasks]
tasks = [(t.metadata.name, t.metadata.type) for t in MTEB(**kwargs).tasks]

if args.max_length is not None:
model.encode = partial(model.encode, max_length=args.max_length)
Expand Down
51 changes: 51 additions & 0 deletions evaluation/eval_mteb_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import os
from functools import partial

from mteb import MTEB
from sentence_transformers import SentenceTransformer
import torch


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="GritLM/GritLM-7B", type=str)
parser.add_argument('--attn_implementation', default='sdpa', type=str, help="eager/sdpa/flash_attention_2")
parser.add_argument('--attn', default='bbcc', type=str, help="only first two letters matter for embedding")
parser.add_argument('--task_types', default=None, help="Comma separated. Default is None i.e. running all tasks")
parser.add_argument('--task_names', default=None, help="Comma separated. Default is None i.e. running all tasks")
parser.add_argument('--instruction_set', default="e5", type=str, help="Instructions to use")
parser.add_argument('--instruction_format', default="gritlm", type=str, help="Formatting to use")
parser.add_argument('--no_instruction', action='store_true', help="Do not use instructions")
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--max_length', default=None, type=int)
parser.add_argument('--num_shots', default=None, type=int)
parser.add_argument('--dtype', default='bfloat16', type=str)
parser.add_argument('--output_folder', default="results", type=str)
parser.add_argument('--overwrite_results', action='store_true')
parser.add_argument('--pipeline_parallel', action='store_true')
parser.add_argument('--embedding_head', default=None, type=str)
parser.add_argument('--pooling_method', default='mean', type=str)
parser.add_argument('--save_qrels', action='store_true')
parser.add_argument('--second_to_last_hidden', default='False', action='store_true')
parser.add_argument('--top_k', default=10, type=int)
return parser.parse_args()

if __name__ == '__main__':
args = get_args()

task_name = args.task_names
output_folder = args.output_folder


model = SentenceTransformer(args.model_name_or_path)
eval_splits = ["test" if task_name not in ['MSMARCO', 'Ko-miracl'] else 'dev']
evaluation = MTEB(tasks=[task_name], task_langs=['en'])
evaluation.run(
model,
output_folder=output_folder,
eval_splits=eval_splits,
batch_size=args.batch_size,
save_qrels=args.save_qrels,
overwrite_results=args.overwrite_results,
)
53 changes: 42 additions & 11 deletions gritlm/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
is_inference: bool = True,
embed_eos: str = "",
attn: str = 'bbcc',
second_to_last_hidden: bool = False,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs, # Passed to the model, e.g. `attn_implementation`, `torch_dtype` etc.
) -> None:
super().__init__()
Expand All @@ -36,6 +38,8 @@ def __init__(
self.embedding_attr = 'model'
elif hasattr(self.model, 'transformer'): # GPT-Neo & GPT-J
self.embedding_attr = 'transformer'
elif hasattr(self.model, 'rwkv'): # RWKV
self.embedding_attr = 'rwkv'
else:
raise ValueError("Could not find attribute to use for embedding: ", self.model)

Expand All @@ -46,8 +50,8 @@ def __init__(
) if projection is not None else None
self.normalized = normalized
self.pooling_method = pooling_method

self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.second_to_last_hidden = second_to_last_hidden
self.device = device
self.num_gpus = 1
self.embed_eos = embed_eos
self.attn = attn
Expand All @@ -58,7 +62,7 @@ def __init__(

if is_inference:
# Padding side right is necessary for `embed_instruction` to index correctly
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side='right')
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side='right', trust_remote_code=True)
if not(self.tokenizer.pad_token) and self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
print('Set pad token to eos token: ' + self.tokenizer.pad_token)
Expand All @@ -67,8 +71,8 @@ def __init__(
self.model.eval()
if not("device_map" in kwargs):
self.model.to(self.device)
# Parallelize embedding model
if mode == 'embedding':
# Parallelize embedding model unless a specific device is specified, e.g. `cuda:1`
if (mode == 'embedding') and ((isinstance(self.device, str) is False) or (":" not in self.device)):
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------Using {self.num_gpus} data-parallel GPUs----------")
Expand Down Expand Up @@ -127,13 +131,24 @@ def encode(
).to(self.device)

if (self.attn is not None) and (self.attn[:2] == 'bb'):
inputs["is_causal"] = False
if (hasattr(self.model, "config")) and ("olmo" in self.model.config.model_type.lower()):
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
inputs["attention_bias"] = _prepare_4d_attention_mask_for_sdpa(
inputs["attention_mask"], inputs["input_ids"].dtype
)
else:
inputs["is_causal"] = False
if get_cache:
inputs['use_cache'] = True
if self.second_to_last_hidden:
inputs['output_hidden_states'] = True
outputs = (
getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model
)(**inputs)
last_hidden_state = outputs[0]
if self.second_to_last_hidden:
last_hidden_state = outputs.hidden_states[-2]
else:
last_hidden_state = outputs[0]
if get_cache:
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`
assert len(all_kv_caches) == 0, "Can only get cache for one batch at a time"
Expand All @@ -152,6 +167,20 @@ def encode(
)["input_ids"]
inputs['attention_mask'][:, :len(instruction_tokens)] = 0
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'], recast=recast)
# Check if nan
if torch.isnan(embeddings).any():
print("NaN detected in embeddings")
print("Instruction: ", instruction)
print("Sentences: ", sentences_batch)
print("emb", embeddings)
raise ValueError("NaN detected in embeddings")
# Check if inf
if torch.isinf(embeddings).any():
print("Inf detected in embeddings")
print("Instruction: ", instruction)
print("Sentences: ", sentences_batch)
print("emb", embeddings)
raise ValueError("Inf detected in embeddings")
# Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
if self.normalized:
in_dtype = embeddings.dtype
Expand All @@ -166,12 +195,12 @@ def encode(
all_embeddings = (
torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
)
# Check if nan
#if torch.isnan(all_embeddings).any():
# print("NaN detected in all embeddings")
if input_was_string:
all_embeddings = all_embeddings[0]
if get_cache:
# all_kv_caches = (
# torch.stack(all_kv_caches, dim=0) if convert_to_tensor else np.concatenate(all_kv_caches, axis=0)
# )
return all_embeddings, all_kv_caches
return all_embeddings

Expand Down Expand Up @@ -206,9 +235,11 @@ def pooling(
# as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
elif self.pooling_method in ['mean', 'weightedmean']:
elif self.pooling_method in ['mean', 'weightedmean', 'meannorm', 'weightedmeannorm']:
if self.pooling_method == 'weightedmean':
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
if self.pooling_method in ('meannorm', 'weightedmeannorm'):
hidden_state = torch.nn.functional.normalize(hidden_state, dim=-1)
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
embedding = s / d
Expand Down
9 changes: 9 additions & 0 deletions gritlm/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __call__(self, features):
max_length=self.query_max_len,
return_tensors="pt",
add_special_tokens=False, # BOS / EOS is already in the prompt
return_token_type_ids=False,
)
features["passage"] = self.tokenizer(
passage,
Expand All @@ -243,7 +244,11 @@ def __call__(self, features):
max_length=self.passage_max_len,
return_tensors="pt",
add_special_tokens=False, # BOS / EOS is already in the prompt
return_token_type_ids=False,
)
for x in ["input_ids", "attention_mask"]:
features["query"][x] = features["query"][x][:, :self.query_max_len]
features["passage"][x] = features["passage"][x][:, :self.passage_max_len]

if generative[0] is not None:
features["generative"] = self.tokenizer(
Expand All @@ -253,7 +258,11 @@ def __call__(self, features):
max_length=self.generative_max_len,
return_tensors="pt",
add_special_tokens=False, # BOS / EOS is already in the prompt
return_token_type_ids=False,
)
features["generative"]["input_ids"] = features["generative"]["input_ids"][:, :self.generative_max_len]
features["generative"]["attention_mask"] = features["generative"]["attention_mask"][:, :self.generative_max_len]

features["generative"]["labels"] = features["generative"]["input_ids"].clone()
# Do not mask out the first token as it is always something & could be the pad token id (bos)
features["generative"]["labels"][:,1:][features["generative"]["labels"][:,1:] == self.tokenizer.pad_token_id] = -100
Expand Down
1 change: 1 addition & 0 deletions gritlm/training/gradcache_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def model_call(self, model, model_input): return model(model_input)

elif self.mode == 'generative':
tr_loss_step = loss_gen
exit()
### MODIFIED END ###

if (
Expand Down
Loading