Skip to content

Latest commit

 

History

History
62 lines (45 loc) · 1.9 KB

README.md

File metadata and controls

62 lines (45 loc) · 1.9 KB

UNet

Pytorch model from Pytorch-UNet.

Contributors

Requirements

Now TensorRT 8.x is supported and you can use it. The key cause of the previous bug is the pooling layer Stride setting problem.

Build and Run

  1. Generate .wts
cp {path-of-tensorrtx}/unet/gen_wts.py Pytorch-UNet/
cd Pytorch-UNet/
wget https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth
python gen_wts.py unet_carvana_scale0.5_epoch2.pth
  1. Generate TensorRT engine
cd tensorrtx/unet/
mkdir build
cd build
cmake ..
make
cp {path-of-Pytorch-UNet}/unet.wts .
./unet -s
  1. Run inference
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg
./unet -d 0cdf5b5d0ce1_01.jpg
  1. Check result.jpg

Benchmark

Pytorch TensorRT FP32 TensorRT FP16
816x672 816x672 816x672
58ms 43ms (batchsize 8) 14ms (batchsize 8)

More Information

See the readme in home page.