Skip to content

Commit

Permalink
WIP: use RFFT.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth committed Jan 23, 2024
1 parent 20df36d commit 17e9b0a
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 25 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ author = ["Tim Holy <[email protected]>", "Jan Weidner <[email protected]>"]
version = "0.7.8"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CatIndices = "aafaddc9-749c-510e-ac4f-586e18779b91"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -13,8 +14,8 @@ ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -29,8 +30,8 @@ FFTW = "0.3, 1"
ImageBase = "0.1.5"
ImageCore = "0.10"
OffsetArrays = "1.9"
Reexport = "1.1"
PrecompileTools = "1"
Reexport = "1.1"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
TiledIteration = "0.2, 0.3, 0.4, 0.5"
julia = "1.6"
Expand Down
37 changes: 28 additions & 9 deletions demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,34 @@ using ImageFiltering, FFTW, LinearAlgebra, Profile, Random
# using ProfileView
using ComputationalResources

FFTW.set_num_threads(parse(Int, ENV["FFTW_NUM_THREADS"]))
BLAS.set_num_threads(parse(Int, ENV["BLAS_NUM_THREADS"]))
FFTW.set_num_threads(parse(Int, get(ENV, "FFTW_NUM_THREADS", "1")))
BLAS.set_num_threads(parse(Int, get(ENV, "BLAS_NUM_THREADS", string(Threads.nthreads() ÷ 2))))

function benchmark(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
Threads.@threads for mat in mats
frame_filtered = similar(mat[:, :, 1])
r = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
frame_filtered = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
for i in axes(mat, 3)
frame = @view mat[:, :, i]
imfilter!(r, frame_filtered, frame, kernel)
imfilter!(r_cached, frame_filtered, frame, kernel)
end
return
end
end

function test(mats)
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
for mat in mats
f1 = deepcopy(mat[:, :, 1])
r_cached = CPU1(ImageFiltering.planned_fft(f1, kernel))
f2 = deepcopy(mat[:, :, 1])
r_noncached = CPU1(Algorithm.FFT())
for i in axes(mat, 3)
frame = @view mat[:, :, i]
imfilter!(r_cached, f1, frame, kernel)
imfilter!(r_noncached, f2, frame, kernel)
all(f1 .≈ f2) || error("f1 !≈ f2")
end
return
end
Expand All @@ -24,11 +41,13 @@ function profile()
mats = [rand(Float32, rand(80:100), rand(80:100), rand(2000:3000)) for _ in 1:nmats]
GC.gc(true)

benchmark(mats)
# benchmark(mats)

for _ in 1:3
@time "warm run of benchmark(mats)" benchmark(mats)
end
# for _ in 1:3
# @time "warm run of benchmark(mats)" benchmark(mats)
# end

test(mats)

# Profile.clear()
# @profile benchmark(mats)
Expand Down
20 changes: 16 additions & 4 deletions src/ImageFiltering.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module ImageFiltering

using FFTW
include("RFFT.jl") # TODO: Register RFFT.jl on General and add as a dependency
using ImageCore, FFTViews, OffsetArrays, StaticArrays, ComputationalResources, TiledIteration
# Where possible we avoid a direct dependency to reduce the number of [compat] bounds
# using FixedPointNumbers: Normed, N0f8 # reexported by ImageCore
using ImageCore.MappedArrays
using Statistics, LinearAlgebra
using Base: Indices, tail, fill_to_length, @pure, depwarn, @propagate_inbounds
import Base: copy!
using OffsetArrays: IdentityUnitRange # using the one in OffsetArrays makes this work with multiple Julia versions
using SparseArrays # only needed to fix an ambiguity in borderarray
using Reexport
Expand Down Expand Up @@ -51,13 +53,23 @@ end

module Algorithm
import FFTW
import ..RFFT
struct BufferedFFTPlan{T<:AbstractFloat}
plan::Function
buf::RFFT.RCpair{T}
end
function BufferedFFTPlan(a::AbstractArray{T}) where {T<:AbstractFloat}
buf = RFFT.RCpair{T}(undef, size(a))
plan = RFFT.plan_rfft!(buf)
BufferedFFTPlan(plan, buf)
end
# deliberately don't export these, but it's expected that they
# will be used as Algorithm.FFT(), etc.
abstract type Alg end
"Filter using the Fast Fourier Transform" struct FFT <: Alg
plan1::Union{FFTW.rFFTWPlan,Nothing}
plan2::Union{FFTW.rFFTWPlan,Nothing}
plan3::Union{FFTW.AbstractFFTs.ScaledPlan,Nothing}
plan1::Union{BufferedFFTPlan,Nothing}
plan2::Union{BufferedFFTPlan,Nothing}
plan3::Union{BufferedFFTPlan,Nothing}
end
FFT() = FFT(nothing, nothing, nothing)
"Filter using a direct algorithm" struct FIR <: Alg end
Expand All @@ -69,7 +81,7 @@ module Algorithm

FIRTiled() = FIRTiled(())
end
using .Algorithm: Alg, FFT, FIR, FIRTiled, IIR, Mixed
using .Algorithm: Alg, FFT, FIR, FIRTiled, IIR, Mixed, BufferedFFTPlan

Alg(r::AbstractResource{A}) where {A<:Alg} = r.settings

Expand Down
74 changes: 74 additions & 0 deletions src/RFFT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
module RFFT

using FFTW, LinearAlgebra

export RCpair, plan_rfft!, plan_irfft!, rfft!, irfft!, normalization

import Base: real, complex, copy, copy!

mutable struct RCpair{T<:AbstractFloat,N,RType<:AbstractArray{T,N},CType<:AbstractArray{Complex{T},N}}
R::RType
C::CType
region::Vector{Int}
end

function RCpair{T}(::UndefInitializer, realsize::Dims{N}, region=1:length(realsize)) where {T<:AbstractFloat,N}
sz = [realsize...]
firstdim = region[1]
sz[firstdim] = realsize[firstdim]>>1 + 1
sz2 = copy(sz)
sz2[firstdim] *= 2
R = Array{T,N}(undef, (sz2...,)::Dims{N})
C = unsafe_wrap(Array, convert(Ptr{Complex{T}}, pointer(R)), (sz...,)::Dims{N}) # work around performance problems of reinterpretarray
RCpair(view(R, map(n->1:n, realsize)...), C, [region...])
end

RCpair(A::Array{T}, region=1:ndims(A)) where {T<:AbstractFloat} = copy!(RCpair{T}(undef, size(A), region), A)

real(RC::RCpair) = RC.R
complex(RC::RCpair) = RC.C

copy!(RC::RCpair, A::AbstractArray{T}) where {T<:Real} = (copy!(RC.R, A); RC)
function copy(RC::RCpair{T,N}) where {T,N}
C = copy(RC.C)
R = reshape(reinterpret(T, C), size(parent(RC.R)))
RCpair(view(R, RC.R.indices...), C, copy(RC.region))
end

# New API
rplan_fwd(R, C, region, flags, tlim) =
FFTW.rFFTWPlan{eltype(R),FFTW.FORWARD,true,ndims(R)}(R, C, region, flags, tlim)
rplan_inv(R, C, region, flags, tlim) =
FFTW.rFFTWPlan{eltype(R),FFTW.BACKWARD,true,ndims(R)}(R, C, region, flags, tlim)
function plan_rfft!(RC::RCpair{T}; flags::Integer = FFTW.ESTIMATE, timelimit::Real = FFTW.NO_TIMELIMIT) where T
p = rplan_fwd(RC.R, RC.C, RC.region, flags, timelimit)
return Z::RCpair -> begin
FFTW.assert_applicable(p, Z.R, Z.C)
FFTW.unsafe_execute!(p, Z.R, Z.C)
return Z
end
end
function plan_irfft!(RC::RCpair{T}; flags::Integer = FFTW.ESTIMATE, timelimit::Real = FFTW.NO_TIMELIMIT) where T
p = rplan_inv(RC.C, RC.R, RC.region, flags, timelimit)
return Z::RCpair -> begin
FFTW.assert_applicable(p, Z.C, Z.R)
FFTW.unsafe_execute!(p, Z.C, Z.R)
rmul!(Z.R, 1 / prod(size(Z.R)[Z.region]))
return Z
end
end
function rfft!(RC::RCpair{T}) where T
p = rplan_fwd(RC.R, RC.C, RC.region, FFTW.ESTIMATE, FFTW.NO_TIMELIMIT)
FFTW.unsafe_execute!(p, RC.R, RC.C)
return RC
end
function irfft!(RC::RCpair{T}) where T
p = rplan_inv(RC.C, RC.R, RC.region, FFTW.ESTIMATE, FFTW.NO_TIMELIMIT)
FFTW.unsafe_execute!(p, RC.C, RC.R)
rmul!(RC.R, 1 / prod(size(RC.R)[RC.region]))
return RC
end

@deprecate RCpair(realtype::Type{T}, realsize, region=1:length(realsize)) where T<:AbstractFloat RCpair{T}(undef, realsize, region)

end
29 changes: 19 additions & 10 deletions src/imfilter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -840,25 +840,34 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
out
end

copy!(p::BufferedFFTPlan, a::AbstractArray{T}) where {T} = copy!(p.buf, a)
function updaterun!(p::BufferedFFTPlan, a::AbstractArray{T}) where {T}
copy!(p.buf, OffsetArrays.no_offset_view(a))
p.plan(p.buf)
end

function planned_fft(A::AbstractArray{T,N},
kernel::Tuple{AbstractArray,Vararg{AbstractArray}},
border::BorderSpecAny=Pad(:replicate)) where {T,N}
bord = border(kernel, A, Algorithm.FFT())
_A = padarray(T, A, bord)
p1 = plan_rfft(_A)
bfp1 = BufferedFFTPlan(_A)
B = real(updaterun!(bfp1, _A)) * FFTW.AbstractFFTs.to1(_A)
kern = samedims(_A, kernelconv(kernel...))
krn = FFTView(zeros(eltype(kern), map(length, axes(_A))))
p2 = plan_rfft(krn)
B = p1 * _A
B .*= conj!(p2 * krn)
p3 = plan_irfft(B, length(axes(_A, 1)))
return Algorithm.FFT(p1, p2, p3)
for I in CartesianIndices(axes(kern))
krn[I] = kern[I]
end
bfp2 = BufferedFFTPlan(krn)
B .*= conj!(real(updaterun!(bfp2, krn)) * FFTW.AbstractFFTs.to1(krn))
bfp3 = BufferedFFTPlan(B)
return Algorithm.FFT(bfp1, bfp2, bfp3)
end

function filtfft(A, krn, plan_A::FFTW.rFFTWPlan, plan_krn::FFTW.rFFTWPlan, plan_B::FFTW.AbstractFFTs.ScaledPlan)
B = plan_A * A
B .*= conj!(plan_krn * krn)
plan_B * B
function filtfft(A, krn, bfp1::BufferedFFTPlan, bfp2::BufferedFFTPlan, bfp3::BufferedFFTPlan)
B = real(updaterun!(bfp1, A)) * FFTW.AbstractFFTs.to1(A)
B .*= conj!(real(updaterun!(bfp2, krn)) * FFTW.AbstractFFTs.to1(krn))
return real(updaterun!(bfp3, B)) * B
end
filtfft(A, krn, ::Nothing, ::Nothing, ::Nothing) = filtfft(A, krn)
function filtfft(A, krn)
Expand Down

0 comments on commit 17e9b0a

Please sign in to comment.