Skip to content

Commit

Permalink
allow providing fft plans
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth committed Mar 11, 2024
1 parent 66cf9d9 commit 25dbbbe
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 48 deletions.
21 changes: 7 additions & 14 deletions .github/workflows/UnitTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ on:
branches:
- master
pull_request:
schedule:
- cron: '20 00 1 * *'

jobs:
test:
Expand All @@ -20,23 +18,18 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]

steps:
- uses: actions/checkout@v1.0.0
- uses: actions/checkout@v4
- name: "Set up Julia"
uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}

- name: Cache artifacts
uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1

- run: |
import Pkg
Pkg.add(Pkg.PackageSpec(url="https://github.com/HolyLab/RFFT.jl", rev="ib/add_copy"))
shell: julia --project --color=yes {0}
- name: "Unit Test"
uses: julia-actions/julia-runtest@v1
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ author = ["Tim Holy <[email protected]>", "Jan Weidner <[email protected]>"]
version = "0.7.8"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CatIndices = "aafaddc9-749c-510e-ac4f-586e18779b91"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -15,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RFFT = "3bd9afcd-55df-531a-9b34-dc642dce7b95"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -31,6 +33,7 @@ ImageCore = "0.10"
OffsetArrays = "1.9"
PrecompileTools = "1"
Reexport = "1.1"
RFFT = "0.1.1"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
Statistics = "1"
TiledIteration = "0.2, 0.3, 0.4, 0.5"
Expand Down
79 changes: 79 additions & 0 deletions demo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using ImageFiltering, FFTW, LinearAlgebra, Profile, Random
# using ProfileView
using ComputationalResources

FFTW.set_num_threads(parse(Int, get(ENV, "FFTW_NUM_THREADS", "1")))
BLAS.set_num_threads(parse(Int, get(ENV, "BLAS_NUM_THREADS", string(Threads.nthreads() ÷ 2))))

function benchmark(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
Threads.@threads for mat in mats
frame_filtered = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
for i in axes(mat, 3)
frame = @view mat[:, :, i]
imfilter!(r_cached, frame_filtered, frame, kernel)
end
return
end
end

function test(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
for mat in mats
f1 = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(f1, kernel))
f2 = deepcopy(mat[:, :, 1])
r_noncached = CPU1(Algorithm.FFT())
for i in axes(mat, 3)
frame = @view mat[:, :, i]
@info "imfilter! noncached"
imfilter!(r_noncached, f2, frame, kernel)
@info "imfilter! cached"
imfilter!(r_cached, f1, frame, kernel)
@show f1[1:4] f2[1:4]
f1 f2 || error("f1 !≈ f2")
end
return
end
end

function profile()
Random.seed!(1)
nmats = 10
mats = [rand(Float32, rand(80:100), rand(80:100), rand(2000:3000)) for _ in 1:nmats]
GC.gc(true)

# benchmark(mats)

# for _ in 1:3
# @time "warm run of benchmark(mats)" benchmark(mats)
# end

test(mats)

# Profile.clear()
# @profile benchmark(mats)

# Profile.print(IOContext(stdout, :displaysize => (24, 200)); C=true, combine=true, mincount=100)
# # ProfileView.view()
# GC.gc(true)
end

profile()

using ImageFiltering
using ImageFiltering.RFFT

function mwe()
a = rand(Float64, 10, 10)
out1 = rfft(a)

buf = RFFT.RCpair{Float64}(undef, size(a))
rfft_plan = RFFT.plan_rfft!(buf)
copy!(buf, a)
out2 = complex(rfft_plan(buf))

return out1 out2
end
mwe()
13 changes: 11 additions & 2 deletions src/ImageFiltering.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module ImageFiltering

using FFTW
using RFFT
using ImageCore, FFTViews, OffsetArrays, StaticArrays, ComputationalResources, TiledIteration
# Where possible we avoid a direct dependency to reduce the number of [compat] bounds
# using FixedPointNumbers: Normed, N0f8 # reexported by ImageCore
using ImageCore.MappedArrays
using Statistics, LinearAlgebra
using Base: Indices, tail, fill_to_length, @pure, depwarn, @propagate_inbounds
import Base: copy!
using OffsetArrays: IdentityUnitRange # using the one in OffsetArrays makes this work with multiple Julia versions
using SparseArrays # only needed to fix an ambiguity in borderarray
using Reexport
Expand All @@ -30,7 +32,8 @@ export Kernel, KernelFactors,
imgradients, padarray, centered, kernelfactors, reflect,
freqkernel, spacekernel,
findlocalminima, findlocalmaxima,
blob_LoG, BlobLoG
blob_LoG, BlobLoG,
planned_fft

FixedColorant{T<:Normed} = Colorant{T}
StaticOffsetArray{T,N,A<:StaticArray} = OffsetArray{T,N,A}
Expand All @@ -50,10 +53,16 @@ function Base.transpose(A::StaticOffsetArray{T,2}) where T
end

module Algorithm
import FFTW
# deliberately don't export these, but it's expected that they
# will be used as Algorithm.FFT(), etc.
abstract type Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg
plan1::Union{Function,Nothing}
plan2::Union{Function,Nothing}
plan3::Union{Function,Nothing}
end
FFT() = FFT(nothing, nothing, nothing)
"Filter using a direct algorithm" struct FIR <: Alg end
"Cache-efficient filtering using tiles" struct FIRTiled{N} <: Alg
tilesize::Dims{N}
Expand Down
54 changes: 51 additions & 3 deletions src/imfilter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
for I in CartesianIndices(axes(kern))
krn[I] = kern[I]
end
Af = filtfft(A, krn)
Af = filtfft(A, krn, r.settings.plan1, r.settings.plan2, r.settings.plan3)
if map(first, axes(out)) == map(first, axes(Af))
R = CartesianIndices(axes(out))
copyto!(out, R, Af, R)
Expand All @@ -837,13 +837,61 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
src = view(FFTView(Af), axes(dest)...)
copyto!(dest, src)
end
out
return out
end

function buffered_planned_rfft(a::AbstractArray{T}) where {T}
buf = RFFT.RCpair{T}(undef, size(a))
plan = RFFT.plan_rfft!(buf; flags=FFTW.MEASURE)
return function (arr::AbstractArray{T}) where {T}
copy!(buf, OffsetArrays.no_offset_view(arr))
return plan(buf)

Check warning on line 848 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L843-L848

Added lines #L843 - L848 were not covered by tests
end
end
function buffered_planned_irfft(a::AbstractArray{T}) where {T}
buf = RFFT.RCpair{T}(undef, size(a))
plan = RFFT.plan_irfft!(buf; flags=FFTW.MEASURE)
return function (arr::AbstractArray{T}) where {T}
copy!(buf, OffsetArrays.no_offset_view(arr))
return plan(buf)

Check warning on line 856 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L851-L856

Added lines #L851 - L856 were not covered by tests
end
end

function planned_fft(A::AbstractArray{T,N},

Check warning on line 860 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L860

Added line #L860 was not covered by tests
kernel::ProcessedKernel,
border::BorderSpecAny=Pad(:replicate)
) where {T<:AbstractFloat,N}
bord = border(kernel, A, Algorithm.FFT())
_A = padarray(T, A, bord)
bfp1 = buffered_planned_rfft(_A)
kern = samedims(_A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(_A))))
bfp2 = buffered_planned_rfft(krn)
bfp3 = buffered_planned_irfft(_A)
return Algorithm.FFT(bfp1, bfp2, bfp3)

Check warning on line 871 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L864-L871

Added lines #L864 - L871 were not covered by tests
end
planned_fft(A::AbstractArray, kernel, border::AbstractString) = planned_fft(A, kernel, borderinstance(border))
planned_fft(A::AbstractArray, kernel::Union{ArrayLike,Laplacian}, border::BorderSpecAny) = planned_fft(A, factorkernel(kernel), border)

Check warning on line 874 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L873-L874

Added lines #L873 - L874 were not covered by tests

function filtfft(A, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function)
B = complex(planned_rfft1(A))
B .*= conj!(complex(planned_rfft2(krn)))
return real(planned_irfft(complex(B)))

Check warning on line 879 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L876-L879

Added lines #L876 - L879 were not covered by tests
end
# TODO: this does not work. See TODO below
function filtfft(A::AbstractArray{C}, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function) where {C<:Colorant}
Av, dims = channelview_dims(A)
kernrs = kreshape(C, krn)
B = complex(planned_rfft1(Av, dims)) # TODO: dims is not supported by planned_rfft1
B .*= conj!(complex(planned_rfft2(kernrs)))
Avf = real(planned_irfft(complex(B)))
return colorview(base_colorant_type(C){eltype(Avf)}, Avf)

Check warning on line 888 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L882-L888

Added lines #L882 - L888 were not covered by tests
end
filtfft(A, krn, ::Nothing, ::Nothing, ::Nothing) = filtfft(A, krn)
function filtfft(A, krn)
B = rfft(A)
B .*= conj!(rfft(krn))
irfft(B, length(axes(A, 1)))
return irfft(B, length(axes(A, 1)))
end
function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant}
Av, dims = channelview_dims(A)
Expand Down
40 changes: 25 additions & 15 deletions test/2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ using ImageFiltering: borderinstance
end
end

function supported_algs(img, kernel, border)
if eltype(img) isa AbstractFloat
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT(), planned_fft(img, kernel, border))
else
# TODO: extend planned_fft to support other types
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
end
end

@testset "FIR/FFT" begin
f32type(img) = f32type(eltype(img))
f32type(::Type{C}) where {C<:Colorant} = base_colorant_type(C){Float32}
Expand All @@ -50,6 +59,7 @@ end
# Dense inseparable kernel
kern = [0.1 0.2; 0.4 0.5]
kernel = OffsetArray(kern, -1:0, 1:2)
border = Inner()
for img in (imgf, imgi, imgg, imgc)
targetimg = zeros(typeof(img[1]*kern[1]), size(img))
targetimg[3:4,2:3] = rot180(kern) .* img[3,4]
Expand All @@ -66,7 +76,7 @@ end
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
fill!(ret, zero(eltype(ret)))
@test @inferred(imfilter!(ret, img, kernel, border)) targetimg
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
@test @inferred(imfilter(img, (kernel,), border, alg)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
Expand All @@ -76,12 +86,12 @@ end
@test_throws MethodError imfilter!(CPU1(Algorithm.FIR()), ret, img, kernel, border, Algorithm.FFT())
end
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
@test @inferred(imfilter(img, kernel, Inner(), alg)) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) float32.(targetimg_inner)
@test @inferred(imfilter(CPU1(alg), img, kernel, Inner())) targetimg_inner
@test @inferred(imfilter(img, kernel, border)) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg_inner)
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg_inner)
@test @inferred(imfilter(CPU1(alg), img, kernel, border)) targetimg_inner
end
end
# Factored kernel
Expand All @@ -96,7 +106,7 @@ end
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
@test @inferred(imfilter(img, kernel, border)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
end
Expand All @@ -106,7 +116,7 @@ end
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, Inner(), alg)) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) float32.(targetimg_inner)
end
Expand All @@ -122,7 +132,7 @@ end
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
@test @inferred(imfilter(img, kernel, border)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
if alg == Algorithm.FFT() && eltype(img) == Int
@test @inferred(imfilter(Float64, img, kernel, border, alg)) targetimg
else
Expand All @@ -134,7 +144,7 @@ end
targetimg_inner = OffsetArray(targetimg[2:end-1, 2:end-1], 2:4, 2:6)
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
if alg == Algorithm.FFT() && eltype(img) == Int
@test @inferred(imfilter(Float64, img, kernel, Inner(), alg)) targetimg_inner
else
Expand Down Expand Up @@ -184,7 +194,7 @@ end
targetimg = target1(img, kern, border)
@test @inferred(imfilter(img, kernel, border)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
end
Expand All @@ -195,7 +205,7 @@ end
targetimg = zerona!(copy(targetimg0))
@test @inferred(zerona!(imfilter(img, kernel, border))) targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) float32.(targetimg)
end
Expand All @@ -208,7 +218,7 @@ end
targetimg = target1(img, kern, border)
@test @inferred(imfilter(img, kernel, border)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
end
Expand All @@ -219,7 +229,7 @@ end
targetimg = zerona!(copy(targetimg0))
@test @inferred(zerona!(imfilter(img, kernel, border))) targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) float32.(targetimg)
end
Expand Down
Loading

0 comments on commit 25dbbbe

Please sign in to comment.