Skip to content

Commit

Permalink
Always build the CPU variant of controller_wrappers such that CUDA bu…
Browse files Browse the repository at this point in the history
…ild can still run on CPU
  • Loading branch information
basnijholt committed Jan 13, 2024
1 parent c533946 commit cc18d6f
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ recursive-include qiskit_aer *hpp
graft src
graft contrib
include qiskit_aer/backends/wrappers/CMakeLists.txt
include qiskit_aer/backends/wrappers/bindings.cc
include qiskit_aer/backends/wrappers/bindings.cc.in
include qiskit_aer/VERSION.txt
include qiskit_aer/library/instructions_table.csv
include CMakeLists.txt
Expand Down
29 changes: 29 additions & 0 deletions qiskit_aer/backends/controller_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import importlib


def try_import_backend(backend_module_suffix):
module_name = f".controller_wrappers_{backend_module_suffix}"
try:
return importlib.import_module(module_name, "qiskit_aer.backends")
except ImportError:
return None


IMPORTED_BACKEND = None
BACKENDS = ["cuda", "rocm", "cpu"]

for backend_suffix in BACKENDS:
backend_module = try_import_backend(backend_suffix)
if backend_module:
IMPORTED_BACKEND = backend_suffix
globals().update(
{
name: getattr(backend_module, name)
for name in dir(backend_module)
if not name.startswith("_")
}
)
break

if IMPORTED_BACKEND is None:
raise ImportError("No backend found for qiskit-aer.")
59 changes: 39 additions & 20 deletions qiskit_aer/backends/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,37 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_HOST_SYSTEM_PROCESSOR
endif()
endif()

set(AER_SIMULATOR_SOURCES "bindings.cc" "${SIMD_SOURCE_FILE}")
basic_pybind11_add_module(controller_wrappers "${AER_SIMULATOR_SOURCES}")
macro(configure_target target_name)
target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
PRIVATE ${AER_SIMULATOR_CPP_SRC_DIR}
PRIVATE ${AER_SIMULATOR_CPP_EXTERNAL_LIBS})
target_link_libraries(${target_name} ${AER_LIBRARIES})
target_compile_definitions(${target_name} PRIVATE ${AER_COMPILER_DEFINITIONS})
install(TARGETS ${target_name} LIBRARY DESTINATION qiskit_aer/backends)
endmacro()

# Build the CPU backend
set(BACKEND_MODULE_NAME "controller_wrappers_cpu")
configure_file(bindings.cc.in bindings_cpu.cc)
basic_pybind11_add_module(controller_wrappers_cpu bindings_cpu.cc "${SIMD_SOURCE_FILE}")

if(DEFINED SIMD_SOURCE_FILE)
string(REPLACE ";" " " SIMD_FLAGS "${SIMD_FLAGS_LIST}")
set_source_files_properties(${SIMD_SOURCE_FILE} PROPERTIES COMPILE_FLAGS "${SIMD_FLAGS}")
endif()

set_target_properties(controller_wrappers_cpu PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS}")
configure_target(controller_wrappers_cpu)

# Build the CUDA backend
if(AER_THRUST_BACKEND STREQUAL "CUDA")
set(BACKEND_MODULE_NAME "controller_wrappers_cuda")
configure_file(bindings.cc.in bindings_cuda.cc)
basic_pybind11_add_module(controller_wrappers_cuda bindings_cuda.cc "${SIMD_SOURCE_FILE}")

include(nvcc_add_compiler_options)
set_source_files_properties(bindings.cc PROPERTIES LANGUAGE CUDA)
set_source_files_properties(bindings.cc PROPERTIES COMPILE_FLAGS "${CUDA_NVCC_FLAGS}")
set_source_files_properties(bindings_cuda.cc PROPERTIES LANGUAGE CUDA)
set_source_files_properties(bindings_cuda.cc PROPERTIES COMPILE_FLAGS "${CUDA_NVCC_FLAGS}")

if(DEFINED SIMD_SOURCE_FILE)
set_source_files_properties(${SIMD_SOURCE_FILE} PROPERTIES LANGUAGE CUDA)
Expand All @@ -36,34 +60,29 @@ if(AER_THRUST_BACKEND STREQUAL "CUDA")

string(STRIP ${AER_COMPILER_FLAGS} AER_COMPILER_FLAGS_STRIPPED)
nvcc_add_compiler_options(${AER_COMPILER_FLAGS_STRIPPED} AER_COMPILER_FLAGS_OUT)
set_target_properties(controller_wrappers PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS_OUT}")
set_target_properties(controller_wrappers_cuda PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS_OUT}")
enable_language(CUDA)
configure_target(controller_wrappers_cuda)
# Build the ROCm backend
elseif(AER_THRUST_BACKEND STREQUAL "ROCM")
set(BACKEND_MODULE_NAME "controller_wrappers_rocm")
configure_file(bindings.cc.in bindings_rocm.cc)
basic_pybind11_add_module(controller_wrappers_rocm bindings_rocm.cc "${SIMD_SOURCE_FILE}")

if(NOT DEFINED SIMD_SOURCE_FILE)
message(FATAL_ERROR "ROCm supported target machines are expected to be SIMD-enabled.")
endif()

set_source_files_properties(
bindings.cc
bindings_rocm.cc
${SIMD_SOURCE_FILE}
PROPERTIES LANGUAGE CXX)

target_compile_options(controller_wrappers PRIVATE ${ROCM_EXTRA_FLAGS} ${SIMD_FLAGS_LIST})
target_compile_definitions(controller_wrappers PRIVATE ${ROCM_EXTRA_DEFS} ${AER_COMPILER_DEFINITIONS})
set_target_properties(controller_wrappers PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS}")
else()
if(DEFINED SIMD_SOURCE_FILE)
string(REPLACE ";" " " SIMD_FLAGS "${SIMD_FLAGS_LIST}")
set_source_files_properties(${SIMD_SOURCE_FILE} PROPERTIES COMPILE_FLAGS "${SIMD_FLAGS}")
endif()
set_target_properties(controller_wrappers PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS}")
target_compile_options(controller_wrappers_rocm PRIVATE ${ROCM_EXTRA_FLAGS} ${SIMD_FLAGS_LIST})
target_compile_definitions(controller_wrappers_rocm PRIVATE ${ROCM_EXTRA_DEFS} ${AER_COMPILER_DEFINITIONS})
set_target_properties(controller_wrappers_rocm PROPERTIES COMPILE_FLAGS "${AER_COMPILER_FLAGS}")
configure_target(controller_wrappers_rocm)
endif()
target_include_directories(controller_wrappers PRIVATE ${AER_SIMULATOR_CPP_SRC_DIR}
PRIVATE ${AER_SIMULATOR_CPP_EXTERNAL_LIBS})
target_link_libraries(controller_wrappers ${AER_LIBRARIES})
target_compile_definitions(controller_wrappers PRIVATE ${AER_COMPILER_DEFINITIONS})
install(TARGETS controller_wrappers LIBRARY DESTINATION qiskit_aer/backends)

# Install redistributable dependencies
install(FILES ${BACKEND_REDIST_DEPS} DESTINATION qiskit_aer/backends)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ DISABLE_WARNING_POP

using namespace AER;

PYBIND11_MODULE(controller_wrappers, m) {
PYBIND11_MODULE(@BACKEND_MODULE_NAME@, m) {

#ifdef AER_MPI
int prov;
Expand Down

0 comments on commit cc18d6f

Please sign in to comment.