Skip to content

Commit

Permalink
feat: add ego information handlings
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Apr 12, 2024
1 parent 4d54696 commit ffae194
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 20 deletions.
4 changes: 2 additions & 2 deletions perception/tensorrt_mtr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ if (CUDA_FOUND)
PATH_SUFFIXES lib lib64 bin
DOC "CUDNN library.")
else()
message(FAITAL_ERROR "Can not find CUDA")
message(FATAL_ERROR "Can not find CUDA")
endif()

list(APPEND TRT_PLUGINS "nvinfer")
list(APPEND TRT_PLUGINS "nvonnxparser")
list(APPEND TRT_PLUGINS "nvparsers")
foreach(libName ${TRT_PLUGINS})
find_library(${libName}_lib NAMES ${libName} "/usr" PATH_SUFFIES lib)
find_library(${libName}_lib NAMES ${libName} "/usr" PATH_SUFFIXES lib)
list(APPEND TRT_PLUGINS ${${libName}_lib})
endforeach()

Expand Down
11 changes: 10 additions & 1 deletion perception/tensorrt_mtr/include/tensorrt_mtr/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace trt_mtr
Expand All @@ -56,6 +57,11 @@ using autoware_auto_perception_msgs::msg::TrackedObject;
using autoware_auto_perception_msgs::msg::TrackedObjects;
using nav_msgs::msg::Odometry;

// TODO(ktro2828): use received ego size topic
constexpr float EGO_LENGTH = 4.0f;
constexpr float EGO_WIDTH = 2.0f;
constexpr float EGO_HEIGHT = 1.0f;

class PolylineTypeMap
{
public:
Expand Down Expand Up @@ -139,6 +145,8 @@ class MTRNode : public rclcpp::Node
void updateAgentHistory(
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg);

AgentState extractNearestEgo(const float current_time) const;

/**
* @brief Extract target agents and return corresponding indices.
*
Expand Down Expand Up @@ -179,10 +187,11 @@ class MTRNode : public rclcpp::Node
tier4_autoware_utils::TransformListener transform_listener_;

// MTR parameters
std::unique_ptr<MtrConfig> config_ptr_;
std::unique_ptr<MTRConfig> config_ptr_;
std::unique_ptr<TrtMTR> model_ptr_;
PolylineTypeMap polyline_type_map_;
std::shared_ptr<PolylineData> polyline_ptr_;
std::vector<std::pair<float, AgentState>> ego_states_;
std::vector<float> timestamps_;
}; // class MTRNode
} // namespace trt_mtr
Expand Down
14 changes: 7 additions & 7 deletions perception/tensorrt_mtr/include/tensorrt_mtr/trt_mtr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ namespace trt_mtr
/**
* @brief A configuration of MTR.
*/
struct MtrConfig
struct MTRConfig
{
/**
* @brief Construct a new Mtr Config object
* @brief Construct a new instance.
*
* @param target_labels An array of target label names.
* @param num_mode The number of modes.
Expand All @@ -49,7 +49,7 @@ struct MtrConfig
* @param intention_point_filepath The path to intention points file.
* @param num_intention_point_cluster The number of clusters for intension points.
*/
MtrConfig(
MTRConfig(
const std::vector<std::string> & target_labels = {"VEHICLE", "PEDESTRIAN", "CYCLIST"},
const size_t num_past = 10, const size_t num_mode = 6, const size_t num_future = 80,
const size_t max_num_polyline = 768, const size_t max_num_point = 20,
Expand Down Expand Up @@ -81,7 +81,7 @@ struct MtrConfig
std::array<float, 2> offset_xy;
std::string intention_point_filepath;
size_t num_intention_point_cluster;
};
}; // struct MTRConfig

/**
* @brief A class to inference with MTR.
Expand All @@ -101,7 +101,7 @@ class TrtMTR
*/
TrtMTR(
const std::string & model_path, const std::string & precision,
const MtrConfig & config = MtrConfig(), const BatchConfig & batch_config = {1, 1, 1},
const MTRConfig & config = MTRConfig(), const BatchConfig & batch_config = {1, 1, 1},
const size_t max_workspace_size = (1ULL << 30),
const BuildConfig & build_config = BuildConfig());

Expand All @@ -122,7 +122,7 @@ class TrtMTR
*
* @return const MtrConfig& The model configuration which can not be updated.
*/
const MtrConfig & config() const { return config_; }
const MTRConfig & config() const { return config_; }

private:
/**
Expand Down Expand Up @@ -152,7 +152,7 @@ class TrtMTR
bool postProcess(AgentData & agent_data, std::vector<PredictedTrajectory> & trajectories);

// model parameters
MtrConfig config_;
MTRConfig config_;

std::unique_ptr<MTRBuilder> builder_;
cudaStream_t stream_{nullptr};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*
* @param B The number of target agents.
* @param M The number of modes.
* @param T The number of future timestmaps.
* @param T The number of future timestamps.
* @param inDim The number of input agent state dimensions.
* @param targetState Source target agent states at latest timestamp, in shape [B*inDim].
* @param outDim The number of output state dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ __global__ void calculatePolylineCenterKernel(

/**
* @brief In cases of the number of batch polylines (L) is greater than K,
* extacts the topK elements.
* extracts the topK elements.
*
* @param L The number of source polylines.
* @param K The number of polylines expected as the model input.
Expand Down
2 changes: 1 addition & 1 deletion perception/tensorrt_mtr/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<package format="3">
<name>tensorrt_mtr</name>
<version>0.1.0</version>
<description>ROS 2 Node of Motion Transfomer(a.k.a MTR).</description>
<description>ROS 2 Node of Motion Transfromer(a.k.a MTR).</description>
<maintainer email="[email protected]">kotarouetake</maintainer>
<license>Apache-2.0</license>

Expand Down
57 changes: 51 additions & 6 deletions perception/tensorrt_mtr/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

#include <algorithm>
#include <cmath>
#include <utility>

namespace trt_mtr
{
Expand Down Expand Up @@ -182,7 +181,7 @@ MTRNode::MTRNode(const rclcpp::NodeOptions & node_options)
declare_parameter<std::string>("intention_point_filepath");
const auto num_intention_point_cluster =
static_cast<size_t>(declare_parameter<int>("num_intention_point_cluster"));
config_ptr_ = std::make_unique<MtrConfig>(
config_ptr_ = std::make_unique<MTRConfig>(
target_labels, num_past, num_mode, num_future, max_num_polyline, max_num_point,
point_break_distance, offset_xy, intention_point_filepath, num_intention_point_cluster);
model_ptr_ = std::make_unique<TrtMTR>(model_path, precision, *config_ptr_.get());
Expand Down Expand Up @@ -210,7 +209,7 @@ void MTRNode::callback(const TrackedObjects::ConstSharedPtr object_msg)
return; // No polyline
}

const auto current_time = rclcpp::Time(object_msg->header.stamp).seconds();
const auto current_time = static_cast<float>(rclcpp::Time(object_msg->header.stamp).seconds());

timestamps_.emplace_back(current_time);
// TODO(ktro2828): update timestamps
Expand Down Expand Up @@ -289,7 +288,30 @@ void MTRNode::onMap(const HADMapBin::ConstSharedPtr map_msg)

void MTRNode::onEgo(const Odometry::ConstSharedPtr ego_msg)
{
RCLCPP_INFO_STREAM(get_logger(), "Ego msg is received: " << ego_msg->header.frame_id);
const auto current_time = static_cast<float>(rclcpp::Time(ego_msg->header.stamp).seconds());
const auto & position = ego_msg->pose.pose.position;
const auto & twist = ego_msg->twist.twist;
const auto yaw = static_cast<float>(tf2::getYaw(ego_msg->pose.pose.orientation));
float ax = 0.0f, ay = 0.0f;
if (!ego_states_.empty()) {
const auto & latest_state = ego_states_.back();
const auto time_diff = current_time - latest_state.first;
ax = (static_cast<float>(twist.linear.x) - latest_state.second.vx()) / (time_diff + 1e-10f);
ay = static_cast<float>(twist.linear.y) - latest_state.second.vy() / (time_diff + 1e-10f);
}

// TODO(ktro2828): use received ego size topic
ego_states_.emplace_back(std::make_pair(
current_time,
AgentState(
static_cast<float>(position.x), static_cast<float>(position.y),
static_cast<float>(position.z), EGO_LENGTH, EGO_WIDTH, EGO_HEIGHT, yaw,
static_cast<float>(twist.linear.x), static_cast<float>(twist.linear.y), ax, ay, true)));

constexpr size_t max_buffer_size = 100;
if (max_buffer_size < ego_states_.size()) {
ego_states_.erase(ego_states_.begin(), ego_states_.begin());
}
}

bool MTRNode::convertLaneletToPolyline()
Expand Down Expand Up @@ -362,19 +384,24 @@ bool MTRNode::convertLaneletToPolyline()
void MTRNode::removeAncientAgentHistory(
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg)
{
// TODO(ktro2828): use ego info
constexpr float time_threshold = 1.0f; // TODO(ktro2828): use parameter
for (const auto & object : objects_msg->objects) {
const auto & object_id = tier4_autoware_utils::toHexString(object.object_id);
if (agent_history_map_.count(object_id) == 0) {
continue;
}

constexpr float time_threshold = 1.0f; // TODO(ktro2828): use parameter
const auto & history = agent_history_map_.at(object_id);
if (history.is_ancient(current_time, time_threshold)) {
agent_history_map_.erase(object_id);
}
}

if (
agent_history_map_.count(EGO_ID) != 0 &&
agent_history_map_.at(EGO_ID).is_ancient(current_time, time_threshold)) {
agent_history_map_.erase(EGO_ID);
}
}

void MTRNode::updateAgentHistory(
Expand Down Expand Up @@ -402,6 +429,15 @@ void MTRNode::updateAgentHistory(
}
}

auto ego_state = extractNearestEgo(current_time);
if (agent_history_map_.count(EGO_ID) == 0) {
AgentHistory history(EGO_ID, AgentLabel::VEHICLE, config_ptr_->num_past);
history.update(current_time, ego_state);
} else {
agent_history_map_.at(EGO_ID).update(current_time, ego_state);
}
observed_ids.emplace_back(EGO_ID);

// update unobserved histories with empty
for (auto & [object_id, history] : agent_history_map_) {
if (std::find(observed_ids.cbegin(), observed_ids.cend(), object_id) != observed_ids.cend()) {
Expand All @@ -411,6 +447,15 @@ void MTRNode::updateAgentHistory(
}
}

Check warning on line 448 in perception/tensorrt_mtr/src/node.cpp

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Complex Method

MTRNode::updateAgentHistory has a cyclomatic complexity of 9, threshold = 9. This function has many conditional statements (e.g. if, for, while), leading to lower code health. Avoid adding more conditionals and code to it without refactoring.

Check warning on line 448 in perception/tensorrt_mtr/src/node.cpp

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Bumpy Road Ahead

MTRNode::updateAgentHistory has 2 blocks with nested conditional logic. Any nesting of 2 or deeper is considered. Threshold is one single, nested block per function. The Bumpy Road code smell is a function that contains multiple chunks of nested conditional logic. The deeper the nesting and the more bumps, the lower the code health.

AgentState MTRNode::extractNearestEgo(const float current_time) const
{
auto state = std::min_element(
ego_states_.cbegin(), ego_states_.cend(), [&](const auto & s1, const auto & s2) {
return std::abs(s1.first - current_time) < std::abs(s2.first - current_time);
});
return state->second;
}

std::vector<size_t> MTRNode::extractTargetAgent(const std::vector<AgentHistory> & histories)
{
std::vector<std::pair<size_t, float>> distances;
Expand Down
2 changes: 1 addition & 1 deletion perception/tensorrt_mtr/src/trt_mtr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace trt_mtr
{
TrtMTR::TrtMTR(
const std::string & model_path, const std::string & precision, const MtrConfig & config,
const std::string & model_path, const std::string & precision, const MTRConfig & config,
const BatchConfig & batch_config, const size_t max_workspace_size,
const BuildConfig & build_config)
: config_(config),
Expand Down

0 comments on commit ffae194

Please sign in to comment.