Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow async nestings that violate read after write dependencies #7868

Merged
merged 18 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,74 @@ class NoOpCollapsingMutator : public IRMutator {
class GenerateProducerBody : public NoOpCollapsingMutator {
const string &func;
vector<Expr> sema;
std::set<string> producers_dropped;
bool found_producer = false;

using NoOpCollapsingMutator::visit;

void bad_producer_nesting_error(const string &producer, const string &async_consumer) {
user_error
<< "The Func " << producer << " is consumed by async Func " << async_consumer
<< " and has a compute_at location in between the store_at "
<< "location and the compute_at location of " << async_consumer
<< ". This is only legal when " << producer
<< " is both async and has a store_at location outside the store_at location of the consumer.";
}

// Preserve produce nodes and add synchronization
Stmt visit(const ProducerConsumer *op) override {
if (op->name == func && op->is_producer) {
found_producer = true;

// Add post-synchronization
internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n";
Stmt body = op->body;

// We don't currently support waiting on producers to the producer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"to the producer" -> "in the producer"?

// half of the fork node. Or rather, if you want to do that you have
// to schedule those Funcs as async too. Check for any consume nodes
// where the producer has gone to the consumer side of the fork
// node.
class FindBadConsumeNodes : public IRVisitor {
const std::set<string> &producers_dropped;
using IRVisitor::visit;

void visit(const ProducerConsumer *op) override {
if (!op->is_producer && producers_dropped.count(op->name)) {
found = op->name;
}
}

public:
string found;
FindBadConsumeNodes(const std::set<string> &p)
: producers_dropped(p) {
}
} finder(producers_dropped);
body.accept(&finder);
if (!finder.found.empty()) {
bad_producer_nesting_error(finder.found, func);
}

while (!sema.empty()) {
Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
body = Block::make(body, Evaluate::make(release));
sema.pop_back();
}
return ProducerConsumer::make_produce(op->name, body);
} else {
if (op->is_producer) {
producers_dropped.insert(op->name);
}
bool found_producer_before = found_producer;
Stmt body = mutate(op->body);
if (!op->is_producer && producers_dropped.count(op->name) &&
found_producer && !found_producer_before) {
// We've found a consume node wrapping our async producer where
// the corresponding producer node was dropped from this half of
// the fork.
bad_producer_nesting_error(op->name, func);
}
if (is_no_op(body) || op->is_producer) {
return body;
} else {
Expand Down
5 changes: 0 additions & 5 deletions src/StorageFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,6 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
to_release = max_required - max_required_next; // This is the last time we use these entries
}

if (provided.used.defined()) {
to_acquire = select(provided.used, to_acquire, 0);
}
// We should always release the required region, even if we don't use it.

// On the first iteration, we need to acquire the extent of the region shared
// between the producer and consumer, and we need to release it on the last
// iteration.
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tests(GROUPS correctness
align_bounds.cpp
argmax.cpp
async_device_copy.cpp
async_order.cpp
autodiff.cpp
bad_likely.cpp
bit_counting.cpp
Expand Down
94 changes: 94 additions & 0 deletions test/correctness/async_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
if (get_jit_target_from_environment().arch == Target::WebAssembly) {
printf("[SKIP] WebAssembly does not support async() yet.\n");
return 0;
}

{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_at(consumer, y);
producer2.compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}
{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_root();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.store_root().compute_at(consumer, y).async();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

printf("Success!\n");
return 0;
}
2 changes: 2 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ tests(GROUPS error
auto_schedule_no_parallel.cpp
auto_schedule_no_reorder.cpp
autodiff_unbounded.cpp
bad_async_producer.cpp
bad_async_producer_2.cpp
bad_bound.cpp
bad_bound_storage.cpp
bad_compute_at.cpp
Expand Down
31 changes: 31 additions & 0 deletions test/error/bad_async_producer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv) {

Func f{"f"}, g{"g"}, h{"h"};
Var x;

f(x) = cast<uint8_t>(x + 7);
g(x) = f(x);
h(x) = g(x);

// The schedule below is an error. It should really be:
// f.store_root().compute_at(g, Var::outermost());
// So that it's nested inside the consumer h.
f.store_root().compute_at(h, x);
g.store_root().compute_at(h, x).async();

Buffer<uint8_t> buf = h.realize({32});
for (int i = 0; i < buf.dim(0).extent(); i++) {
uint8_t correct = i + 7;
if (buf(i) != correct) {
printf("buf(%d) = %d instead of %d\n", i, buf(i), correct);
return 1;
}
}

return 0;
}
23 changes: 23 additions & 0 deletions test/error/bad_async_producer_2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "Halide.h"

using namespace Halide;

// From https://github.com/halide/Halide/issues/5201
int main(int argc, char **argv) {
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer2(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_at(consumer, y).async();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});
return 0;
}
33 changes: 26 additions & 7 deletions test/performance/async_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Expr expensive(Expr x, int c) {
if (c <= 0) {
return x;
} else {
return expensive(fast_pow(x, x + 1), c - 1);
return expensive(x * (x + 1), c - 1);
}
}

Expand All @@ -31,11 +31,12 @@ int main(int argc, char **argv) {
}

double times[2];
uint32_t correct = 0;
for (int use_async = 0; use_async < 2; use_async++) {
Var x, y, t, xi, yi;

ImageParam in(Float(32), 3);
Func cpu, gpu;
ImageParam in(UInt(32), 3);
Func cpu("cpu"), gpu("gpu");

// We have a two-stage pipeline that processes frames. We want
// to run the first stage on the GPU and the second stage on
Expand All @@ -50,26 +51,44 @@ int main(int argc, char **argv) {

// Assume GPU memory is limited, and compute the GPU stage one
// frame at a time. Hoist the allocation to the top level.
gpu.compute_at(cpu, t).store_root().gpu_tile(x, y, xi, yi, 8, 8);
gpu.compute_at(gpu.in(), Var::outermost()).store_root().gpu_tile(x, y, xi, yi, 8, 8);

// Stage the copy-back of the GPU result into a host-side
// double-buffer.
gpu.in().copy_to_host().compute_at(cpu, t).store_root().fold_storage(t, 2);

if (use_async) {
// gpu.async();
gpu.in().async();
gpu.async();
}

in.set(Buffer<float>(800, 800, 16));
Buffer<float> out(800, 800, 16);
Buffer<uint32_t> in_buf(800, 800, 16);
in_buf.fill(17);
in.set(in_buf);
Buffer<uint32_t> out(800, 800, 16);

cpu.compile_jit();

times[use_async] = benchmark(10, 1, [&]() {
cpu.realize(out);
});

if (!use_async) {
correct = out(0, 0, 0);
} else {
for (int t = 0; t < out.dim(2).extent(); t++) {
for (int y = 0; y < out.dim(1).extent(); y++) {
for (int x = 0; x < out.dim(0).extent(); x++) {
if (out(x, y, t) != correct) {
printf("Async output at (%d, %d, %d) is %u instead of %u\n",
x, y, t, out(x, y, t), correct);
return 1;
}
}
}
}
}

printf("%s: %f\n",
use_async ? "with async" : "without async",
times[use_async]);
Expand Down