Skip to content

Commit

Permalink
Add feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanKazemkhani committed May 8, 2024
1 parent 8dceca2 commit 92c91f7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ namespace gpudrive
Ignore
};

enum class FindRoadObservationsWith {
KNearestEntitiesWithRadiusFiltering,
AllEntitiesWithRadiusFiltering
};

struct Parameters
{
float polylineReductionThreshold;
Expand All @@ -104,6 +109,8 @@ namespace gpudrive
CollisionBehaviour collisionBehaviour = CollisionBehaviour::AgentStop; // Default: AgentStop
uint32_t maxNumControlledVehicles = 10000; // Arbitrary high number to by default control all vehicles
bool IgnoreNonVehicles = false; // Default: false
FindRoadObservationsWith roadObservationAlgorithm{
FindRoadObservationsWith::KNearestEntitiesWithRadiusFiltering};
};

struct WorldInit
Expand Down
14 changes: 14 additions & 0 deletions src/mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,17 @@ static std::vector<std::string> getMapFiles(const Manager::Config &cfg)
return mapFiles;
}

bool isRoadObservationAlgorithmValid(FindRoadObservationsWith algo) {
madrona::CountT roadObservationsCount =
sizeof(AgentMapObservations) / sizeof(MapObservation);

return algo ==
FindRoadObservationsWith::KNearestEntitiesWithRadiusFiltering ||
(algo ==
FindRoadObservationsWith::KNearestEntitiesWithRadiusFiltering &&
roadObservationsCount == consts::kMaxRoadEntityCount);
}

Manager::Impl * Manager::Impl::init(
const Manager::Config &mgr_cfg)
{
Expand All @@ -445,6 +456,9 @@ Manager::Impl * Manager::Impl::init(

std::vector<std::string> mapFiles = getMapFiles(mgr_cfg);

assert(isRoadObservationAlgorithmValid(
mgr_cfg.params.roadObservationAlgorithm));

switch (mgr_cfg.execMode) {
case ExecMode::CUDA: {
#ifdef MADRONA_CUDA_SUPPORT
Expand Down
30 changes: 28 additions & 2 deletions src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,34 @@ inline void collectObservationsSystem(Engine &ctx,
arrIndex++;
}

selectKNearestRoadEntities<consts::kMaxAgentMapObservationsCount>(
ctx, rot, model.position, map_obs.obs);
const auto alg = ctx.data().params.roadObservationAlgorithm;
if (alg == FindRoadObservationsWith::KNearestEntitiesWithRadiusFiltering) {
selectKNearestRoadEntities<consts::kMaxAgentMapObservationsCount>(
ctx, rot, model.position, map_obs.obs);
return;
}

assert(alg == FindRoadObservationsWith::AllEntitiesWithRadiusFiltering);
arrIndex = 0; CountT roadIdx = 0;
while(roadIdx < ctx.data().numRoads) {
Entity road = ctx.data().roads[roadIdx++];
Vector2 relative_pos = Vector2{ctx.get<Position>(road).x, ctx.get<Position>(road).y} - model.position;
relative_pos = rot.inv().rotateVec({relative_pos.x, relative_pos.y, 0}).xy();
if(relative_pos.length() > ctx.data().params.observationRadius)
{
continue;
}
map_obs.obs[arrIndex] = ctx.get<MapObservation>(road);
map_obs.obs[arrIndex].position = map_obs.obs[arrIndex].position - model.position;
arrIndex++;
}
while (arrIndex < consts::kMaxRoadEntityCount)
{
map_obs.obs[arrIndex].position = Vector2{0.f, 0.f};
map_obs.obs[arrIndex].heading = 0.f;
map_obs.obs[arrIndex].type = (float)EntityType::None;
arrIndex++;
}
}

inline void movementSystem(Engine &e,
Expand Down

0 comments on commit 92c91f7

Please sign in to comment.