Skip to content

Commit

Permalink
Export the entity type of the agent in Info Tensor (#98)
Browse files Browse the repository at this point in the history
* Export the entity type of the agent in Info Tensor

* Export Entity Type bindings
  • Loading branch information
aaravpandya authored May 2, 2024
1 parent e25da92 commit 630ffbb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
13 changes: 13 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ namespace gpudrive
.value("AgentRemoved", CollisionBehaviour::AgentRemoved)
.value("Ignore", CollisionBehaviour::Ignore);

nb::enum_<EntityType>(m, "EntityType")
.value("_None", EntityType::None)
.value("RoadEdge", EntityType::RoadEdge)
.value("RoadLine", EntityType::RoadLine)
.value("RoadLane", EntityType::RoadLane)
.value("CrossWalk", EntityType::CrossWalk)
.value("SpeedBump", EntityType::SpeedBump)
.value("StopSign", EntityType::StopSign)
.value("Vehicle", EntityType::Vehicle)
.value("Pedestrian", EntityType::Pedestrian)
.value("Cyclist", EntityType::Cyclist)
.value("Padding", EntityType::Padding)
.value("NumTypes", EntityType::NumTypes);

// Bindings for Manager class
nb::class_<Manager>(m, "SimManager")
Expand Down
2 changes: 2 additions & 0 deletions src/level_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static inline void resetAgent(Engine &ctx, Entity agent) {
ctx.get<Done>(agent).v = 0;
ctx.get<Reward>(agent).v = 0;
ctx.get<Info>(agent) = Info{};
ctx.get<Info>(agent).type = (int32_t)ctx.get<EntityType>(agent);

#ifndef GPUDRIVE_DISABLE_NARROW_PHASE
ctx.get<CollisionDetectionEvent>(agent).hasCollided.store_release(0);
Expand Down Expand Up @@ -354,6 +355,7 @@ static void resetPaddingEntities(Engine &ctx) {
ctx.get<Done>(agent).v = 0;
ctx.get<StepsRemaining>(agent).t = consts::episodeLen;
ctx.get<Info>(agent) = Info{};
ctx.get<Info>(agent).type = (int32_t)ctx.get<EntityType>(agent);
registerRigidBodyEntity(ctx, agent, SimObject::Agent);
}

Expand Down
3 changes: 2 additions & 1 deletion src/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ struct Info{
int collidedWithVehicle;
int collidedWithNonVehicle;
int reachedGoal;
int type;
};

const size_t InfoExportSize = 4;
const size_t InfoExportSize = 5;

static_assert(sizeof(Info) == sizeof(int) * InfoExportSize);

Expand Down

0 comments on commit 630ffbb

Please sign in to comment.