Skip to content

Commit

Permalink
fix: obtain n_centers random indices
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Nov 4, 2023
1 parent 4c1ccec commit 6511009
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def find_centers(x)
# initialize centers randomly
n_samples = x.shape[0]
sub_rng = @rng.dup
rand_id = Array(0...n_samples).sample(n_centers, random: sub_rng)
rand_id = Array.new(n_centers) { |_v| sub_rng.rand(0...n_samples) }
@centers = x[rand_id, true].dup

# find centers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,62 @@
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_outputs) { y.shape[1] }
let(:estimator) { described_class.new(hidden_units: 64, reg_param: 1e4, random_seed: 1).fit(x, y) }
let(:hidden_units) { 64 }
let(:estimator) { described_class.new(hidden_units: hidden_units, reg_param: 1e4, random_seed: 1) }
let(:predicted) { estimator.predict(x) }
let(:score) { estimator.score(x, y) }

shared_examples 'regression' do
before { estimator.fit(x, y) }

it 'fits model for given dataset.', :aggregate_failures do
expect(predicted).to be_a(Numo::DFloat)
expect(predicted).to be_contiguous
expect(predicted.ndim).to eq(y.ndim)
expect(predicted.shape[0]).to eq(n_samples)
expect(estimator.centers).to be_a(Numo::DFloat)
expect(estimator.centers).to be_contiguous
expect(estimator.centers.ndim).to eq(2)
expect(estimator.centers.shape[0]).to eq(hidden_units)
expect(estimator.centers.shape[1]).to eq(n_features)
expect(estimator.weight_vec).to be_a(Numo::DFloat)
expect(estimator.weight_vec).to be_contiguous
expect(estimator.weight_vec.ndim).to eq(2)
expect(estimator.weight_vec.shape[0]).to eq(hidden_units)
expect(estimator.weight_vec.shape[1]).to eq(n_outputs)
expect(score).to be > 0.98
end
end

context 'when single regression problem' do
let(:y) { single_target }
context 'when the number of hidden units is less than the number of samples' do
context 'when single regression problem' do
let(:y) { single_target }
let(:n_outputs) { 1 }

it_behaves_like 'regression'
end

it_behaves_like 'regression'
context 'when multiple regression problem' do
let(:y) { multi_target }

it_behaves_like 'regression'
end
end

context 'when multiple regression problem' do
let(:y) { multi_target }
context 'when the number of hidden units is greater than the number of samples' do
let(:hidden_units) { 400 }

it_behaves_like 'regression'
context 'when single regression problem' do
let(:y) { single_target }
let(:n_outputs) { 1 }

it_behaves_like 'regression'
end

context 'when multiple regression problem' do
let(:y) { multi_target }

it_behaves_like 'regression'
end
end
end

0 comments on commit 6511009

Please sign in to comment.