Skip to content

Commit

Permalink
fix: resolve runtime error in preprocess and postprocess
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Apr 19, 2024
1 parent a65a8b8 commit 3fef41c
Show file tree
Hide file tree
Showing 13 changed files with 529 additions and 521 deletions.
1 change: 0 additions & 1 deletion perception/tensorrt_mtr/config/tensorrt_mtr.param.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
max_num_polyline: 798
max_num_point: 20
point_break_distance: 1.0
offset_xy: [30.0, 0.0]
intention_point_filepath: "$(var data_path)/intention_point.csv"
num_intention_point_cluster: 64
polyline_label_path: "$(var data_path)/polyline_label.txt"
141 changes: 68 additions & 73 deletions perception/tensorrt_mtr/include/tensorrt_mtr/agent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ enum AgentLabel { VEHICLE = 0, PEDESTRIAN = 1, CYCLIST = 2 };
*/
struct AgentState
{
/**
* @brief Construct a new instance filling all elements by `0.0f`.
*/
// Construct a new instance filling all elements by `0.0f`.
AgentState() : data_({0.0f}) {}

/**
Expand Down Expand Up @@ -84,33 +82,49 @@ struct AgentState
std::copy(itr, itr + dim(), data_.begin());
}

static size_t dim() { return AgentStateDim; }

/**
* @brief Construct a new instance filling all elements by `0.0f`.
*
* @return AgentState
*/
// Construct a new instance filling all elements by `0.0f`.
static AgentState empty() noexcept { return AgentState(); }

/**
* @brief Return the address pointer of data array.
*
* @return float*
*/
// Return the agent state dimensions `D`.
static size_t dim() { return AgentStateDim; }

// Return the address pointer of data array.
float * data_ptr() noexcept { return data_.data(); }

// Return the x position.
float x() const { return x_; }

// Return the y position.
float y() const { return y_; }

// Return the z position.
float z() const { return z_; }

// Return the length of object size.
float length() const { return length_; }

// Return the width of object size.
float width() const { return width_; }

// Return the height of object size.
float height() const { return height_; }

// Return the yaw angle `[rad]`.
float yaw() const { return yaw_; }

// Return the x velocity.
float vx() const { return vx_; }

// Return the y velocity.
float vy() const { return vy_; }

// Return the x acceleration.
float ax() const { return ax_; }

// Return the y acceleration.
float ay() const { return ay_; }

// Return `true` if the value is `1.0`.
bool is_valid() const { return is_valid_ == 1.0f; }

private:
Expand Down Expand Up @@ -163,34 +177,19 @@ struct AgentHistory
{
}

static size_t state_dim() { return AgentStateDim; }

/**
* @brief Return the history length.
*
* @return size_t History length.
*/
// Return the history time length `T`.
size_t length() const { return max_time_length_; }

/**
* @brief Return the data size of history `T * D`.
*
* @return size_t
*/
// Return the number of agent state dimensions `D`.
static size_t state_dim() { return AgentStateDim; }

// Return the data size of history `T * D`.
size_t size() const { return max_time_length_ * state_dim(); }

/**
* @brief Return the shape of history matrix ordering in `(T, D)`.
*
* @return std::tuple<size_t, size_t>
*/
// Return the shape of history matrix ordering in `(T, D)`.
std::tuple<size_t, size_t> shape() const { return {max_time_length_, state_dim()}; }

/**
* @brief Return the object id.
*
* @return const std::string&
*/
// Return the object id.
const std::string & object_id() const { return object_id_; }

size_t label_index() const { return label_index_; }
Expand Down Expand Up @@ -220,10 +219,7 @@ struct AgentHistory
latest_time_ = current_time;
}

/**
* @brief Update history with all-zeros state, but latest time is not updated.
*
*/
// Update history with all-zeros state, but latest time is not updated.
void update_empty() noexcept
{
// remove the state at the oldest timestamp
Expand All @@ -235,11 +231,7 @@ struct AgentHistory
}
}

/**
* @brief Return the address pointer of data array.
*
* @return float* The pointer of data array.
*/
// Return the address pointer of data array.
float * data_ptr() noexcept { return data_.data(); }

/**
Expand All @@ -264,11 +256,7 @@ struct AgentHistory
*/
bool is_valid_latest() const { return get_latest_state().is_valid(); }

/**
* @brief Get the latest agent state.
*
* @return AgentState
*/
// Get the latest agent state at `T`.
AgentState get_latest_state() const
{
const auto & latest_itr = (data_.begin() + state_dim() * (max_time_length_ - 1));
Expand Down Expand Up @@ -338,49 +326,55 @@ struct AgentData
}
}

static size_t state_dim() { return AgentStateDim; }
// Return the number of classes `C`.
static size_t num_class() { return 3; } // TODO(ktro2828): Do not use magic number.

// Return the number of target agents `B`.
size_t num_target() const { return num_target_; }

// Return the number of agents `N`.
size_t num_agent() const { return num_agent_; }

// Return the timestamp length `T`.
size_t time_length() const { return time_length_; }

// Return the number of agent state dimensions `D`.
static size_t state_dim() { return AgentStateDim; }

// Return the index of ego.
int sdc_index() const { return sdc_index_; }

// Return the vector of indices of target agents, in shape `[B]`.
const std::vector<size_t> & target_index() const { return target_index_; }

// Return the vector of label indices of all agents, in shape `[N]`.
const std::vector<size_t> & label_index() const { return label_index_; }

// Return the vector of label indices of target agents, in shape `[B]`.
const std::vector<size_t> & target_label_index() const { return target_label_index_; }

// Return the vector of timestamps in shape `[T]`.
const std::vector<float> & timestamps() const { return timestamps_; }

// Return the number of all elements `N*T*D`.
size_t size() const { return num_agent_ * time_length_ * state_dim(); }
size_t input_dim() const { return num_agent_ + time_length_ + num_class() + 3; }

/**
* @brief Return the data shape ordering in (N, T, D).
*
* @return std::tuple<size_t, size_t, size_t>
*/
// Return the number of state dimensions of MTR input `T+C+D+3`.
size_t input_dim() const { return time_length_ + state_dim() + num_class() + 3; }

// Return the data shape ordering in (N, T, D).
std::tuple<size_t, size_t, size_t> shape() const
{
return {num_agent_, time_length_, state_dim()};
}

/**
* @brief Return the address pointer of data array.
*
* @return float* The pointer of data array.
*/
// Return the address pointer of data array.
float * data_ptr() noexcept { return data_.data(); }

/**
* @brief Return the address pointer of data array for target agents.
*
* @return float* The pointer of data array for target agents.
*/
// Return the address pointer of data array for target agents.
float * target_data_ptr() noexcept { return target_data_.data(); }

/**
* @brief Return the address pointer of data array for ego vehicle.
*
* @return float* The pointer of data array for ego vehicle.
*/
// Return the address pointer of data array for ego vehicle.
float * ego_data_ptr() noexcept { return ego_data_.data(); }

private:
Expand All @@ -397,6 +391,7 @@ struct AgentData
std::vector<float> ego_data_;
};

// Get label names from label indices.
std::vector<std::string> getLabelNames(const std::vector<size_t> & label_index)
{
std::vector<std::string> label_names;
Expand Down
12 changes: 2 additions & 10 deletions perception/tensorrt_mtr/include/tensorrt_mtr/intention_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,10 @@ struct IntentionPoint
return points;
}

/**
* @brief Return the size of intension point K * D.
*
* @return size_t
*/
// Return the size of intension point `K*D`.
size_t size() const { return num_cluster_ * state_dim(); }

/**
* @brief Return the number of clusters contained in intention points.
*
* @return size_t
*/
// Return the number of clusters contained in intention points `K`.
size_t num_cluster() const { return num_cluster_; }

private:
Expand Down
76 changes: 15 additions & 61 deletions perception/tensorrt_mtr/include/tensorrt_mtr/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,8 @@ class PolylineTypeMap
}
}

/**
* @brief Return the ID of the corresponding label type.
* If specified type is not contained in map, return -1.
*
* @param type
* @return int
*/
// Return the ID of the corresponding label type. If specified type is not contained in map,
// return `-1`.
int getTypeID(const std::string & type) const
{
return label_map_.count(type) == 0 ? -1 : label_map_.at(type);
Expand All @@ -109,83 +104,42 @@ class MTRNode : public rclcpp::Node
public:
explicit MTRNode(const rclcpp::NodeOptions & node_options);

// Object ID for ego vehicle
// Object ID of the ego vehicle
const std::string EGO_ID{"EGO"};

private:
/**
* @brief Main callback being invoked when the tracked objects topic is subscribed.
*
* @param object_msg
*/
// Main callback being invoked when the tracked objects topic is subscribed.
void callback(const TrackedObjects::ConstSharedPtr object_msg);

/**
* @brief Callback being invoked when the HD map topic is subscribed.
*
* @param map_msg
*/
// Callback being invoked when the HD map topic is subscribed.
void onMap(const HADMapBin::ConstSharedPtr map_msg);

/**
* @brief Callback being invoked when the Ego's odometry topic is subscribed.
*
* @param ego_msg
*/
// Callback being invoked when the Ego's odometry topic is subscribed.
void onEgo(const Odometry::ConstSharedPtr ego_msg);

/**
* @brief Converts lanelet2 to polylines.
*
* @return true
*/
// Convert Lanelet to `PolylineData`.
bool convertLaneletToPolyline();

/**
* @brief Remove ancient agent histories.
*
* @param current_time
* @param objects_msg
*/
// Remove ancient agent histories.
void removeAncientAgentHistory(
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg);

/**
* @brief Appends new states to history.
*
* @param current_time
* @param objects_msg
*/
// Appends new states to history.
void updateAgentHistory(
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg);

// Extract ego state stored in the buffer which has the nearest timestamp from current timestamp.
AgentState extractNearestEgo(const float current_time) const;

/**
* @brief Extract target agents and return corresponding indices.
*
* NOTE: Extract targets in order of proximity, closest first.
*
* @param histories
* @return std::vector<size_t>
*/
// Extract target agents and return corresponding indices.
// NOTE: Extract targets in order of proximity, closest first.
std::vector<size_t> extractTargetAgent(const std::vector<AgentHistory> & histories);

/**
* @brief Return the timestamps relative from the first element.Return the timestamps relative
* from the first element.
*
* @return std::vector<float>
*/
// Return the timestamps relative from the first element.Return the timestamps relative from the
// first element.
std::vector<float> getRelativeTimestamps() const;

/**
* @brief Generate `PredictedObject` from `PredictedTrajectory`.
*
* @param object
* @param trajectory
* @return PredictedObject
*/
// Generate `PredictedObject` from `PredictedTrajectory`.
PredictedObject generatePredictedObject(
const TrackedObject & object, const PredictedTrajectory & trajectory);

Expand Down
Loading

0 comments on commit 3fef41c

Please sign in to comment.