Skip to content

Commit

Permalink
Add --lost-dist-impl argument to pick different distributed loss impl…
Browse files Browse the repository at this point in the history
…ementations
  • Loading branch information
rwightman committed Dec 4, 2024
1 parent 9451ba8 commit aeaf2a0
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def create_loss(args):
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
)

return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
Expand Down
39 changes: 22 additions & 17 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -102,8 +104,14 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
image_features,
text_features,
local_loss=self.local_loss,
gather_with_grad=self.gather_with_grad,
rank=self.rank,
world_size=self.world_size,
use_horovod=self.use_horovod,
)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
Expand Down Expand Up @@ -158,12 +166,11 @@ def __init__(
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):

clip_loss = torch.tensor(0)

if self.clip_loss_weight:
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
else:
clip_loss = torch.tensor(0, device=logits.device)

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
Expand Down Expand Up @@ -316,19 +323,17 @@ class SigLipLoss(nn.Module):
"""
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
impl='bidir',
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: Optional[str] = None,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.impl = impl
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
Expand Down Expand Up @@ -361,7 +366,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
if self.impl == 'bidir':
if self.dist_impl == 'bidir':
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features_to_left = text_features
Expand Down Expand Up @@ -396,7 +401,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
logit_bias,
negative_only=True,
)
elif self.impl == "shift":
elif self.dist_impl == "shift":
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features
Expand All @@ -414,7 +419,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
negative_only=True,
)
text_features_to_right = text_features_from_left
elif self.impl == "reduce":
elif self.dist_impl == "reduce":
for i in range(self.world_size):
text_from_other = torch.distributed.nn.all_reduce(
text_features * (self.rank == i),
Expand All @@ -427,7 +432,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
logit_bias,
negative_only=True,
)
elif self.impl == "gather":
elif self.dist_impl == "gather":
all_text = torch.distributed.nn.all_gather(text_features)
for i in range(self.world_size):
loss += float(i != self.rank) * self._loss(
Expand Down
6 changes: 6 additions & 0 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ def parse_args(args):
action="store_true",
help='Use SigLip (sigmoid) loss.'
)
parser.add_argument(
"--loss-dist-impl",
default=None,
type=str,
help='A string to specify a specific distributed loss implementation.'
)

args = parser.parse_args(args)

Expand Down

0 comments on commit aeaf2a0

Please sign in to comment.