Skip to content

Commit

Permalink
Fix compilation when using TemplNet without any input derivatives ena…
Browse files Browse the repository at this point in the history
…bled (bug introduced in recent commits)
  • Loading branch information
Ithanil committed Nov 20, 2019
1 parent 35475b5 commit 423c47b
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions include/qnets/templ/TemplNet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TemplNet


template <int ONIN = ORIG_N_IN>
typename std::enable_if<ONIN != NET_N_IN, void>::type _computeInputGradients()
typename std::enable_if<ONIN != NET_N_IN && (nd1 != 0 || nd2 != 0), void>::type _computeInputGradients()
{
throw std::runtime_error("[TemplNet::_processOrigInput] Original input derivatives require provided input-to-orig derivatives.");
}
Expand Down Expand Up @@ -245,28 +245,45 @@ class TemplNet
}

template <int ONIN = ORIG_N_IN>
typename std::enable_if<ONIN == NET_N_IN, void>::type _processOrigInput()
typename std::enable_if<ONIN == NET_N_IN && nd1 != 0, void>::type _processOrigInput()
{
// feed original input
std::get<0>(_layers).ForwardInput(_input, dflags);
this->_propagateLayers();
if (this->hasD1()) { this->_computeInputGradients(); }
}

template <int ONIN = ORIG_N_IN>
typename std::enable_if<ONIN == NET_N_IN && nd1 == 0, void>::type _processOrigInput()
{
// feed original input
std::get<0>(_layers).ForwardInput(_input, dflags);
this->_propagateLayers();
}

template <int ONIN = ORIG_N_IN>
typename std::enable_if<ONIN != NET_N_IN, void>::type _processOrigInput()
{
throw std::runtime_error("[TemplNet::_processOrigInput] Original input can't be fed directly, because it differs in size from network input.");
}

void _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[])
template <int N_D1 = nd1>
typename std::enable_if<N_D1 != 0, void>::type _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[])
{
// feed derived network input
std::get<0>(_layers).ForwardLayer(_input.data(), orig_d1, orig_d2, dflags);
this->_propagateLayers();
if (this->hasD1()) { this->_computeInputGradients(orig_d1); }
}

template <int N_D1 = nd1>
typename std::enable_if<N_D1 == 0, void>::type _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[])
{
// feed derived network input
std::get<0>(_layers).ForwardLayer(_input.data(), orig_d1, orig_d2, dflags);
this->_propagateLayers();
}

public:
explicit constexpr TemplNet(DynamicDFlags init_dflags = DynamicDFlags{DCONF}):
_out_begins(tupl::make_fcont<std::array<const ValueT *, nlayer>>(_layers, [](const auto &layer) { return &layer.out().front(); })),
Expand Down

0 comments on commit 423c47b

Please sign in to comment.