You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In PyTorch there is a module called PixelShuffle .
I created as mall implementation of it that currently only supports 4D Tensors (I tried to follow the implementation format of the library) and wanted to share it to check if someday it could be added to the crate.
This is the full code:
use burn::{config::Config, module::Module, prelude::Backend, tensor::Tensor};#[derive(Config,Debug)]pubstructPixelShuffleConfig{#[config(default = "2")]upscale_factor:usize}#[derive(Module,Debug,Clone)]pubstructPixelShuffle{upscale_factor:usize}implPixelShuffleConfig{pubfninit(&self) -> PixelShuffle{PixelShuffle{upscale_factor:self.upscale_factor}}}implPixelShuffle{pubfnforward<B:Backend>(&self,input:Tensor<B,4>) -> Tensor<B,4>{letmut dims = input.dims();
dims.reverse();let c = dims[2];let h = dims[1];let w = dims[0];if c % (self.upscale_factor*self.upscale_factor) != 0{panic!("pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of upscale_factor, but input.size(-3)={c} is not divisible by {}",self.upscale_factor *self.upscale_factor)}let oc = c / (self.upscale_factor*self.upscale_factor);let oh = h *self.upscale_factor;let ow = w *self.upscale_factor;let x = input.reshape([dims[3], oc,self.upscale_factor,self.upscale_factor, h, w]);let x = x.permute([0, -5, -2, -4, -1, -3]);
x.reshape([dims[3], oc, oh, ow])}}
And here are some test that I made comparing the outputs with the ones from the pytorch implementation:
In PyTorch there is a module called PixelShuffle .
I created as mall implementation of it that currently only supports 4D Tensors (I tried to follow the implementation format of the library) and wanted to share it to check if someday it could be added to the crate.
This is the full code:
And here are some test that I made comparing the outputs with the ones from the pytorch implementation:
The original implementation in c++ is here
The text was updated successfully, but these errors were encountered: