Skip to content

Commit

Permalink
fix: single_inference node
Browse files Browse the repository at this point in the history
Signed-off-by: tzhong518 <[email protected]>
  • Loading branch information
tzhong518 committed Mar 12, 2024
1 parent 356be27 commit b986903
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 216 deletions.
24 changes: 12 additions & 12 deletions perception/lidar_centerpoint/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL)
)

## single inference node ##
# ament_auto_add_library(single_inference_lidar_centerpoint_component SHARED
# src/single_inference_node.cpp
# )

# target_link_libraries(single_inference_lidar_centerpoint_component
# centerpoint_lib
# )

# rclcpp_components_register_node(single_inference_lidar_centerpoint_component
# PLUGIN "centerpoint::SingleInferenceLidarCenterPointNode"
# EXECUTABLE single_inference_lidar_centerpoint_node
# )
ament_auto_add_library(single_inference_lidar_centerpoint_component SHARED
src/single_inference_node.cpp
)

target_link_libraries(single_inference_lidar_centerpoint_component
centerpoint_lib
)

rclcpp_components_register_node(single_inference_lidar_centerpoint_component
PLUGIN "centerpoint::SingleInferenceLidarCenterPointNode"
EXECUTABLE single_inference_lidar_centerpoint_node
)

install(PROGRAMS
scripts/lidar_centerpoint_visualizer.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SingleInferenceLidarCenterPointNode : public rclcpp::Node
float score_threshold_{0.0};
std::vector<std::string> class_names_;
bool rename_car_to_truck_and_bus_{false};
bool has_variance_{false};
bool has_twist_{false};

DetectionClassRemapper detection_class_remapper_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ __global__ void generateBoxes3D_kernel(
const float vel_x_variance = out_vel[down_grid_size * 2 + idx];
const float vel_y_variance = out_vel[down_grid_size * 3 + idx];

det_boxes3d[idx].x_variance = expf(offset_x_variance);
det_boxes3d[idx].y_variance = expf(offset_y_variance);
det_boxes3d[idx].x_variance = voxel_size_x * downsample_factor * expf(offset_x_variance);
det_boxes3d[idx].y_variance = voxel_size_x * downsample_factor * expf(offset_y_variance);
det_boxes3d[idx].z_variance = expf(z_variance);
det_boxes3d[idx].length_variance = expf(l_variance);
det_boxes3d[idx].width_variance = expf(w_variance);
Expand Down
200 changes: 0 additions & 200 deletions perception/lidar_centerpoint/src/node copy.cpp

This file was deleted.

5 changes: 3 additions & 2 deletions perception/lidar_centerpoint/src/single_inference_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ SingleInferenceLidarCenterPointNode::SingleInferenceLidarCenterPointNode(
const std::string head_onnx_path = this->declare_parameter<std::string>("head_onnx_path");
const std::string head_engine_path = this->declare_parameter<std::string>("head_engine_path");
class_names_ = this->declare_parameter<std::vector<std::string>>("class_names");
has_variance_ = this->declare_parameter("has_variance", false);
has_twist_ = this->declare_parameter("has_twist", false);
const std::size_t point_feature_size =
static_cast<std::size_t>(this->declare_parameter<std::int64_t>("point_feature_size"));
Expand Down Expand Up @@ -96,7 +97,7 @@ SingleInferenceLidarCenterPointNode::SingleInferenceLidarCenterPointNode(
CenterPointConfig config(
class_names_.size(), point_feature_size, max_voxel_size, point_cloud_range, voxel_size,
downsample_factor, encoder_in_feature_size, score_threshold, circle_nms_dist_threshold,
yaw_norm_thresholds);
yaw_norm_thresholds, has_variance_);
detector_ptr_ =
std::make_unique<CenterPointTRT>(encoder_param, head_param, densification_param, config);

Expand Down Expand Up @@ -173,7 +174,7 @@ void SingleInferenceLidarCenterPointNode::detect(
output_msg.header = msg.header;
for (const auto & box3d : det_boxes3d) {
autoware_auto_perception_msgs::msg::DetectedObject obj;
box3DToDetectedObject(box3d, class_names_, has_twist_, obj);
box3DToDetectedObject(box3d, class_names_, has_twist_, has_variance_, obj);
output_msg.objects.emplace_back(obj);
}

Expand Down

0 comments on commit b986903

Please sign in to comment.