diff --git a/rumale-neural_network/lib/rumale/neural_network/base_rbf.rb b/rumale-neural_network/lib/rumale/neural_network/base_rbf.rb index ba6af5e8..633689c4 100644 --- a/rumale-neural_network/lib/rumale/neural_network/base_rbf.rb +++ b/rumale-neural_network/lib/rumale/neural_network/base_rbf.rb @@ -64,7 +64,7 @@ def find_centers(x) # initialize centers randomly n_samples = x.shape[0] sub_rng = @rng.dup - rand_id = Array(0...n_samples).sample(n_centers, random: sub_rng) + rand_id = Array.new(n_centers) { |_v| sub_rng.rand(0...n_samples) } @centers = x[rand_id, true].dup # find centers diff --git a/rumale-neural_network/spec/rumale/neural_network/rbf_regressor_spec.rb b/rumale-neural_network/spec/rumale/neural_network/rbf_regressor_spec.rb index 1dd9eb55..7eda5f31 100644 --- a/rumale-neural_network/spec/rumale/neural_network/rbf_regressor_spec.rb +++ b/rumale-neural_network/spec/rumale/neural_network/rbf_regressor_spec.rb @@ -12,29 +12,62 @@ let(:n_samples) { x.shape[0] } let(:n_features) { x.shape[1] } let(:n_outputs) { y.shape[1] } - let(:estimator) { described_class.new(hidden_units: 64, reg_param: 1e4, random_seed: 1).fit(x, y) } + let(:hidden_units) { 64 } + let(:estimator) { described_class.new(hidden_units: hidden_units, reg_param: 1e4, random_seed: 1) } let(:predicted) { estimator.predict(x) } let(:score) { estimator.score(x, y) } shared_examples 'regression' do + before { estimator.fit(x, y) } + it 'fits model for given dataset.', :aggregate_failures do expect(predicted).to be_a(Numo::DFloat) expect(predicted).to be_contiguous expect(predicted.ndim).to eq(y.ndim) expect(predicted.shape[0]).to eq(n_samples) + expect(estimator.centers).to be_a(Numo::DFloat) + expect(estimator.centers).to be_contiguous + expect(estimator.centers.ndim).to eq(2) + expect(estimator.centers.shape[0]).to eq(hidden_units) + expect(estimator.centers.shape[1]).to eq(n_features) + expect(estimator.weight_vec).to be_a(Numo::DFloat) + expect(estimator.weight_vec).to be_contiguous + expect(estimator.weight_vec.ndim).to eq(2) + expect(estimator.weight_vec.shape[0]).to eq(hidden_units) + expect(estimator.weight_vec.shape[1]).to eq(n_outputs) expect(score).to be > 0.98 end end - context 'when single regression problem' do - let(:y) { single_target } + context 'when the number of hidden units is less than the number of samples' do + context 'when single regression problem' do + let(:y) { single_target } + let(:n_outputs) { 1 } + + it_behaves_like 'regression' + end - it_behaves_like 'regression' + context 'when multiple regression problem' do + let(:y) { multi_target } + + it_behaves_like 'regression' + end end - context 'when multiple regression problem' do - let(:y) { multi_target } + context 'when the number of hidden units is greater than the number of samples' do + let(:hidden_units) { 400 } - it_behaves_like 'regression' + context 'when single regression problem' do + let(:y) { single_target } + let(:n_outputs) { 1 } + + it_behaves_like 'regression' + end + + context 'when multiple regression problem' do + let(:y) { multi_target } + + it_behaves_like 'regression' + end end end