Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove padding agents using interface entities #103

Merged
merged 51 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
6b0aff7
Init remove padding agents
aaravpandya May 5, 2024
6e37a75
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 5, 2024
f70350b
Testing
aaravpandya May 7, 2024
8904310
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 7, 2024
6265e41
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 7, 2024
63fa549
Refactor out the controlled state
aaravpandya May 8, 2024
e287a63
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 9, 2024
55e411f
Fix the tests
aaravpandya May 11, 2024
293f953
Use by reference
aaravpandya May 11, 2024
2b5384c
Temp fix for map export
aaravpandya May 11, 2024
754f04c
cleanup
aaravpandya May 11, 2024
da59cb1
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 14, 2024
adbb5f9
Fix merge issues
aaravpandya May 14, 2024
9328909
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 25, 2024
50ca28b
Fix merge issues
aaravpandya May 26, 2024
3c1afdd
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 26, 2024
835e8fe
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 29, 2024
79f3a41
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 30, 2024
1e46bbb
make consts 6000 to pass tests
aaravpandya May 30, 2024
0e5c70b
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 30, 2024
8d38136
Use init counts for BVH
aaravpandya May 31, 2024
7ab4983
Merge branch 'main' into ap_removePaddingAgents
aaravpandya May 31, 2024
5917867
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Jun 2, 2024
2139693
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Jun 10, 2024
6421861
Remove unused function
aaravpandya Jun 11, 2024
e88af6c
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Jul 22, 2024
9bc9af9
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Aug 16, 2024
7ac634b
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Aug 20, 2024
92772a1
Pass all tests
aaravpandya Aug 20, 2024
974fd84
Set controlledstate
aaravpandya Aug 20, 2024
6170b84
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Aug 20, 2024
6b67b2b
Zero out things
aaravpandya Aug 27, 2024
c57647b
Zero padded agents out
aaravpandya Aug 28, 2024
3dc766f
Add a test.py for easy debugging
aaravpandya Aug 28, 2024
e104d3d
Better check for map drawing.
aaravpandya Aug 28, 2024
ab3f3f7
Set agents to 128
aaravpandya Aug 28, 2024
4e3a427
Cycle through files in test.py
aaravpandya Aug 28, 2024
a1e19bb
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Aug 29, 2024
5e7e162
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Aug 31, 2024
1eab6d5
Rename InterfaceEntity to AgentInterfaceEntity
aaravpandya Aug 31, 2024
e10e2b3
Only print dones
aaravpandya Sep 5, 2024
67021b7
Remove debug checks
aaravpandya Sep 14, 2024
11cb74a
Separete out the observation systems
aaravpandya Sep 14, 2024
78bb617
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Sep 16, 2024
60cc090
Sort interfaces
aaravpandya Sep 16, 2024
5979e50
remove debug script
aaravpandya Sep 16, 2024
b353f8d
Remove debug statements
aaravpandya Sep 16, 2024
e9f4c83
Merge branch 'main' into ap_removePaddingAgents
aaravpandya Sep 17, 2024
476e691
Minor improvements
aaravpandya Sep 19, 2024
d21a7e4
Info tensor interface matches main
aaravpandya Sep 19, 2024
7685237
Correctly export response type and trajectory
aaravpandya Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pygpudrive/env/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def draw_map(self, surf, map_info, world_render_idx=0):
"""Draw static map elements."""
for idx, map_obj in enumerate(map_info):

if map_obj[-1] == float(gpudrive.EntityType.Padding):
if map_obj[-1] == float(gpudrive.EntityType.Padding) or map_obj[-1] == float(gpudrive.EntityType._None):
continue

elif map_obj[-1] <= float(gpudrive.EntityType.RoadLane):
Expand Down
2 changes: 1 addition & 1 deletion src/consts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <limits>
#include <madrona/math.hpp>
#include <madrona/types.hpp>

namespace gpudrive {

namespace consts {
Expand Down
2 changes: 1 addition & 1 deletion src/dynamics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace madrona::math;
namespace gpudrive
{

inline void forwardKinematics(Action &action, VehicleSize &size, Rotation &rotation, Position &position, Velocity &velocity)
inline void forwardKinematics(const Action &action, VehicleSize &size, Rotation &rotation, Position &position, Velocity &velocity)
{
const float maxSpeed{std::numeric_limits<float>::max()};
const float dt{0.1};
Expand Down
13 changes: 7 additions & 6 deletions src/headless.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,17 @@ int main(int argc, char *argv[])
// printf("Reward\n");
// rewardPrinter.print();

// printf("Done\n");
// donePrinter.print();
printf("Done\n");
donePrinter.print();

// printf("Controlled State\n");
// controlledStatePrinter.print();

printf("Agent Map Obs\n");
agent_map_obs_printer.print();
// printf("Agent Map Obs\n");
// agent_map_obs_printer.print();

printf("Info\n");
info_printer.print();
// printf("Info\n");
// info_printer.print();
};

auto worldToShape =
Expand All @@ -133,6 +133,7 @@ int main(int argc, char *argv[])
}
const auto end = std::chrono::steady_clock::now();
const std::chrono::duration<double> elapsed = end - start;
printObs();

float fps = (double)num_steps * (double)num_worlds / elapsed.count();
printf("FPS %f\n", fps);
Expand Down
3 changes: 2 additions & 1 deletion src/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ bool isObservationsValid(gpudrive::Engine &ctx,
sortedObservations.reserve(roadCount);

for (madrona::CountT roadIdx = 0; roadIdx < roadCount; ++roadIdx) {
auto &road_iface = ctx.get<gpudrive::RoadInterfaceEntity>(ctx.data().roads[roadIdx]).e;
const auto &currentObservation =
ctx.get<gpudrive::MapObservation>(ctx.data().roads[roadIdx]);
ctx.get<gpudrive::MapObservation>(road_iface);
sortedObservations.emplace_back(relativeObservation(
currentObservation, referenceRotation, referencePosition));
}
Expand Down
184 changes: 79 additions & 105 deletions src/level_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ static void registerRigidBodyEntity(
}

static inline void resetAgent(Engine &ctx, Entity agent) {
auto xCoord = ctx.get<Trajectory>(agent).positions[0].x;
auto yCoord = ctx.get<Trajectory>(agent).positions[0].y;
auto xVelocity = ctx.get<Trajectory>(agent).velocities[0].x;
auto yVelocity = ctx.get<Trajectory>(agent).velocities[0].y;
auto speed = ctx.get<Trajectory>(agent).velocities[0].length();
auto heading = ctx.get<Trajectory>(agent).headings[0];
auto agent_iface = ctx.get<AgentInterfaceEntity>(agent).e;
auto xCoord = ctx.get<Trajectory>(agent_iface).positions[0].x;
auto yCoord = ctx.get<Trajectory>(agent_iface).positions[0].y;
auto xVelocity = ctx.get<Trajectory>(agent_iface).velocities[0].x;
auto yVelocity = ctx.get<Trajectory>(agent_iface).velocities[0].y;
auto speed = ctx.get<Trajectory>(agent_iface).velocities[0].length();
auto heading = ctx.get<Trajectory>(agent_iface).headings[0];

ctx.get<Position>(agent) = Vector3{.x = xCoord, .y = yCoord, .z = 1};
ctx.get<Rotation>(agent) = Quat::angleAxis(heading, madrona::math::up);
Expand All @@ -37,30 +38,31 @@ static inline void resetAgent(Engine &ctx, Entity agent) {
{
case DynamicsModel::Classic:
{
ctx.get<Action>(agent) = Action{.classic = {0, 0, 0}};
ctx.get<Action>(agent_iface) = Action{.classic = {0, 0, 0}};
break;
}
case DynamicsModel::InvertibleBicycle:
{
ctx.get<Action>(agent) = Action{.classic = {0, 0, 0}};
ctx.get<Action>(agent_iface) = Action{.classic = {0, 0, 0}};
break;
}
case DynamicsModel::DeltaLocal:
{
ctx.get<Action>(agent) = Action{.delta{.dx = 0, .dy = 0, .dyaw = 0}};
ctx.get<Action>(agent_iface) = Action{.delta{.dx = 0, .dy = 0, .dyaw = 0}};
break;
}
case DynamicsModel::State:
{
ctx.get<Action>(agent) = Action{.state = {.position = Vector3{0, 0, 1}, .yaw = 0, .velocity = {.linear = Vector3::zero(), .angular = Vector3::zero()}}};
ctx.get<Action>(agent_iface) = Action{.state = {.position = Vector3{0, 0, 1}, .yaw = 0, .velocity = {.linear = Vector3::zero(), .angular = Vector3::zero()}}};
break;
}
}
ctx.get<StepsRemaining>(agent).t = consts::episodeLen;
ctx.get<Done>(agent).v = 0;
ctx.get<Reward>(agent).v = 0;
ctx.get<Info>(agent) = Info{};
ctx.get<Info>(agent).type = (int32_t)ctx.get<EntityType>(agent);
ctx.get<StepsRemaining>(agent_iface).t = consts::episodeLen;
ctx.get<Done>(agent_iface).v = 0;
ctx.get<Reward>(agent_iface).v = 0;
ctx.get<Info>(agent_iface) = Info{};
ctx.get<Info>(agent_iface).type = (int32_t)ctx.get<EntityType>(agent);
ctx.get<ResponseType>(agent_iface) = ctx.get<ResponseType>(agent);

if(ctx.get<ResponseType>(agent) == ResponseType::Static)
{
Expand All @@ -74,7 +76,8 @@ static inline void resetAgent(Engine &ctx, Entity agent) {
}

static inline void populateExpertTrajectory(Engine &ctx, const Entity &agent, const MapObject &agentInit) {
auto &trajectory = ctx.get<Trajectory>(agent);
const auto &agent_iface = ctx.get<AgentInterfaceEntity>(agent).e;
auto &trajectory = ctx.get<Trajectory>(agent_iface);
for(CountT i = 0; i < agentInit.numPositions; i++)
{
trajectory.positions[i] = Vector2{.x = agentInit.position[i].x - ctx.data().mean.x, .y = agentInit.position[i].y - ctx.data().mean.y};
Expand Down Expand Up @@ -141,21 +144,23 @@ static inline Entity createAgent(Engine &ctx, const MapObject &agentInit) {
assert(agentInit.type >= EntityType::Vehicle || agentInit.type == EntityType::None);
ctx.get<EntityType>(agent) = agentInit.type;

auto agent_iface = ctx.get<AgentInterfaceEntity>(agent).e = ctx.makeEntity<AgentInterface>();

ctx.get<Goal>(agent)= Goal{.position = Vector2{.x = agentInit.goalPosition.x - ctx.data().mean.x, .y = agentInit.goalPosition.y - ctx.data().mean.y}};
populateExpertTrajectory(ctx, agent, agentInit);
if(!ctx.data().params.isStaticAgentControlled && (ctx.get<Goal>(agent).position - ctx.get<Trajectory>(agent).positions[0]).length() < consts::staticThreshold)
if(!ctx.data().params.isStaticAgentControlled && (ctx.get<Goal>(agent).position - ctx.get<Trajectory>(agent_iface).positions[0]).length() < consts::staticThreshold)
{
ctx.get<ResponseType>(agent) = ResponseType::Static;
}

if(ctx.data().numControlledVehicles < ctx.data().params.maxNumControlledVehicles && agentInit.type == EntityType::Vehicle && agentInit.valid[0] && ctx.get<ResponseType>(agent) == ResponseType::Dynamic)
{
ctx.get<ControlledState>(agent) = ControlledState{.controlled = 1};
ctx.get<ControlledState>(agent_iface) = ControlledState{.controlled = 1};
ctx.data().numControlledVehicles++;
}
else
{
ctx.get<ControlledState>(agent) = ControlledState{.controlled = 0};
ctx.get<ControlledState>(agent_iface) = ControlledState{.controlled = 0};
}

// This is not stricly necessary since , but is kept here for consistency
Expand Down Expand Up @@ -189,7 +194,8 @@ static Entity makeRoadEdge(Engine &ctx, const MapVector2 &p1,
ctx.get<ObjectID>(road_edge) = ObjectID{(int32_t)SimObject::Cube};
registerRigidBodyEntity(ctx, road_edge, SimObject::Cube);
ctx.get<ResponseType>(road_edge) = ResponseType::Static;
ctx.get<MapObservation>(road_edge) = MapObservation{.position = ctx.get<Position>(road_edge).xy(),
auto road_iface = ctx.get<RoadInterfaceEntity>(road_edge).e = ctx.makeEntity<RoadInterface>();
ctx.get<MapObservation>(road_iface) = MapObservation{.position = ctx.get<Position>(road_edge).xy(),
.scale = ctx.get<Scale>(road_edge),
.heading = utils::quatToYaw(ctx.get<Rotation>(road_edge)),
.type = (float)type};
Expand Down Expand Up @@ -262,7 +268,8 @@ static Entity makeCube(Engine &ctx, const MapVector2 &p1, const MapVector2 &p2,
ctx.get<ObjectID>(speed_bump) = ObjectID{(int32_t)SimObject::SpeedBump};
registerRigidBodyEntity(ctx, speed_bump, SimObject::SpeedBump);
ctx.get<ResponseType>(speed_bump) = ResponseType::Static;
ctx.get<MapObservation>(speed_bump) = MapObservation{.position = ctx.get<Position>(speed_bump).xy(),
auto road_iface = ctx.get<RoadInterfaceEntity>(speed_bump).e = ctx.makeEntity<RoadInterface>();
ctx.get<MapObservation>(road_iface) = MapObservation{.position = ctx.get<Position>(speed_bump).xy(),
.scale = ctx.get<Scale>(speed_bump),
.heading = utils::quatToYaw(ctx.get<Rotation>(speed_bump)),
.type = (float)type};
Expand All @@ -281,7 +288,8 @@ static Entity makeStopSign(Engine &ctx, const MapVector2 &p1) {
ctx.get<ObjectID>(stop_sign) = ObjectID{(int32_t)SimObject::StopSign};
registerRigidBodyEntity(ctx, stop_sign, SimObject::StopSign);
ctx.get<ResponseType>(stop_sign) = ResponseType::Static;
ctx.get<MapObservation>(stop_sign) = MapObservation{.position = ctx.get<Position>(stop_sign).xy(),
auto road_iface = ctx.get<RoadInterfaceEntity>(stop_sign).e = ctx.makeEntity<RoadInterface>();
ctx.get<MapObservation>(road_iface) = MapObservation{.position = ctx.get<Position>(stop_sign).xy(),
.scale = ctx.get<Scale>(stop_sign),
.heading = utils::quatToYaw(ctx.get<Rotation>(stop_sign)),
.type = (float)EntityType::StopSign};
Expand All @@ -292,27 +300,36 @@ static inline void createRoadEntities(Engine &ctx, const MapRoad &roadInit, Coun
if (roadInit.type == EntityType::RoadEdge || roadInit.type == EntityType::RoadLine || roadInit.type == EntityType::RoadLane)
{
size_t numPoints = roadInit.numPoints;
for(size_t j = 1; j <= numPoints - 1; j++)
for (size_t j = 1; j <= numPoints - 1; j++)
{
if(idx >= consts::kMaxRoadEntityCount)
return;
ctx.data().roads[idx++] = makeRoadEdge(ctx, roadInit.geometry[j-1], roadInit.geometry[j], roadInit.type);
if (idx >= consts::kMaxRoadEntityCount)
return;
auto road = ctx.data().roads[idx] = makeRoadEdge(ctx, roadInit.geometry[j - 1], roadInit.geometry[j], roadInit.type);
ctx.data().road_ifaces[idx++] = ctx.get<RoadInterfaceEntity>(road).e;
}
} else if (roadInit.type == EntityType::SpeedBump || roadInit.type == EntityType::CrossWalk) {
assert(roadInit.numPoints >= 4);
// TODO: Speed Bump are not guranteed to have 4 points. Need to handle this case.
if(idx >= consts::kMaxRoadEntityCount)
return;
ctx.data().roads[idx++] = makeCube(ctx, roadInit.geometry[0], roadInit.geometry[1], roadInit.geometry[2], roadInit.geometry[3], roadInit.type);
} else if (roadInit.type == EntityType::StopSign ) {
assert(roadInit.numPoints >= 1);
// TODO: Stop Sign are not guranteed to have 1 point. Need to handle this case.
if(idx >= consts::kMaxRoadEntityCount)
return;
ctx.data().roads[idx++] = makeStopSign(ctx, roadInit.geometry[0]);
} else {
// TODO: Need to handle Cross Walk.
// assert(false);
}
else if (roadInit.type == EntityType::SpeedBump || roadInit.type == EntityType::CrossWalk)
{
assert(roadInit.numPoints >= 4);
// TODO: Speed Bump are not guranteed to have 4 points. Need to handle this case.
if (idx >= consts::kMaxRoadEntityCount)
return;
auto road = ctx.data().roads[idx] = makeCube(ctx, roadInit.geometry[0], roadInit.geometry[1], roadInit.geometry[2], roadInit.geometry[3], roadInit.type);
ctx.data().road_ifaces[idx++] = ctx.get<RoadInterfaceEntity>(road).e;
}
else if (roadInit.type == EntityType::StopSign)
{
assert(roadInit.numPoints >= 1);
// TODO: Stop Sign are not guranteed to have 1 point. Need to handle this case.
if (idx >= consts::kMaxRoadEntityCount)
return;
auto road = ctx.data().roads[idx] = makeStopSign(ctx, roadInit.geometry[0]);
ctx.data().road_ifaces[idx++] = ctx.get<RoadInterfaceEntity>(road).e;
}
else
{
// TODO: Need to handle Cross Walk.
// assert(false);
return;
}
}
Expand All @@ -330,58 +347,33 @@ static void createFloorPlane(Engine &ctx)
registerRigidBodyEntity(ctx, ctx.data().floorPlane, SimObject::Plane);
}

static inline Entity createAgentPadding(Engine &ctx) {
auto agent = ctx.makeRenderableEntity<Agent>();

ctx.get<Position>(agent) = consts::kPaddingPosition;
ctx.get<Rotation>(agent) = Quat::angleAxis(0, madrona::math::up);
ctx.get<Scale>(agent) = Diag3x3{.d0 = 0, .d1 = 0, .d2 = 0};
ctx.get<Velocity>(agent) = {Vector3::zero(), Vector3::zero()};
ctx.get<ObjectID>(agent) = ObjectID{(int32_t)SimObject::Agent};
ctx.get<ResponseType>(agent) = ResponseType::Static;
ctx.get<EntityType>(agent) = EntityType::Padding;
ctx.get<CollisionDetectionEvent>(agent).hasCollided.store_release(0);
ctx.get<Done>(agent).v = 0;
ctx.get<StepsRemaining>(agent).t = consts::episodeLen;
ctx.get<ControlledState>(agent) = ControlledState{.controlled = 0};

if (ctx.data().enableRender) {
render::RenderingSystem::attachEntityToView(ctx,
agent,
90.f, 0.001f,
1.5f * math::up);
}

return agent;
}

static inline Entity createPhysicsEntityPadding(Engine &ctx) {
auto physicsEntity = ctx.makeRenderableEntity<PhysicsEntity>();

ctx.get<Position>(physicsEntity) = consts::kPaddingPosition;
ctx.get<Rotation>(physicsEntity) = Quat::angleAxis(0, madrona::math::up);
ctx.get<Scale>(physicsEntity) = Diag3x3{.d0 = 0, .d1 = 0, .d2 = 0};
ctx.get<Velocity>(physicsEntity) = {Vector3::zero(), Vector3::zero()};
ctx.get<ObjectID>(physicsEntity) = ObjectID{(int32_t)SimObject::Cube};
ctx.get<ResponseType>(physicsEntity) = ResponseType::Static;
ctx.get<MapObservation>(physicsEntity) = MapObservation{.position = ctx.get<Position>(physicsEntity).xy(),
.scale = ctx.get<Scale>(physicsEntity),
.heading = utils::quatToYaw(ctx.get<Rotation>(physicsEntity)),
.type = float(EntityType::Padding)};
ctx.get<EntityType>(physicsEntity) = EntityType::Padding;

return physicsEntity;
}

void createPaddingEntities(Engine &ctx) {
for (CountT agentIdx = ctx.data().numAgents;
agentIdx < consts::kMaxAgentCount; ++agentIdx) {
ctx.data().agents[agentIdx] = createAgentPadding(ctx);
Entity &agent_iface = ctx.data().agent_ifaces[agentIdx] = ctx.makeEntity<AgentInterface>();
ctx.get<ControlledState>(agent_iface) = ControlledState{.controlled = 0};
ctx.get<Done>(agent_iface).v = 1;
ctx.get<Reward>(agent_iface).v = 0;
ctx.get<Info>(agent_iface) = Info::zero();
ctx.get<ResponseType>(agent_iface) = ResponseType::Static;
auto &agent_map_obs = ctx.get<AgentMapObservations>(agent_iface);
for (CountT i = 0; i < consts::kMaxAgentMapObservationsCount; i++) {
agent_map_obs.obs[i] = MapObservation::zero();
}
auto &self_obs = ctx.get<SelfObservation>(agent_iface);
self_obs = SelfObservation::zero();

auto &partner_obs = ctx.get<PartnerObservations>(agent_iface);
for (CountT i = 0; i < consts::kMaxAgentCount-1; i++) {
partner_obs.obs[i] = PartnerObservation::zero();
}

}

for (CountT roadIdx = ctx.data().numRoads;
roadIdx < consts::kMaxRoadEntityCount; ++roadIdx) {
ctx.data().roads[roadIdx] = createPhysicsEntityPadding(ctx);
Entity &e = ctx.data().road_ifaces[roadIdx] = ctx.makeEntity<RoadInterface>();
ctx.get<MapObservation>(e) = MapObservation::zero();
}
}

Expand Down Expand Up @@ -429,6 +421,7 @@ void createPersistentEntities(Engine &ctx, Map *map) {
}
auto agent = createAgent(
ctx, agentInit);
ctx.data().agent_ifaces[agentIdx] = ctx.get<AgentInterfaceEntity>(agent).e;
ctx.data().agents[agentIdx++] = agent;
}

Expand All @@ -451,24 +444,6 @@ void createPersistentEntities(Engine &ctx, Map *map) {
createPaddingEntities(ctx);
}

static void resetPaddingEntities(Engine &ctx) {
for (CountT agentIdx = ctx.data().numAgents;
agentIdx < consts::kMaxAgentCount; ++agentIdx) {
Entity agent = ctx.data().agents[agentIdx];
ctx.get<Done>(agent).v = 0;
ctx.get<StepsRemaining>(agent).t = consts::episodeLen;
ctx.get<Info>(agent) = Info{};
ctx.get<Info>(agent).type = (int32_t)ctx.get<EntityType>(agent);
registerRigidBodyEntity(ctx, agent, SimObject::Agent);
}

for (CountT roadIdx = ctx.data().numRoads;
roadIdx < consts::kMaxRoadEntityCount; ++roadIdx) {
Entity road = ctx.data().roads[roadIdx];
registerRigidBodyEntity(ctx, road, SimObject::Cube);
}
}

static void resetPersistentEntities(Engine &ctx)
{

Expand Down Expand Up @@ -516,7 +491,6 @@ static void resetPersistentEntities(Engine &ctx)
void generateWorld(Engine &ctx)
{
resetPersistentEntities(ctx);
resetPaddingEntities(ctx);
}

}
2 changes: 1 addition & 1 deletion src/mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ Tensor Manager::agentMapObservationsTensor() const
TensorElementType::Float32,
{
impl_->numWorlds,
consts::kMaxAgentCount,
consts::kMaxAgentCount,
consts::kMaxAgentMapObservationsCount,
AgentMapObservationExportSize,
});
Expand Down
Loading
Loading