-
Notifications
You must be signed in to change notification settings - Fork 3
/
CMakeLists.txt
107 lines (87 loc) · 4.03 KB
/
CMakeLists.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
cmake_minimum_required(VERSION 3.19...3.30)
project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})
# for cuda-gdb and verbose PTXAS output
# set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "-g -G -Xptxas -v")
# Enable OpenMP if requested and available
option(JAX_FINUFFT_USE_OPENMP "Enable OpenMP" ON)
if(JAX_FINUFFT_USE_OPENMP)
find_package(OpenMP)
if(OpenMP_CXX_FOUND)
message(STATUS "jax_finufft: OpenMP found")
set(FINUFFT_USE_OPENMP ON)
else()
message(STATUS "jax_finufft: OpenMP not found")
set(FINUFFT_USE_OPENMP OFF)
endif()
else()
message(STATUS "jax_finufft: OpenMP support was not requested")
set(FINUFFT_USE_OPENMP OFF)
endif()
# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF)
if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "jax_finufft: CUDA compiler found; compiling with GPU support")
set(FINUFFT_USE_CUDA ON)
if(NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "native")
endif()
message(STATUS "jax_finufft: CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
# Propagate to finufft, because it doesn't look at CMAKE_CUDA_ARCHITECTURES by default
set(FINUFFT_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
# This needs to be run after the CMAKE_CUDA_ARCHITECTURES check, otherwise
# it will set it to the compiler default
enable_language(CUDA)
else()
message(FATAL_ERROR "jax_finufft: No CUDA compiler found! Please ensure the "
"CUDA Toolkit is installed, or set JAX_FINUFFT_USE_CUDA=OFF to disable "
"GPU support.")
set(FINUFFT_USE_CUDA OFF)
endif()
else()
message(STATUS "jax_finufft: GPU support was not requested")
set(FINUFFT_USE_CUDA OFF)
endif()
set(FINUFFT_POSITION_INDEPENDENT_CODE ON)
# Add the FINUFFT project using the vendored version
add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/vendor/finufft")
# Find Python and nanobind
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(nanobind CONFIG REQUIRED)
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
# Build the CPU XLA bindings
nanobind_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc)
target_link_libraries(jax_finufft_cpu PRIVATE finufft)
target_include_directories(jax_finufft_cpu PRIVATE ${FFTW_INCLUDE_DIRS})
install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .)
if(FINUFFT_USE_OPENMP)
target_compile_definitions(jax_finufft_cpu PRIVATE FINUFFT_USE_OPENMP)
endif()
# Include the CUDA extensions if possible - see above for where this is set
if(FINUFFT_USE_CUDA)
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
# TODO(dfm): The ${CUFINUFFT_INCLUDE_DIRS} variable doesn't seem to get set
# properly when FINUFFT is included as a submodule (maybe because of the use
# of ${PROJECT_SOURCE_DIR}). This is just copied from there, linking to the
# appropriate vendored directories.
set(CUFINUFFT_VENDORED_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/cuda_samples
)
nanobind_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/cufinufft_wrapper.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_INCLUDE_DIRS})
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_VENDORED_INCLUDE_DIRS})
target_link_libraries(jax_finufft_gpu PRIVATE cufinufft)
install(TARGETS jax_finufft_gpu LIBRARY DESTINATION .)
endif()