Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UB-introducing rewrite in FindIntrinsics #8539

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ class FindIntrinsics : public IRMutator {
Scope<ConstantInterval> let_var_bounds;

Expr lossless_cast(Type t, const Expr &e) {
return Halide::Internal::lossless_cast(t, e, &bounds_cache);
return Halide::Internal::lossless_cast(t, e, let_var_bounds, &bounds_cache);
}

ConstantInterval constant_integer_bounds(const Expr &e) {
// TODO: Use the scope - add let visitors
return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache);
}

Expand Down Expand Up @@ -210,6 +209,51 @@ class FindIntrinsics : public IRMutator {
return Expr();
}

template<typename LetOrLetStmt>
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
struct Frame {
const LetOrLetStmt *orig;
Expr new_value;
ScopedBinding<ConstantInterval> bind;
Frame(const LetOrLetStmt *orig,
Expr &&new_value,
ScopedBinding<ConstantInterval> &&bind)
: orig(orig), new_value(std::move(new_value)), bind(std::move(bind)) {
}
};
std::vector<Frame> frames;
decltype(op->body) body;
while (op) {
Expr v = mutate(op->value);
ConstantInterval b = constant_integer_bounds(v);
frames.emplace_back(op,
std::move(v),
ScopedBinding<ConstantInterval>(let_var_bounds, op->name, b));
body = op->body;
op = body.template as<LetOrLetStmt>();
}

body = mutate(body);

for (const auto &f : reverse_view(frames)) {
if (f.new_value.same_as(f.orig->value) && body.same_as(f.orig->body)) {
body = f.orig;
} else {
body = LetOrLetStmt::make(f.orig->name, f.new_value, body);
}
}

return body;
}

Expr visit(const Let *op) override {
return visit_let(op);
}

Stmt visit(const LetStmt *op) override {
return visit_let(op);
}

Expr visit(const Add *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
Expand Down Expand Up @@ -697,7 +741,12 @@ class FindIntrinsics : public IRMutator {
bool is_saturated = op->value.as<Max>() || op->value.as<Min>();
Expr a = lossless_cast(op->type, shift->args[0]);
Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]);
if (a.defined() && b.defined()) {
// Doing the shift in the narrower type might introduce UB where
// there was no UB before, so we need to make sure b is bounded.
auto b_bounds = constant_integer_bounds(b);
const int max_shift = op->type.bits() - 1;

if (a.defined() && b.defined() && b_bounds >= -max_shift && b_bounds <= max_shift) {
if (!is_saturated ||
(shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) ||
(shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) {
Expand Down Expand Up @@ -1118,8 +1167,12 @@ class SubstituteInWideningLets : public IRMutator {
std::string name;
Expr new_value;
ScopedBinding<Expr> bind;
Frame(const std::string &name, const Expr &new_value, ScopedBinding<Expr> &&bind)
: name(name), new_value(new_value), bind(std::move(bind)) {
Frame(const std::string &name,
const Expr &new_value,
ScopedBinding<Expr> &&bind)
: name(name),
new_value(new_value),
bind(std::move(bind)) {
}
};
std::vector<Frame> frames;
Expand Down
57 changes: 30 additions & 27 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,20 @@ Expr const_false(int w) {
return make_zero(UInt(1, w));
}

Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache) {
Expr lossless_cast(Type t,
Expr e,
const Scope<ConstantInterval> &scope,
std::map<Expr, ConstantInterval, ExprCompare> *cache) {
if (!e.defined() || t == e.type()) {
return e;
} else if (t.can_represent(e.type())) {
return cast(t, std::move(e));
} else if (const Cast *c = e.as<Cast>()) {
if (c->type.can_represent(c->value.type())) {
return lossless_cast(t, c->value, cache);
return lossless_cast(t, c->value, scope, cache);
}
} else if (const Broadcast *b = e.as<Broadcast>()) {
Expr v = lossless_cast(t.element_of(), b->value, cache);
Expr v = lossless_cast(t.element_of(), b->value, scope, cache);
if (v.defined()) {
return Broadcast::make(v, b->lanes);
}
Expand All @@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
} else if (const Shuffle *shuf = e.as<Shuffle>()) {
std::vector<Expr> vecs;
for (const auto &vec : shuf->vectors) {
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache));
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, scope, cache));
if (!vecs.back().defined()) {
return Expr();
}
Expand All @@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
} else if (t.is_int_or_uint()) {
// Check the bounds. If they're small enough, we can throw narrowing
// casts around e, or subterms.
ConstantInterval ci = constant_integer_bounds(e, Scope<ConstantInterval>::empty_scope(), cache);
ConstantInterval ci = constant_integer_bounds(e, scope, cache);

if (t.can_represent(ci)) {
// There are certain IR nodes where if the result is expressible
// using some type, and the args are expressible using that type,
// then the operation can just be done in that type.
if (const Add *op = e.as<Add>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Add::make(a, b);
}
} else if (const Sub *op = e.as<Sub>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Sub::make(a, b);
}
} else if (const Mul *op = e.as<Mul>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Mul::make(a, b);
}
} else if (const Min *op = e.as<Min>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
debug(0) << a << " " << b << "\n";
return Min::make(a, b);
}
} else if (const Max *op = e.as<Max>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Max::make(a, b);
}
} else if (const Mod *op = e.as<Mod>()) {
Expr a = lossless_cast(t, op->a, cache);
Expr b = lossless_cast(t, op->b, cache);
Expr a = lossless_cast(t, op->a, scope, cache);
Expr b = lossless_cast(t, op->b, scope, cache);
if (a.defined() && b.defined()) {
return Mod::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Add::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Sub::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
return Mul::make(a, b);
}
} else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left,
Call::shift_right, Call::widening_shift_right})) {
Expr a = lossless_cast(t, op->args[0], cache);
Expr b = lossless_cast(t, op->args[1], cache);
Expr a = lossless_cast(t, op->args[0], scope, cache);
Expr b = lossless_cast(t, op->args[1], scope, cache);
if (a.defined() && b.defined()) {
ConstantInterval cb = constant_integer_bounds(b, Scope<ConstantInterval>::empty_scope(), cache);
ConstantInterval cb = constant_integer_bounds(b, scope, cache);
if (cb > -t.bits() && cb < t.bits()) {
if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
return a << b;
Expand All @@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
if (op->op == VectorReduce::Add ||
op->op == VectorReduce::Min ||
op->op == VectorReduce::Max) {
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache);
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache);
if (v.defined()) {
return VectorReduce::make(op->op, v, op->type.lanes());
}
Expand Down
19 changes: 12 additions & 7 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <map>
#include <optional>

#include "ConstantInterval.h"
#include "Expr.h"
#include "Scope.h"
#include "Target.h"
#include "Tuple.h"

Expand Down Expand Up @@ -150,13 +152,16 @@ Expr const_false(int lanes = 1);
/** Attempt to cast an expression to a smaller type while provably not losing
* information. If it can't be done, return an undefined Expr.
*
* Optionally accepts a map that gives the constant bounds of exprs already
* analyzed to avoid redoing work across many calls to lossless_cast. It is not
* safe to use this optional map in contexts where the same Expr object may
* take on a different value. For example:
* (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)).
* It is safe to use it after uniquify_variable_names has been run. */
Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);
* Optionally accepts a scope giving the constant bounds of any variables, and a
* map that gives the constant bounds of exprs already analyzed to avoid redoing
* work across many calls to lossless_cast. It is not safe to use this optional
* map in contexts where the same Expr object may take on a different value. For
* example: (let x = 4 in some_expr_object) + (let x = 5 in
* the_same_expr_object)). It is safe to use it after uniquify_variable_names
* has been run. */
Expr lossless_cast(Type t, Expr e,
const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope(),
std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);

/** Attempt to negate x without introducing new IR and without overflow.
* If it can't be done, return an undefined Expr. */
Expand Down
18 changes: 9 additions & 9 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,17 @@ int main(int argc, char **argv) {
check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4)));
check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4)));

// And with variable shifts.
check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y));
// And with variable shifts. These won't match unless Halide can statically
// prove it's not an out-of-range shift.
Expr u8yc = Min::make(u8y, make_const(u8y.type(), 7));
Expr i8yc = Max::make(Min::make(i8y, make_const(i8y.type(), 7)), make_const(i8y.type(), -7));
check(i8(widening_add(i8x, (i8(1) << u8yc) / 2) >> u8yc), rounding_shift_right(i8x, u8yc));
check((i32x + (i32(1) << u32y) / 2) >> u32y, rounding_shift_right(i32x, u32y));

check(i8(widening_add(i8x, (i8(1) << max(i8y, 0)) / 2) >> i8y), rounding_shift_right(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) << max(i8yc, 0)) / 2) >> i8yc), rounding_shift_right(i8x, i8yc));
check((i32x + (i32(1) << max(i32y, 0)) / 2) >> i32y, rounding_shift_right(i32x, i32y));

check(i8(widening_add(i8x, (i8(1) >> min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) >> min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
check((i32x + (i32(1) >> min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));

check(i8(widening_add(i8x, (i8(1) << -min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
check(i8(widening_add(i8x, (i8(1) << -min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
check((i32x + (i32(1) << -min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));
check((i32x + (i32(1) << max(-i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));

Expand Down Expand Up @@ -372,7 +372,7 @@ int main(int argc, char **argv) {
f(x) = cast<uint8_t>(x);
f.compute_root();

g(x) = rounding_shift_right(x, 0) + rounding_shift_left(x, 8);
g(x) = rounding_shift_right(f(x), 0) + u8(rounding_shift_left(u16(f(x)), 11));

g.compile_jit();
}
Expand Down