Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Use custom type to pass CGF around instead of std::function #16668

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,16 @@ template <typename LCRangeT, typename LCPropertiesT> struct LaunchConfigAccess {
template <typename CommandGroupFunc, typename PropertiesT>
void submit_impl(queue &Q, PropertiesT Props, CommandGroupFunc &&CGF,
const sycl::detail::code_location &CodeLoc) {
Q.submit_without_event(Props, std::forward<CommandGroupFunc>(CGF), CodeLoc);
Q.submit_without_event<__SYCL_USE_FALLBACK_ASSERT>(
Props, detail::type_erased_cgfo_ty{CGF}, CodeLoc);
}

template <typename CommandGroupFunc, typename PropertiesT>
event submit_with_event_impl(queue &Q, PropertiesT Props,
CommandGroupFunc &&CGF,
const sycl::detail::code_location &CodeLoc) {
return Q.submit_with_event(Props, std::forward<CommandGroupFunc>(CGF),
nullptr, CodeLoc);
return Q.submit_with_event<__SYCL_USE_FALLBACK_ASSERT>(
Props, detail::type_erased_cgfo_ty{CGF}, nullptr, CodeLoc);
}
} // namespace detail

Expand Down
32 changes: 32 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,38 @@ class graph_impl;
} // namespace ext::oneapi::experimental::detail
namespace detail {

class type_erased_cgfo_ty {
// From SYCL 2020, command group function object:
// A type which is callable with operator() that takes a reference to a
// command group handler, that defines a command group which can be submitted
// by a queue. The function object can be a named type, lambda function or
// std::function.
template <typename T> struct invoker {
static void call(void *object, handler &cgh) {
(*static_cast<T *>(object))(cgh);
}
};
void *object;
using invoker_ty = void (*)(void *, handler &);
const invoker_ty invoker_f;

public:
template <class T>
type_erased_cgfo_ty(T &f)
// NOTE: Even if `T` is a pointer to a function, `&f` is a pointer to a
// pointer to a function and as such can be casted to `void *` (pointer to
// a function cannot be casted).
: object(static_cast<void *>(&f)), invoker_f(&invoker<T>::call) {}
~type_erased_cgfo_ty() = default;

type_erased_cgfo_ty(const type_erased_cgfo_ty &) = delete;
type_erased_cgfo_ty(type_erased_cgfo_ty &&) = delete;
type_erased_cgfo_ty &operator=(const type_erased_cgfo_ty &) = delete;
type_erased_cgfo_ty &operator=(type_erased_cgfo_ty &&) = delete;

void operator()(sycl::handler &cgh) const { invoker_f(object, cgh); }
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I first saw this I was wondering if we could use std::invokable instead, but I think that'd change the ABI, plus it might have the same too-much-templating problem as std::function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how it's applicable here. std::invokable is a concept, not a data type; so more like std::is_invokable_r trait that we are already using.

class kernel_bundle_impl;
class work_group_memory_impl;
class handler_impl;
Expand Down
111 changes: 65 additions & 46 deletions sycl/include/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ auto get_native(const SyclObjectT &Obj)
namespace detail {
class queue_impl;

#if __SYCL_USE_FALLBACK_ASSERT
inline event submitAssertCapture(queue &, event &, queue *,
const detail::code_location &);
#endif

// Function to postprocess submitted command
// Arguments:
Expand Down Expand Up @@ -375,8 +373,9 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
std::enable_if_t<std::is_invocable_r_v<void, T, handler &>, event> submit(
T CGF,
const detail::code_location &CodeLoc = detail::code_location::current()) {
return submit_with_event(
sycl::ext::oneapi::experimental::empty_properties_t{}, CGF,
return submit_with_event<__SYCL_USE_FALLBACK_ASSERT>(
sycl::ext::oneapi::experimental::empty_properties_t{},
detail::type_erased_cgfo_ty{CGF},
/*SecondaryQueuePtr=*/nullptr, CodeLoc);
}

Expand All @@ -395,9 +394,9 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
std::enable_if_t<std::is_invocable_r_v<void, T, handler &>, event> submit(
T CGF, queue &SecondaryQueue,
const detail::code_location &CodeLoc = detail::code_location::current()) {
return submit_with_event(
sycl::ext::oneapi::experimental::empty_properties_t{}, CGF,
&SecondaryQueue, CodeLoc);
return submit_with_event<__SYCL_USE_FALLBACK_ASSERT>(
sycl::ext::oneapi::experimental::empty_properties_t{},
detail::type_erased_cgfo_ty{CGF}, &SecondaryQueue, CodeLoc);
}

/// Prevents any commands submitted afterward to this queue from executing
Expand Down Expand Up @@ -2786,6 +2785,7 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
/// TODO: Unused. Remove these when ABI-break window is open.
/// Not using `type_erased_cgfo_ty` on purpose.
event submit_impl(std::function<void(handler &)> CGH,
const detail::code_location &CodeLoc);
event submit_impl(std::function<void(handler &)> CGH,
Expand Down Expand Up @@ -2815,16 +2815,28 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
std::function<void(handler &)> CGH, queue secondQueue,
const detail::code_location &CodeLoc,
const detail::SubmitPostProcessF &PostProcess, bool IsTopCodeLoc);

// Old version when `std::function` was used in place of
// `std::function<void(handler &)>`.
event submit_with_event_impl(std::function<void(handler &)> CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc);

void submit_without_event_impl(std::function<void(handler &)> CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc);
#endif // __INTEL_PREVIEW_BREAKING_CHANGES

/// A template-free versions of submit.
event submit_with_event_impl(std::function<void(handler &)> CGH,
event submit_with_event_impl(const detail::type_erased_cgfo_ty &CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc);

/// A template-free version of submit_without_event.
void submit_without_event_impl(std::function<void(handler &)> CGH,
void submit_without_event_impl(const detail::type_erased_cgfo_ty &CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc);
Expand All @@ -2836,32 +2848,35 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
/// \param CGF is a function object containing command group.
/// \param CodeLoc is the code location of the submit call (default argument)
/// \return a SYCL event object for the submitted command group.
template <typename T, typename PropertiesT>
std::enable_if_t<std::is_invocable_r_v<void, T, handler &>, event>
submit_with_event(
PropertiesT Props, T CGF, queue *SecondaryQueuePtr,
//
// UseFallBackAssert as template param vs `#if` in function body is necessary
// to prevent ODR-violation between TUs built with different fallback assert
// modes.
Comment on lines +2852 to +2854
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steffenlarsen , any idea if

/// NOTE: Function is dependent to prevent the fallback kernels from
/// materializing without the use of the function.
is somehow related?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "fallback kernel" in this case is not referring to fallback asserts. It's because the implementation of ext_oneapi_memcpy2d has a kernel it uses if the operation isn't natively supported, but we don't want that kernel to pop up unless the user calls ext_oneapi_memcpy2d in their code, so we make it dependent.

template <bool UseFallbackAssert, typename PropertiesT>
event submit_with_event(
PropertiesT Props, const detail::type_erased_cgfo_ty &CGF,
queue *SecondaryQueuePtr,
const detail::code_location &CodeLoc = detail::code_location::current()) {
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
detail::SubmissionInfo SI{};
ProcessSubmitProperties(Props, SI);
if (SecondaryQueuePtr)
SI.SecondaryQueue() = detail::getSyclObjImpl(*SecondaryQueuePtr);
#if __SYCL_USE_FALLBACK_ASSERT
SI.PostProcessorFunc() =
[this, &SecondaryQueuePtr,
&TlsCodeLocCapture](bool IsKernel, bool KernelUsesAssert, event &E) {
if (IsKernel && !device_has(aspect::ext_oneapi_native_assert) &&
KernelUsesAssert && !device_has(aspect::accelerator)) {
// __devicelib_assert_fail isn't supported by Device-side Runtime
// Linking against fallback impl of __devicelib_assert_fail is
// performed by program manager class
// Fallback assert isn't supported for FPGA
submitAssertCapture(*this, E, SecondaryQueuePtr,
TlsCodeLocCapture.query());
}
};
#endif // __SYCL_USE_FALLBACK_ASSERT
return submit_with_event_impl(std::move(CGF), SI, TlsCodeLocCapture.query(),
if constexpr (UseFallbackAssert)
SI.PostProcessorFunc() =
[this, &SecondaryQueuePtr,
&TlsCodeLocCapture](bool IsKernel, bool KernelUsesAssert, event &E) {
if (IsKernel && !device_has(aspect::ext_oneapi_native_assert) &&
KernelUsesAssert && !device_has(aspect::accelerator)) {
// __devicelib_assert_fail isn't supported by Device-side Runtime
// Linking against fallback impl of __devicelib_assert_fail is
// performed by program manager class
// Fallback assert isn't supported for FPGA
submitAssertCapture(*this, E, SecondaryQueuePtr,
TlsCodeLocCapture.query());
}
};
return submit_with_event_impl(CGF, SI, TlsCodeLocCapture.query(),
TlsCodeLocCapture.isToplevel());
}

Expand All @@ -2871,21 +2886,25 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
/// \param Props is a property list with submission properties.
/// \param CGF is a function object containing command group.
/// \param CodeLoc is the code location of the submit call (default argument)
template <typename T, typename PropertiesT>
std::enable_if_t<std::is_invocable_r_v<void, T, handler &>, void>
submit_without_event(PropertiesT Props, T CGF,
const detail::code_location &CodeLoc) {
#if __SYCL_USE_FALLBACK_ASSERT
// If post-processing is needed, fall back to the regular submit.
// TODO: Revisit whether we can avoid this.
submit_with_event(Props, CGF, nullptr, CodeLoc);
#else
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
detail::SubmissionInfo SI{};
ProcessSubmitProperties(Props, SI);
submit_without_event_impl(CGF, SI, TlsCodeLocCapture.query(),
TlsCodeLocCapture.isToplevel());
#endif // __SYCL_USE_FALLBACK_ASSERT
//
// UseFallBackAssert as template param vs `#if` in function body is necessary
// to prevent ODR-violation between TUs built with different fallback assert
// modes.
template <bool UseFallbackAssert, typename PropertiesT>
void submit_without_event(PropertiesT Props,
const detail::type_erased_cgfo_ty &CGF,
const detail::code_location &CodeLoc) {
if constexpr (UseFallbackAssert) {
// If post-processing is needed, fall back to the regular submit.
// TODO: Revisit whether we can avoid this.
submit_with_event<UseFallbackAssert>(Props, CGF, nullptr, CodeLoc);
} else {
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
detail::SubmissionInfo SI{};
ProcessSubmitProperties(Props, SI);
submit_without_event_impl(CGF, SI, TlsCodeLocCapture.query(),
TlsCodeLocCapture.isToplevel());
}
}

/// parallel_for_impl with a kernel represented as a lambda + range that
Expand Down Expand Up @@ -3114,10 +3133,10 @@ event submitAssertCapture(queue &Self, event &Event, queue *SecondaryQueue,
});
};

CopierEv = Self.submit_with_event(
CopierEv = Self.submit_with_event<true>(
sycl::ext::oneapi::experimental::empty_properties_t{}, CopierCGF,
SecondaryQueue, CodeLoc);
CheckerEv = Self.submit_with_event(
CheckerEv = Self.submit_with_event<true>(
sycl::ext::oneapi::experimental::empty_properties_t{}, CheckerCGF,
SecondaryQueue, CodeLoc);

Expand Down
35 changes: 18 additions & 17 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ void queue_impl::addSharedEvent(const event &Event) {
MEventsShared.push_back(Event);
}

event queue_impl::submit_impl(const std::function<void(handler &)> &CGF,
event queue_impl::submit_impl(const detail::type_erased_cgfo_ty &CGF,
const std::shared_ptr<queue_impl> &Self,
const std::shared_ptr<queue_impl> &PrimaryQueue,
const std::shared_ptr<queue_impl> &SecondaryQueue,
Expand Down Expand Up @@ -402,10 +402,13 @@ event queue_impl::submit_impl(const std::function<void(handler &)> &CGF,
// We don't want stream flushing to be blocking operation that is why submit
// a host task to print stream buffer. It will fire up as soon as the kernel
// finishes execution.
event FlushEvent = submit_impl(
[&](handler &ServiceCGH) { Stream->generateFlushCommand(ServiceCGH); },
Self, PrimaryQueue, SecondaryQueue, /*CallerNeedsEvent*/ true, Loc,
IsTopCodeLoc, {});
auto L = [&](handler &ServiceCGH) {
Stream->generateFlushCommand(ServiceCGH);
};
detail::type_erased_cgfo_ty CGF{L};
event FlushEvent =
submit_impl(CGF, Self, PrimaryQueue, SecondaryQueue,
/*CallerNeedsEvent*/ true, Loc, IsTopCodeLoc, {});
EventImpl->attachEventToCompleteWeak(detail::getSyclObjImpl(FlushEvent));
registerStreamServiceEvent(detail::getSyclObjImpl(FlushEvent));
}
Expand All @@ -419,21 +422,19 @@ event queue_impl::submitWithHandler(const std::shared_ptr<queue_impl> &Self,
bool CallerNeedsEvent,
HandlerFuncT HandlerFunc) {
SubmissionInfo SI{};
auto L = [&](handler &CGH) {
CGH.depends_on(DepEvents);
HandlerFunc(CGH);
};
detail::type_erased_cgfo_ty CGF{L};

if (!CallerNeedsEvent && supportsDiscardingPiEvents()) {
submit_without_event(
[&](handler &CGH) {
CGH.depends_on(DepEvents);
HandlerFunc(CGH);
},
Self, SI, /*CodeLoc*/ {}, /*IsTopCodeLoc*/ true);
submit_without_event(CGF, Self, SI,
/*CodeLoc*/ {}, /*IsTopCodeLoc*/ true);
return createDiscardedEvent();
}
return submit_with_event(
[&](handler &CGH) {
CGH.depends_on(DepEvents);
HandlerFunc(CGH);
},
Self, SI, /*CodeLoc*/ {}, /*IsTopCodeLoc*/ true);
return submit_with_event(CGF, Self, SI,
/*CodeLoc*/ {}, /*IsTopCodeLoc*/ true);
}

template <typename HandlerFuncT, typename MemOpFuncT, typename... MemOpArgTs>
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class queue_impl {
/// \param StoreAdditionalInfo makes additional info be stored in event_impl
/// \return a SYCL event object, which corresponds to the queue the command
/// group is being enqueued on.
event submit(const std::function<void(handler &)> &CGF,
event submit(const detail::type_erased_cgfo_ty &CGF,
const std::shared_ptr<queue_impl> &Self,
const std::shared_ptr<queue_impl> &SecondQueue,
const detail::code_location &Loc, bool IsTopCodeLoc,
Expand All @@ -362,7 +362,7 @@ class queue_impl {
/// \param Loc is the code location of the submit call (default argument)
/// \param StoreAdditionalInfo makes additional info be stored in event_impl
/// \return a SYCL event object for the submitted command group.
event submit_with_event(const std::function<void(handler &)> &CGF,
event submit_with_event(const detail::type_erased_cgfo_ty &CGF,
const std::shared_ptr<queue_impl> &Self,
const SubmissionInfo &SubmitInfo,
const detail::code_location &Loc, bool IsTopCodeLoc) {
Expand All @@ -387,7 +387,7 @@ class queue_impl {
return discard_or_return(ResEvent);
}

void submit_without_event(const std::function<void(handler &)> &CGF,
void submit_without_event(const detail::type_erased_cgfo_ty &CGF,
const std::shared_ptr<queue_impl> &Self,
const SubmissionInfo &SubmitInfo,
const detail::code_location &Loc,
Expand Down Expand Up @@ -855,7 +855,7 @@ class queue_impl {
/// \param Loc is the code location of the submit call (default argument)
/// \param SubmitInfo is additional optional information for the submission.
/// \return a SYCL event representing submitted command group.
event submit_impl(const std::function<void(handler &)> &CGF,
event submit_impl(const detail::type_erased_cgfo_ty &CGF,
const std::shared_ptr<queue_impl> &Self,
const std::shared_ptr<queue_impl> &PrimaryQueue,
const std::shared_ptr<queue_impl> &SecondaryQueue,
Expand Down
16 changes: 15 additions & 1 deletion sycl/source/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ event queue::submit_impl_and_postprocess(
return impl->submit(CGH, impl, SecondQueue.impl, CodeLoc, IsTopCodeLoc,
&PostProcess);
}
#endif // __INTEL_PREVIEW_BREAKING_CHANGES

event queue::submit_with_event_impl(std::function<void(handler &)> CGH,
const detail::SubmissionInfo &SubmitInfo,
Expand All @@ -270,6 +269,21 @@ void queue::submit_without_event_impl(std::function<void(handler &)> CGH,
bool IsTopCodeLoc) {
impl->submit_without_event(CGH, impl, SubmitInfo, CodeLoc, IsTopCodeLoc);
}
#endif // __INTEL_PREVIEW_BREAKING_CHANGES

event queue::submit_with_event_impl(const detail::type_erased_cgfo_ty &CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc) {
return impl->submit_with_event(CGH, impl, SubmitInfo, CodeLoc, IsTopCodeLoc);
}

void queue::submit_without_event_impl(const detail::type_erased_cgfo_ty &CGH,
const detail::SubmissionInfo &SubmitInfo,
const detail::code_location &CodeLoc,
bool IsTopCodeLoc) {
impl->submit_without_event(CGH, impl, SubmitInfo, CodeLoc, IsTopCodeLoc);
}

void queue::wait_proxy(const detail::code_location &CodeLoc) {
impl->wait(CodeLoc);
Expand Down
22 changes: 22 additions & 0 deletions sycl/test-e2e/Basic/submit_fn_ptr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <sycl/detail/core.hpp>
#include <sycl/usm.hpp>

int *p = nullptr;

void foo(sycl::handler &cgh) {
auto *copy = p;
cgh.single_task([=]() { *copy = 42; });
}

int main() {
sycl::queue q;
p = sycl::malloc_shared<int>(1, q);
*p = 0;
q.submit(foo).wait();
assert(*p == 42);
sycl::free(p, q);
return 0;
}
Loading
Loading