diff --git a/python_bindings/src/halide/halide_/PyExpr.cpp b/python_bindings/src/halide/halide_/PyExpr.cpp index 1403e2db9a72..c05eef6c351c 100644 --- a/python_bindings/src/halide/halide_/PyExpr.cpp +++ b/python_bindings/src/halide/halide_/PyExpr.cpp @@ -22,22 +22,10 @@ void define_expr(py::module &m) { auto expr_class = py::class_(m, "Expr") + // Default ctor .def(py::init<>()) - .def(py::init([](bool b) { - return Internal::make_bool(b); - })) - // PyBind11 searches in declared order, - // int should be tried before float conversion - .def(py::init()) - .def(py::init()) - // Python float is implemented by double - // But Halide prohibits implicitly construct by double. - .def(py::init([](double v) { - return double_to_expr_check(v); - })) - .def(py::init()) - // for implicitly_convertible + // For implicitly_convertible .def(py::init([](const FuncRef &f) -> Expr { return f; })) .def(py::init([](const FuncTupleElementRef &f) -> Expr { return f; })) .def(py::init([](const Param<> &p) -> Expr { return p; })) @@ -45,6 +33,31 @@ void define_expr(py::module &m) { .def(py::init([](const RVar &r) -> Expr { return r; })) .def(py::init([](const Var &v) -> Expr { return v; })) + // Weird types + .def(py::init()) + + // Numeric types. + // This is tricky. PyBind11 tries the conversions in declared order, + // and we generally want to prefer int over float conversion (to avoid + // accidental promotion). However, we want to keep a float32 as a float32 + // (e.g. specified via numpy.float32()) and the implicit Expr conversion + // will confuse PyBind, hence the apparently wrong order. + .def(py::init()) + .def(py::init([](bool b) { + return Internal::make_bool(b); + })) + .def(py::init()) + .def(py::init()) + // Most scalar fp values we get from Python will actually be doubles; + // for efficiency, we want to store these as float32 instead of float64. + // This may not always be the right decision -- e.g., if someone + // constructs something via numpy.float64() they will be unhappy -- + // but changing the behavior now would likely cause lots of subtle + // regressions. + .def(py::init([](double v) { + return double_to_expr_check(v); + })) + .def("__bool__", to_bool) .def("__nonzero__", to_bool) diff --git a/python_bindings/test/correctness/basics.py b/python_bindings/test/correctness/basics.py index 204966418e41..b1c2b1e9df5f 100644 --- a/python_bindings/test/correctness/basics.py +++ b/python_bindings/test/correctness/basics.py @@ -450,6 +450,17 @@ def test_implicit_convert_int64(): assert (hl.i32(0) + (0x7fffffff+1)).type() == hl.Int(64) +def test_explicit_expr_ctors(): + assert (hl.Expr(np.bool_(0))).type() == hl.Bool() + assert (hl.Expr(np.int32(0))).type() == hl.Int(32), (hl.Expr(np.int32(0))).type() + assert (hl.Expr(np.int64(0x7fffffff+1))).type() == hl.Int(64), (hl.Expr(np.int64(0x7fffffff+1))).type() + assert (hl.Expr(np.float32(0))).type() == hl.Float(32), (hl.Expr(np.float32(0))).type() + # Note that this is deliberate: we have aggressively downscaled scalar + # float64 values from Python into float32, and we aren't going to change + # that now. + assert (hl.Expr(np.float64(0))).type() == hl.Float(32) + + if __name__ == "__main__": test_compiletime_error() test_runtime_error() @@ -469,3 +480,4 @@ def test_implicit_convert_int64(): test_bool_conversion() test_requirements() test_implicit_convert_int64() + test_explicit_expr_ctors()