Skip to content

Commit

Permalink
Improve post-processing performance (#10170)
Browse files Browse the repository at this point in the history
* Use multiplication instead of division
* Add fast path when denormalizing all or none of the images
  • Loading branch information
soof-golan authored Dec 10, 2024
1 parent c9e4fab commit 22d3a82
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
`np.ndarray` or `torch.Tensor`:
The denormalized image array.
"""
return (images / 2 + 0.5).clamp(0, 1)
return (images * 0.5 + 0.5).clamp(0, 1)

@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
Expand Down Expand Up @@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:

return image

def _denormalize_conditionally(
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
) -> torch.Tensor:
r"""
Denormalize a batch of images based on a condition list.
Args:
images (`torch.Tensor`):
The input image tensor.
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
value of `do_normalize` in the `VaeImageProcessor` config.
"""
if do_denormalize is None:
return self.denormalize(images) if self.config.do_normalize else images

return torch.stack(
[self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
)

def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
Expand Down Expand Up @@ -752,12 +772,7 @@ def postprocess(
if output_type == "latent":
return image

if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]

image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)
image = self._denormalize_conditionally(image, do_denormalize)

if output_type == "pt":
return image
Expand Down Expand Up @@ -966,12 +981,7 @@ def postprocess(
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]

image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)
image = self._denormalize_conditionally(image, do_denormalize)

image = self.pt_to_numpy(image)

Expand Down

0 comments on commit 22d3a82

Please sign in to comment.