Skip to content

Commit

Permalink
Delete agents
Browse files Browse the repository at this point in the history
  • Loading branch information
aaravpandya committed Jan 18, 2025
1 parent 1701596 commit 25fe4e8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 10 deletions.
15 changes: 14 additions & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,20 @@ namespace gpudrive
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor)
.def("set_maps", &Manager::setMaps)
.def("world_means_tensor", &Manager::worldMeansTensor)
.def("metadata_tensor", &Manager::metadataTensor);
.def("metadata_tensor", &Manager::metadataTensor)
.def("deleteAgents", [](Manager &self, nb::dict py_agents_to_delete) {
std::unordered_map<int32_t, std::vector<int32_t>> agents_to_delete;

// Convert Python dict to C++ unordered_map
for (auto item : py_agents_to_delete) {
int32_t key = nb::cast<int32_t>(item.first);
std::vector<int32_t> value = nb::cast<std::vector<int32_t>>(item.second);
agents_to_delete[key] = value;
}

self.deleteAgents(agents_to_delete);
})
.def("deleted_agents_tensor", &Manager::deletedAgentsTensor);
}

}
9 changes: 9 additions & 0 deletions src/level_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,15 @@ static inline bool shouldAgentBeCreated(Engine &ctx, const MapObject &agentInit)
return false;
}

auto& deletedAgents = ctx.singleton<DeletedAgents>().deletedAgents;
for (CountT i = 0; i < consts::kMaxAgentCount; i++)
{
if(deletedAgents[i] == agentInit.id)
{
return false;
}
}

return true;
}

Expand Down
25 changes: 21 additions & 4 deletions src/mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,27 @@ void Manager::setMaps(const std::vector<std::string> &maps)
reset(worldIndices);
}

void Manager::setToDelete(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete)
Tensor Manager::deletedAgentsTensor() const
{
return impl_->exportTensor(ExportID::DeletedAgents, TensorElementType::Int32,
{
impl_->numWorlds,
consts::kMaxAgentCount,
});
}

void Manager::deleteAgents(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete)
{

ResetMap resetmap{
1,
};

if (impl_->cfg.execMode == madrona::ExecMode::CUDA)
{
#ifdef MADRONA_CUDA_SUPPORT
auto &gpu_exec = static_cast<CUDAImpl *>(impl_.get())->gpuExec;
auto agentsToDeleteDevicePtr = (int32_t *)gpu_exec.getExported((uint32_t)ExportID::DeleteAgents);
auto agentsToDeleteDevicePtr = (int32_t *)gpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
for (const auto &[worldIdx, agents] : agentsToDelete)
{
assert(worldIdx < impl_->cfg.scenes.size());
Expand All @@ -659,6 +673,8 @@ void Manager::setToDelete(const std::unordered_map<int32_t, std::vector<int32_t>
{
REQ_CUDA(cudaMemcpy(agentsToDeletePtr + i, &agents[i], sizeof(int32_t), cudaMemcpyHostToDevice));
}
auto resetMapPtr = (ResetMap *)gpu_exec.getExported((uint32_t)ExportID::ResetMap) + worldIdx;
REQ_CUDA(cudaMemcpy(resetMapPtr, &resetmap, sizeof(ResetMap), cudaMemcpyHostToDevice));
}
#else
// Handle the case where CUDA support is not available
Expand All @@ -668,7 +684,7 @@ void Manager::setToDelete(const std::unordered_map<int32_t, std::vector<int32_t>
else
{
auto &cpu_exec = static_cast<CPUImpl *>(impl_.get())->cpuExec;
auto agentsToDeleteDevicePtr = (int32_t *)cpu_exec.getExported((uint32_t)ExportID::DeleteAgents);
auto agentsToDeleteDevicePtr = (int32_t *)cpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
for (const auto &[worldIdx, agents] : agentsToDelete)
{
assert(worldIdx < impl_->cfg.scenes.size());
Expand All @@ -678,13 +694,14 @@ void Manager::setToDelete(const std::unordered_map<int32_t, std::vector<int32_t>
{
memcpy(agentsToDeletePtr + i, &agents[i], sizeof(int32_t));
}
auto resetMapPtr = (ResetMap *)cpu_exec.getExported((uint32_t)ExportID::ResetMap) + worldIdx;
memcpy(resetMapPtr, &resetmap, sizeof(ResetMap));
}
}

std::vector<int32_t> worldIndices(impl_->cfg.scenes.size());
std::iota(worldIndices.begin(), worldIndices.end(), 0);
reset(worldIndices);

}


Expand Down
3 changes: 2 additions & 1 deletion src/mgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Manager {
MGR_EXPORT madrona::py::Tensor expertTrajectoryTensor() const;
MGR_EXPORT madrona::py::Tensor worldMeansTensor() const;
MGR_EXPORT madrona::py::Tensor metadataTensor() const;
MGR_EXPORT madrona::py::Tensor deletedAgentsTensor() const;
madrona::py::Tensor rgbTensor() const;
madrona::py::Tensor depthTensor() const;
// These functions are used by the viewer to control the simulation
Expand All @@ -79,7 +80,7 @@ class Manager {
float headAngle);
MGR_EXPORT void setMaps(const std::vector<std::string> &maps);

MGR_EXPORT void setToDelete(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete);
MGR_EXPORT void deleteAgents(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete);

// TODO: remove parameters
MGR_EXPORT std::vector<Shape>
Expand Down
3 changes: 2 additions & 1 deletion src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.registerSingleton<Map>();
registry.registerSingleton<ResetMap>();
registry.registerSingleton<WorldMeans>();
registry.registerSingleton<DeletedAgents>();

registry.registerArchetype<Agent>();
registry.registerArchetype<PhysicsEntity>();
Expand All @@ -74,7 +75,7 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.exportSingleton<Map>((uint32_t)ExportID::Map);
registry.exportSingleton<ResetMap>((uint32_t)ExportID::ResetMap);
registry.exportSingleton<WorldMeans>((uint32_t)ExportID::WorldMeans);
registry.exportSingleton<DeleteAgents>((uint32_t)ExportID::DeleteAgents);
registry.exportSingleton<DeletedAgents>((uint32_t)ExportID::DeletedAgents);

registry.exportColumn<AgentInterface, Action>(
(uint32_t)ExportID::Action);
Expand Down
2 changes: 1 addition & 1 deletion src/sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enum class ExportID : uint32_t {
ResetMap,
WorldMeans,
MetaData,
DeleteAgents,
DeletedAgents,
NumExports
};

Expand Down
4 changes: 2 additions & 2 deletions src/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ namespace gpudrive
int32_t reset;
};

struct DeleteAgents {
int32_t deleteAgents[consts::kMaxAgentCount];
struct DeletedAgents {
int32_t deletedAgents[consts::kMaxAgentCount];
};

struct WorldMeans {
Expand Down

0 comments on commit 25fe4e8

Please sign in to comment.