Skip to content

Commit

Permalink
fix(tensorrt): update tensorrt code of traffic_light_ssd_fine_detector
Browse files Browse the repository at this point in the history
Signed-off-by: M. Fatih Cırıt <[email protected]>
  • Loading branch information
M. Fatih Cırıt committed Dec 20, 2022
1 parent 2be41d7 commit f8b44e8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Net
void save(const std::string & path);

// Infer using pre-allocated GPU buffers {data, scores, boxes}
void infer(std::vector<void *> & buffers, const int batch_size);
void infer(const int batch_size);

// Get (c, h, w) size of the fixed input
std::vector<int> getInputSize();
Expand All @@ -90,6 +90,8 @@ class Net
unique_ptr<nvinfer1::IHostMemory> plan_ = nullptr;
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
unique_ptr<nvinfer1::IExecutionContext> context_ = nullptr;
std::string name_tensor_in_;
std::string name_tensor_out_;
cudaStream_t stream_ = nullptr;

void load(const std::string & path);
Expand Down
24 changes: 15 additions & 9 deletions perception/traffic_light_ssd_fine_detector/lib/src/trt_ssd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ Net::Net(const std::string & path, bool verbose)
runtime_ = unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));
load(path);
prepare();
name_tensor_in_ = engine_->getIOTensorName(0);
name_tensor_out_ = engine_->getIOTensorName(engine_->getNbIOTensors() - 1);
}

Net::~Net()
Expand Down Expand Up @@ -155,6 +157,8 @@ Net::Net(
std::cout << "Fail to create context" << std::endl;
return;
}
name_tensor_in_ = engine_->getIOTensorName(0);
name_tensor_out_ = engine_->getIOTensorName(engine_->getNbIOTensors() - 1);
}

void Net::save(const std::string & path)
Expand All @@ -164,35 +168,37 @@ void Net::save(const std::string & path)
file.write(reinterpret_cast<const char *>(plan_->data()), plan_->size());
}

void Net::infer(std::vector<void *> & buffers, const int batch_size)
void Net::infer(const int batch_size)
{
if (!context_) {
throw std::runtime_error("Fail to create context");
}
auto input_dims = engine_->getBindingDimensions(0);
context_->setBindingDimensions(
0, nvinfer1::Dims4(batch_size, input_dims.d[1], input_dims.d[2], input_dims.d[3]));
context_->enqueueV2(buffers.data(), stream_, nullptr);
const auto input_dims = engine_->getTensorShape(name_tensor_in_.c_str());
context_->setInputShape(
name_tensor_in_.c_str(),
nvinfer1::Dims4(batch_size, input_dims.d[1], input_dims.d[2], input_dims.d[3]));
context_->enqueueV3(stream_);
cudaStreamSynchronize(stream_);
}

std::vector<int> Net::getInputSize()
{
auto dims = engine_->getBindingDimensions(0);
const auto dims = engine_->getTensorShape(name_tensor_in_.c_str());
return {dims.d[1], dims.d[2], dims.d[3]};
}

std::vector<int> Net::getOutputScoreSize()
{
auto dims = engine_->getBindingDimensions(1);
const auto dims = engine_->getTensorShape(name_tensor_out_.c_str());
return {dims.d[1], dims.d[2]};
}

int Net::getMaxBatchSize()
{
return engine_->getProfileDimensions(0, 0, nvinfer1::OptProfileSelector::kMAX).d[0];
return engine_->getProfileShape(name_tensor_in_.c_str(), 0, nvinfer1::OptProfileSelector::kMAX)
.d[0];
}

int Net::getMaxDetections() { return engine_->getBindingDimensions(1).d[1]; }
int Net::getMaxDetections() { return engine_->getTensorShape(name_tensor_out_.c_str()).d[1]; }

} // namespace ssd
3 changes: 1 addition & 2 deletions perception/traffic_light_ssd_fine_detector/src/nodelet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ void TrafficLightSSDFineDetectorNodelet::callback(
auto data_d = cuda::make_unique<float[]>(num_infer * channel_ * width_ * height_);
auto scores_d = cuda::make_unique<float[]>(num_infer * detection_per_class_ * class_num_);
auto boxes_d = cuda::make_unique<float[]>(num_infer * detection_per_class_ * 4);
std::vector<void *> buffers = {data_d.get(), scores_d.get(), boxes_d.get()};
std::vector<cv::Point> lts, rbs;
std::vector<cv::Mat> cropped_imgs;

Expand All @@ -168,7 +167,7 @@ void TrafficLightSSDFineDetectorNodelet::callback(
cudaMemcpy(data_d.get(), data.data(), data.size() * sizeof(float), cudaMemcpyHostToDevice);

try {
net_ptr_->infer(buffers, num_infer);
net_ptr_->infer(num_infer);
} catch (std::exception & e) {
RCLCPP_ERROR(this->get_logger(), "%s", e.what());
return;
Expand Down

0 comments on commit f8b44e8

Please sign in to comment.