Skip to content

Commit

Permalink
fix(tensorrt): update tensorrt code of lidar_centerpoint
Browse files Browse the repository at this point in the history
Signed-off-by: M. Fatih Cırıt <[email protected]>
  • Loading branch information
M. Fatih Cırıt committed Dec 16, 2022
1 parent eb7bb1b commit e43919e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class TensorRTWrapper
const std::string & onnx_path, const std::string & engine_path, const std::string & precision);

unique_ptr<nvinfer1::IExecutionContext> context_ = nullptr;
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;

protected:
virtual bool setProfile(
Expand All @@ -86,7 +87,6 @@ class TensorRTWrapper

unique_ptr<nvinfer1::IRuntime> runtime_ = nullptr;
unique_ptr<nvinfer1::IHostMemory> plan_ = nullptr;
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
};

} // namespace centerpoint
Expand Down
23 changes: 10 additions & 13 deletions perception/lidar_centerpoint/lib/centerpoint_trt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ CenterPointTRT::CenterPointTRT(
encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
encoder_trt_ptr_->init(
encoder_param.onnx_path(), encoder_param.engine_path(), encoder_param.trt_precision());
encoder_trt_ptr_->context_->setBindingDimensions(
0,
std::string name_tensor_encoder_in = encoder_trt_ptr_->engine_->getIOTensorName(0);
encoder_trt_ptr_->context_->setInputShape(
name_tensor_encoder_in.c_str(),
nvinfer1::Dims3(
config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_));

Expand All @@ -49,10 +50,11 @@ CenterPointTRT::CenterPointTRT(
config_.head_out_dim_size_, config_.head_out_rot_size_, config_.head_out_vel_size_};
head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
head_trt_ptr_->init(head_param.onnx_path(), head_param.engine_path(), head_param.trt_precision());
head_trt_ptr_->context_->setBindingDimensions(
0, nvinfer1::Dims4(
config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_,
config_.grid_size_x_));
std::string name_tensor_head_in = head_trt_ptr_->engine_->getIOTensorName(0);
head_trt_ptr_->context_->setInputShape(
name_tensor_head_in.c_str(), nvinfer1::Dims4(
config_.batch_size_, config_.encoder_out_feature_size_,
config_.grid_size_y_, config_.grid_size_x_));

initPtr();

Expand Down Expand Up @@ -166,8 +168,7 @@ void CenterPointTRT::inference()
}

// pillar encoder network
std::vector<void *> encoder_buffers{encoder_in_features_d_.get(), pillar_features_d_.get()};
encoder_trt_ptr_->context_->enqueueV2(encoder_buffers.data(), stream_, nullptr);
encoder_trt_ptr_->context_->enqueueV3(stream_);

// scatter
CHECK_CUDA_ERROR(scatterFeatures_launch(
Expand All @@ -176,11 +177,7 @@ void CenterPointTRT::inference()
spatial_features_d_.get(), stream_));

// head network
std::vector<void *> head_buffers = {spatial_features_d_.get(), head_out_heatmap_d_.get(),
head_out_offset_d_.get(), head_out_z_d_.get(),
head_out_dim_d_.get(), head_out_rot_d_.get(),
head_out_vel_d_.get()};
head_trt_ptr_->context_->enqueueV2(head_buffers.data(), stream_, nullptr);
head_trt_ptr_->context_->enqueueV3(stream_);
}

void CenterPointTRT::postProcess(std::vector<Box3D> & det_boxes3d)
Expand Down

0 comments on commit e43919e

Please sign in to comment.