Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolatendresse committed Dec 16, 2024
1 parent 9bfe777 commit 4d39aa4
Showing 1 changed file with 9 additions and 28 deletions.
37 changes: 9 additions & 28 deletions src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,18 @@ Aggregate::Aggregate(FFModel &model,
dims[i] = inputs[FIXED_ARG_CNT]->dims[i];
}

// TODO replace with inputs[0]->dims[num_dim - 2]
ParallelDim topk_values_penultimate_dim;
topk_values_penultimate_dim.size = 1;
topk_values_penultimate_dim.degree = 1;
topk_values_penultimate_dim.parallel_idx = -1;
topk_values_penultimate_dim.is_replica_dim = false;

// TODO replace with inputs[0]->dims[num_dim - 1]
ParallelDim topk_values_last_dim;
topk_values_last_dim.size = 128;
topk_values_last_dim.degree = 1;
topk_values_last_dim.parallel_idx = -1;
topk_values_last_dim.is_replica_dim = false;

// TODO this is all debugging stuff. Need to set for real
dims[num_dim - 3] = topk_values_penultimate_dim;
dims[num_dim - 2] = topk_values_last_dim;
dims[num_dim - 1] = inputs[FIXED_ARG_CNT]->dims[num_dim - 1];
Expand Down Expand Up @@ -305,8 +302,6 @@ void Aggregate::init_inference(FFModel const &ff,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_field(n + FIXED_ARG_CNT, FID_DATA);
// launcher.add_field(FIXED_ARG_CNT, FID_DATA); // TODO undo when I do experts again


FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
Expand Down Expand Up @@ -339,7 +334,6 @@ OpMeta *Aggregate::init_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
// printf("running Aggregate::init_task\n");
Aggregate *agg = (Aggregate *)task->args;
FFHandler handle = *((FFHandler *)task->local_args);
Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc);
Expand All @@ -353,7 +347,6 @@ OpMeta *Aggregate::init_task(Task const *task,
m->output_type[0] = agg->outputs[0]->data_type;
std::strcpy(m->op_name, agg->name);

// TODO three instructions below are not in SigmoidSiluMulti::init_task
m->profiling = agg->profiling;
m->inference_debugging = agg->inference_debugging;
std::strcpy(m->op_name, agg->name);
Expand Down Expand Up @@ -424,8 +417,7 @@ void Aggregate::forward(FFModel const &ff) {
WRITE_ONLY,
EXCLUSIVE,
outputs[0]->region));
// launcher.add_field(n + 2, FID_DATA);
launcher.add_field(n + FIXED_ARG_CNT, FID_DATA); // TODO undo when I do experts again
launcher.add_field(n + FIXED_ARG_CNT, FID_DATA);


runtime->execute_index_space(ctx, launcher);
Expand All @@ -436,18 +428,14 @@ FutureMap Aggregate::inference(FFModel const &ff,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {
// printf("running Aggregate::inference\n");
ArgumentMap argmap;
Context ctx = ff.config.lg_ctx;
Runtime *runtime = ff.config.lg_hlr;
parallel_is = batch_outputs[0]->parallel_is;
MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view;
set_argumentmap_for_inference(ff, argmap, batch_outputs[0]);
size_t machine_view_hash = view->hash();
// This gives segfault
// std::cout << "Aggregate op machine_view: " << *(MachineView const *)mv
// << std::endl;
IndexLauncher launcher(AGGREGATE_FWD_TASK_ID, // TODO should we have a separate inference task?
IndexLauncher launcher(AGGREGATE_FWD_TASK_ID,
parallel_is,
TaskArgument(nullptr, 0),
argmap,
Expand Down Expand Up @@ -503,7 +491,7 @@ FutureMap Aggregate::inference(FFModel const &ff,
EXCLUSIVE,
batch_outputs[0]->region));
// launcher.add_field(n + 2, FID_DATA);
launcher.add_field(n + FIXED_ARG_CNT, FID_DATA); // TODO undo when I do experts again
launcher.add_field(n + FIXED_ARG_CNT, FID_DATA);

return runtime->execute_index_space(ctx, launcher);
}
Expand All @@ -512,17 +500,10 @@ void Aggregate::forward_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
// TODO in the end, create and place our changes in Aggregate::inference_task
// printf("running Aggregate::forward_task\n");
//

BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
//
AggregateMeta const *m = *((AggregateMeta **)task->local_args);
//
int n = regions.size() - FIXED_ARG_CNT - 1; // Last region is for the output
//
// // get gate_pred, gate_assign, output
// get gate_pred, gate_assign, output
AccessorRW<float, 4> const acc_gate_pred(regions[0], FID_DATA); // causes dynamic type mismatch
AccessorRO<int, 4> const acc_gate_assign(regions[1], FID_DATA);
AccessorWO<float, 4> const acc_output(regions[n + FIXED_ARG_CNT], FID_DATA);
Expand All @@ -533,33 +514,33 @@ void Aggregate::forward_task(Task const *task,
ctx, task->regions[1].region.get_index_space());
Rect<4> rect_output = runtime->get_index_space_domain(
ctx, task->regions[n + FIXED_ARG_CNT].region.get_index_space());
//

coord_t batch_size = rect_gate_pred.hi[1] - rect_gate_pred.lo[1] + 1;
assert(batch_size == rect_gate_assign.hi[1] - rect_gate_assign.lo[1] + 1);
assert(rect_gate_pred.hi[0] - rect_gate_pred.lo[0] ==
rect_gate_assign.hi[0] - rect_gate_assign.lo[0]);
assert(batch_size == rect_output.hi[1] - rect_output.lo[1] + 1);
coord_t out_dim = rect_output.hi[0] - rect_output.lo[0] + 1;

// // get exp_preds
// get exp_preds
float *exp_preds[n];
// get first exp_pred and row and out_dim
Domain exp_domain = runtime->get_index_space_domain(
ctx, task->regions[FIXED_ARG_CNT].region.get_index_space());
exp_preds[0] = helperGetTensorPointerWO<float>(regions[FIXED_ARG_CNT], task->regions[FIXED_ARG_CNT], FID_DATA, ctx, runtime);
coord_t rows = exp_domain.hi()[1] - exp_domain.lo()[1] + 1;
assert(out_dim == exp_domain.hi()[0] - exp_domain.lo()[0] + 1);
//

for (int i = 1; i < n; i++) {
exp_domain = runtime->get_index_space_domain(
ctx, task->regions[i + FIXED_ARG_CNT].region.get_index_space());
exp_preds[i] = helperGetTensorPointerWO<float>(
regions[i + FIXED_ARG_CNT], task->regions[i + FIXED_ARG_CNT], FID_DATA, ctx, runtime);
//

assert(rows == exp_domain.hi()[1] - exp_domain.lo()[1] + 1);
assert(out_dim == exp_domain.hi()[0] - exp_domain.lo()[0] + 1);
}
//

int k = (int)(rect_gate_assign.hi[0] - rect_gate_assign.lo[0] + 1);

Aggregate::forward_kernel_wrapper(m,
Expand Down

0 comments on commit 4d39aa4

Please sign in to comment.