Skip to content

Commit

Permalink
fix(tensorrt): update tensorrt code of traffic_light_classifier (auto…
Browse files Browse the repository at this point in the history
…warefoundation#2325)

Signed-off-by: M. Fatih Cırıt <[email protected]>
Signed-off-by: kminoda <[email protected]>
  • Loading branch information
xmfcx authored and kminoda committed Jan 6, 2023
1 parent c2e2e24 commit a544503
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 10 additions & 6 deletions perception/traffic_light_classifier/utils/trt_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@ void TrtCommon::setup()
}

context_ = UniquePtr<nvinfer1::IExecutionContext>(engine_->createExecutionContext());
input_dims_ = engine_->getBindingDimensions(getInputBindingIndex());
output_dims_ = engine_->getBindingDimensions(getOutputBindingIndex());

#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
input_dims_ = engine_->getTensorShape(input_name_.c_str());
output_dims_ = engine_->getTensorShape(output_name_.c_str());
#else
// Deprecated since 8.5
input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str()));
output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str()));
#endif

is_initialized_ = true;
}

Expand Down Expand Up @@ -155,8 +163,4 @@ int TrtCommon::getNumOutput()
output_dims_.d, output_dims_.d + output_dims_.nbDims, 1, std::multiplies<int>());
}

int TrtCommon::getInputBindingIndex() { return engine_->getBindingIndex(input_name_.c_str()); }

int TrtCommon::getOutputBindingIndex() { return engine_->getBindingIndex(output_name_.c_str()); }

} // namespace Tn
2 changes: 0 additions & 2 deletions perception/traffic_light_classifier/utils/trt_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ class TrtCommon
bool isInitialized();
int getNumInput();
int getNumOutput();
int getInputBindingIndex();
int getOutputBindingIndex();

UniquePtr<nvinfer1::IExecutionContext> context_;

Expand Down

0 comments on commit a544503

Please sign in to comment.