Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: enable reusing fft plans #271

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions .github/workflows/UnitTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ on:
branches:
- master
pull_request:
schedule:
- cron: '20 00 1 * *'

jobs:
test:
Expand All @@ -20,23 +18,18 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]

steps:
- uses: actions/checkout@v1.0.0
- uses: actions/checkout@v4
- name: "Set up Julia"
uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}

- name: Cache artifacts
uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1

- run: |
import Pkg
Pkg.add(Pkg.PackageSpec(url="https://github.com/HolyLab/RFFT.jl", rev="ib/add_copy"))
shell: julia --project --color=yes {0}

- name: "Unit Test"
uses: julia-actions/julia-runtest@v1
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ author = ["Tim Holy <[email protected]>", "Jan Weidner <[email protected]>"]
version = "0.7.8"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
IanButterworth marked this conversation as resolved.
Show resolved Hide resolved
CatIndices = "aafaddc9-749c-510e-ac4f-586e18779b91"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -15,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RFFT = "3bd9afcd-55df-531a-9b34-dc642dce7b95"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -31,6 +33,7 @@ ImageCore = "0.10"
OffsetArrays = "1.9"
PrecompileTools = "1"
Reexport = "1.1"
RFFT = "0.1.1"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
Statistics = "1"
TiledIteration = "0.2, 0.3, 0.4, 0.5"
Expand Down
67 changes: 67 additions & 0 deletions demo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using ImageFiltering, FFTW, LinearAlgebra, Profile, Random
IanButterworth marked this conversation as resolved.
Show resolved Hide resolved
using ComputationalResources

FFTW.set_num_threads(parse(Int, get(ENV, "FFTW_NUM_THREADS", "1")))
BLAS.set_num_threads(parse(Int, get(ENV, "BLAS_NUM_THREADS", string(Threads.nthreads() ÷ 2))))

function benchmark_new(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
Threads.@threads for mat in mats
frame_filtered = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
for i in axes(mat, 3)
frame = @view mat[:, :, i]
imfilter!(r_cached, frame_filtered, frame, kernel)
end
return
end
end
function benchmark_old(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
Threads.@threads for mat in mats
frame_filtered = deepcopy(mat[:, :, 1])
r_noncached = CPU1(Algorithm.FFT())
for i in axes(mat, 3)
frame = @view mat[:, :, i]
imfilter!(r_noncached, frame_filtered, frame, kernel)
end
return
end
end

function test(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
for mat in mats
f1 = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(f1, kernel))
f2 = deepcopy(mat[:, :, 1])
r_noncached = CPU1(Algorithm.FFT())
for i in axes(mat, 3)
imfilter!(r_noncached, f2, deepcopy(mat[:, :, i]), kernel)
imfilter!(r_cached, f1, deepcopy(mat[:, :, i]), kernel)
@show f1[1:4] f2[1:4]
f1 ≈ f2 || error("f1 !≈ f2")
end
return
end
end

function run()
Random.seed!(1)
nmats = 10
mats = [rand(Float64, rand(80:100), rand(80:100), rand(2000:3000)) for _ in 1:nmats]

benchmark_new(mats)
for _ in 1:3
@time "warm run of benchmark_new(mats)" benchmark_new(mats)
end

benchmark_old(mats)
for _ in 1:3
@time "warm run of benchmark_old(mats)" benchmark_old(mats)
end

test(mats)
end

run()
13 changes: 11 additions & 2 deletions src/ImageFiltering.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module ImageFiltering

using FFTW
using RFFT
using ImageCore, FFTViews, OffsetArrays, StaticArrays, ComputationalResources, TiledIteration
# Where possible we avoid a direct dependency to reduce the number of [compat] bounds
# using FixedPointNumbers: Normed, N0f8 # reexported by ImageCore
using ImageCore.MappedArrays
using Statistics, LinearAlgebra
using Base: Indices, tail, fill_to_length, @pure, depwarn, @propagate_inbounds
import Base: copy!
using OffsetArrays: IdentityUnitRange # using the one in OffsetArrays makes this work with multiple Julia versions
using SparseArrays # only needed to fix an ambiguity in borderarray
using Reexport
Expand All @@ -30,7 +32,8 @@ export Kernel, KernelFactors,
imgradients, padarray, centered, kernelfactors, reflect,
freqkernel, spacekernel,
findlocalminima, findlocalmaxima,
blob_LoG, BlobLoG
blob_LoG, BlobLoG,
planned_fft

FixedColorant{T<:Normed} = Colorant{T}
StaticOffsetArray{T,N,A<:StaticArray} = OffsetArray{T,N,A}
Expand All @@ -50,10 +53,16 @@ function Base.transpose(A::StaticOffsetArray{T,2}) where T
end

module Algorithm
import FFTW
# deliberately don't export these, but it's expected that they
# will be used as Algorithm.FFT(), etc.
abstract type Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg
plan1::Union{Function,Nothing}
plan2::Union{Function,Nothing}
plan3::Union{Function,Nothing}
end
FFT() = FFT(nothing, nothing, nothing)
"Filter using a direct algorithm" struct FIR <: Alg end
"Cache-efficient filtering using tiles" struct FIRTiled{N} <: Alg
tilesize::Dims{N}
Expand Down
54 changes: 51 additions & 3 deletions src/imfilter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@
for I in CartesianIndices(axes(kern))
krn[I] = kern[I]
end
Af = filtfft(A, krn)
Af = filtfft(A, krn, r.settings.plan1, r.settings.plan2, r.settings.plan3)
if map(first, axes(out)) == map(first, axes(Af))
R = CartesianIndices(axes(out))
copyto!(out, R, Af, R)
Expand All @@ -837,13 +837,61 @@
src = view(FFTView(Af), axes(dest)...)
copyto!(dest, src)
end
out
return out
end

function buffered_planned_rfft(a::AbstractArray{T}) where {T}
buf = RFFT.RCpair{T}(undef, size(a))
plan = RFFT.plan_rfft!(buf; flags=FFTW.MEASURE)
return function (arr::AbstractArray{T}) where {T}
copy!(buf, OffsetArrays.no_offset_view(arr))
return plan(buf)

Check warning on line 848 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L843-L848

Added lines #L843 - L848 were not covered by tests
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No test coverage here and below

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I hadn't noticed this. Sadly it turns out the new functionality isn't passing yet..

end
function buffered_planned_irfft(a::AbstractArray{T}) where {T}
buf = RFFT.RCpair{T}(undef, size(a))
plan = RFFT.plan_irfft!(buf; flags=FFTW.MEASURE)
return function (arr::AbstractArray{T}) where {T}
copy!(buf, OffsetArrays.no_offset_view(arr))
return plan(buf)

Check warning on line 856 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L851-L856

Added lines #L851 - L856 were not covered by tests
end
end

function planned_fft(A::AbstractArray{T,N},

Check warning on line 860 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L860

Added line #L860 was not covered by tests
kernel::ProcessedKernel,
border::BorderSpecAny=Pad(:replicate)
) where {T<:AbstractFloat,N}
bord = border(kernel, A, Algorithm.FFT())
_A = padarray(T, A, bord)
bfp1 = buffered_planned_rfft(_A)
kern = samedims(_A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(_A))))
bfp2 = buffered_planned_rfft(krn)
bfp3 = buffered_planned_irfft(_A)
return Algorithm.FFT(bfp1, bfp2, bfp3)

Check warning on line 871 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L864-L871

Added lines #L864 - L871 were not covered by tests
end
planned_fft(A::AbstractArray, kernel, border::AbstractString) = planned_fft(A, kernel, borderinstance(border))
planned_fft(A::AbstractArray, kernel::Union{ArrayLike,Laplacian}, border::BorderSpecAny) = planned_fft(A, factorkernel(kernel), border)

Check warning on line 874 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L873-L874

Added lines #L873 - L874 were not covered by tests

function filtfft(A, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function)
B = complex(planned_rfft1(A))
B .*= conj!(complex(planned_rfft2(krn)))
return real(planned_irfft(complex(B)))

Check warning on line 879 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L876-L879

Added lines #L876 - L879 were not covered by tests
end
# TODO: this does not work. See TODO below
function filtfft(A::AbstractArray{C}, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function) where {C<:Colorant}
Av, dims = channelview_dims(A)
kernrs = kreshape(C, krn)
B = complex(planned_rfft1(Av, dims)) # TODO: dims is not supported by planned_rfft1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think dims can be a point of flexibility: these plans are specific to the memory layout of the array. (The planning explores various implementations and picks the fastest discovered; performance is strongly dependent on memory layout, so the choice for one layout may not be the same as another.) You'd have to create a plan specifically to the colorant array-type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this comment to the TODO. I think I'd prefer to fix this in a follow on PR, if that sounds ok

B .*= conj!(complex(planned_rfft2(kernrs)))
Avf = real(planned_irfft(complex(B)))
return colorview(base_colorant_type(C){eltype(Avf)}, Avf)

Check warning on line 888 in src/imfilter.jl

View check run for this annotation

Codecov / codecov/patch

src/imfilter.jl#L882-L888

Added lines #L882 - L888 were not covered by tests
end
filtfft(A, krn, ::Nothing, ::Nothing, ::Nothing) = filtfft(A, krn)
function filtfft(A, krn)
B = rfft(A)
B .*= conj!(rfft(krn))
irfft(B, length(axes(A, 1)))
return irfft(B, length(axes(A, 1)))
end
function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant}
Av, dims = channelview_dims(A)
Expand Down
40 changes: 25 additions & 15 deletions test/2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ using ImageFiltering: borderinstance
end
end

function supported_algs(img, kernel, border)
if eltype(img) isa AbstractFloat
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT(), planned_fft(img, kernel, border))
else
# TODO: extend planned_fft to support other types
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
end
end

@testset "FIR/FFT" begin
f32type(img) = f32type(eltype(img))
f32type(::Type{C}) where {C<:Colorant} = base_colorant_type(C){Float32}
Expand All @@ -50,6 +59,7 @@ end
# Dense inseparable kernel
kern = [0.1 0.2; 0.4 0.5]
kernel = OffsetArray(kern, -1:0, 1:2)
border = Inner()
for img in (imgf, imgi, imgg, imgc)
targetimg = zeros(typeof(img[1]*kern[1]), size(img))
targetimg[3:4,2:3] = rot180(kern) .* img[3,4]
Expand All @@ -66,7 +76,7 @@ end
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg)
fill!(ret, zero(eltype(ret)))
@test @inferred(imfilter!(ret, img, kernel, border)) ≈ targetimg
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) ≈ targetimg
@test @inferred(imfilter(img, (kernel,), border, alg)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) ≈ float32.(targetimg)
Expand All @@ -76,12 +86,12 @@ end
@test_throws MethodError imfilter!(CPU1(Algorithm.FIR()), ret, img, kernel, border, Algorithm.FFT())
end
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
@test @inferred(imfilter(img, kernel, Inner())) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) ≈ float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
@test @inferred(imfilter(img, kernel, Inner(), alg)) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) ≈ float32.(targetimg_inner)
@test @inferred(imfilter(CPU1(alg), img, kernel, Inner())) ≈ targetimg_inner
@test @inferred(imfilter(img, kernel, border)) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg_inner)
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) ≈ float32.(targetimg_inner)
@test @inferred(imfilter(CPU1(alg), img, kernel, border)) ≈ targetimg_inner
end
end
# Factored kernel
Expand All @@ -96,7 +106,7 @@ end
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
@test @inferred(imfilter(img, kernel, border)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) ≈ float32.(targetimg)
end
Expand All @@ -106,7 +116,7 @@ end
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
@test @inferred(imfilter(img, kernel, Inner())) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) ≈ float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, Inner(), alg)) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) ≈ float32.(targetimg_inner)
end
Expand All @@ -122,7 +132,7 @@ end
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
@test @inferred(imfilter(img, kernel, border)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
if alg == Algorithm.FFT() && eltype(img) == Int
@test @inferred(imfilter(Float64, img, kernel, border, alg)) ≈ targetimg
else
Expand All @@ -134,7 +144,7 @@ end
targetimg_inner = OffsetArray(targetimg[2:end-1, 2:end-1], 2:4, 2:6)
@test @inferred(imfilter(img, kernel, Inner())) ≈ targetimg_inner
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) ≈ float32.(targetimg_inner)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
if alg == Algorithm.FFT() && eltype(img) == Int
@test @inferred(imfilter(Float64, img, kernel, Inner(), alg)) ≈ targetimg_inner
else
Expand Down Expand Up @@ -184,7 +194,7 @@ end
targetimg = target1(img, kern, border)
@test @inferred(imfilter(img, kernel, border)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) ≈ float32.(targetimg)
end
Expand All @@ -195,7 +205,7 @@ end
targetimg = zerona!(copy(targetimg0))
@test @inferred(zerona!(imfilter(img, kernel, border))) ≈ targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) ≈ targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) ≈ float32.(targetimg)
end
Expand All @@ -208,7 +218,7 @@ end
targetimg = target1(img, kern, border)
@test @inferred(imfilter(img, kernel, border)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border)) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(imfilter(img, kernel, border, alg)) ≈ targetimg
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) ≈ float32.(targetimg)
end
Expand All @@ -219,7 +229,7 @@ end
targetimg = zerona!(copy(targetimg0))
@test @inferred(zerona!(imfilter(img, kernel, border))) ≈ targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) ≈ float32.(targetimg)
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
for alg in supported_algs(img, kernel, border)
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) ≈ targetimg
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) ≈ float32.(targetimg)
end
Expand Down
14 changes: 9 additions & 5 deletions test/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ using ImageFiltering, Test, Statistics
size_y = 6*σy+1
γ = σx/σy
# zero size forces default kernel width, with warnings
@info "Four warnings are expected"
kernel = Kernel.gabor(0,0,σx,0,5,γ,0)
@test isequal(size(kernel[1]),(size_x,size_y))
kernel = Kernel.gabor(0,0,σx,π,5,γ,0)
@test isequal(size(kernel[1]),(size_x,size_y))

@test_logs (:warn, r"The input parameter size_") match_mode=:any begin
kernel = Kernel.gabor(0,0,σx,0,5,γ,0)
@test isequal(size(kernel[1]),(size_x,size_y))
end
@test_logs (:warn, r"The input parameter size_") match_mode=:any begin
kernel = Kernel.gabor(0,0,σx,π,5,γ,0)
@test isequal(size(kernel[1]),(size_x,size_y))
end
IanButterworth marked this conversation as resolved.
Show resolved Hide resolved

for x in 0:4, y in 0:4, z in 0:4, t in 0:4
σx1 = 2*x+1
Expand Down
Loading
Loading