diff --git a/test/generator/alias_aottest.cpp b/test/generator/alias_aottest.cpp index 6ddc1d911c69..a9f21ba26914 100644 --- a/test/generator/alias_aottest.cpp +++ b/test/generator/alias_aottest.cpp @@ -52,6 +52,7 @@ int main(int argc, char **argv) { output.fill(0); output.copy_to_host(); alias_Mullapudi2016(input, output); + output.copy_to_host(); input.for_each_element([=](int x) { assert(output(x) == input(x) + 2016); }); diff --git a/test/generator/autograd_aottest.cpp b/test/generator/autograd_aottest.cpp index b90616964dc8..70f2c881b9c1 100644 --- a/test/generator/autograd_aottest.cpp +++ b/test/generator/autograd_aottest.cpp @@ -110,6 +110,17 @@ int main(int argc, char **argv) { exit(1); } + grad_loss_out_wrt_a.copy_to_host(); + grad_loss_out_wrt_b.copy_to_host(); + grad_loss_out_wrt_c.copy_to_host(); + dummy_grad_loss_output_wrt_lut.copy_to_host(); + dummy_grad_loss_output_wrt_lut_indices.copy_to_host(); + dummy_grad_loss_output_lut_wrt_input_a.copy_to_host(); + dummy_grad_loss_output_lut_wrt_input_b.copy_to_host(); + dummy_grad_loss_output_lut_wrt_input_c.copy_to_host(); + grad_loss_output_lut_wrt_lut.copy_to_host(); + grad_loss_output_lut_wrt_lut_indices.copy_to_host(); + // Although the values are float, all should be exact results, // so we don't need to worry about comparing vs. an epsilon grad_loss_out_wrt_a.for_each_element([&](int x) { @@ -118,18 +129,21 @@ int main(int argc, char **argv) { float actual = grad_loss_out_wrt_a(x); assert(expected == actual); }); + grad_loss_out_wrt_b.for_each_element([&](int x) { // ∂𝐿/∂b = b * 44 * L float expected = L(x) * b(x) * 44.f; float actual = grad_loss_out_wrt_b(x); assert(expected == actual); }); + grad_loss_out_wrt_c.for_each_element([&](int x) { // ∂𝐿/∂c = 11 * L float expected = L(x) * 11.f; float actual = grad_loss_out_wrt_c(x); assert(expected == actual); }); + dummy_grad_loss_output_wrt_lut.for_each_value([](float f) { assert(f == 0.f); }); dummy_grad_loss_output_wrt_lut_indices.for_each_value([](float f) { assert(f == 0.f); }); dummy_grad_loss_output_lut_wrt_input_a.for_each_value([](float f) { assert(f == 0.f); });