Skip to content

Commit

Permalink
feat: add RBFClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Nov 5, 2023
1 parent 6511009 commit b7fa9e3
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 0 deletions.
1 change: 1 addition & 0 deletions rumale-neural_network/lib/rumale/neural_network.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
require_relative 'neural_network/mlp_classifier'
require_relative 'neural_network/mlp_regressor'
require_relative 'neural_network/base_rbf'
require_relative 'neural_network/rbf_classifier'
require_relative 'neural_network/rbf_regressor'
107 changes: 107 additions & 0 deletions rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# frozen_string_literal: true

require 'rumale/base/classifier'
require 'rumale/utils'
require 'rumale/validation'
require 'rumale/neural_network/base_rbf'

module Rumale
module NeuralNetwork
# RBFClassifier is a class that implements classifier based on (k-means) radial basis function (RBF) networks.
#
# @example
# require 'numo/tiny_linalg'
# Numo::Linalg = Numo::TinyLinalg
#
# require 'rumale/neural_network/rbf_classifier'
#
# estimator = Rumale::NeuralNetwork::RBFClassifier.new(hidden_units: 128, reg_param: 100.0)
# estimator.fit(training_samples, traininig_labels)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - Bugmann, G., "Normalized Gaussian Radial Basis Function networks," Neural Computation, vol. 20, pp. 97--110, 1998.
# - Que, Q., and Belkin, M., "Back to the Future: Radial Basis Function Networks Revisited," Proc. of AISTATS'16, pp. 1375--1383, 2016.
class RBFClassifier < BaseRBF
include ::Rumale::Base::Classifier

# Return the class labels.
# @return [Numo::Int32] (size: n_classes)
attr_reader :classes

# Return the centers in the hidden layer of RBF network.
# @return [Numo::DFloat] (shape: [n_centers, n_features])
attr_reader :centers

# Return the weight vector.
# @return [Numo::DFloat] (shape: [n_centers, n_classes])
attr_reader :weight_vec

# Return the random generator.
# @return [Random]
attr_reader :rng

# Create a new classifier with (k-means) RBF networks.
#
# @param hidden_units [Array] The number of units in the hidden layer.
# @param gamma [Float] The parameter for the radial basis function, if nil it is 1 / n_features.
# @param reg_param [Float] The regularization parameter.
# @param normalize [Boolean] The flag indicating whether to normalize the hidden layer output or not.
# @param max_iter [Integer] The maximum number of iterations for finding centers.
# @param tol [Float] The tolerance of termination criterion for finding centers.
# @param random_seed [Integer] The seed value using to initialize the random generator.
def initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false,
max_iter: 50, tol: 1e-4, random_seed: nil)
super
end

# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [RBFClassifier] The learned classifier itself.
def fit(x, y)
x = ::Rumale::Validation.check_convert_sample_array(x)
y = ::Rumale::Validation.check_convert_label_array(y)
::Rumale::Validation.check_sample_size(x, y)
raise 'RBFClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false)

@classes = Numo::NArray[*y.to_a.uniq.sort]

partial_fit(x, one_hot_encode(y))

self
end

# Calculate confidence scores for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
def decision_function(x)
x = ::Rumale::Validation.check_convert_sample_array(x)

h = hidden_output(x)
h.dot(@weight_vec)
end

# Predict class labels for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
x = ::Rumale::Validation.check_convert_sample_array(x)

scores = decision_function(x)
n_samples, n_classes = scores.shape
label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes
@classes[label_ids].dup
end

private

def one_hot_encode(y)
Numo::DFloat.cast(::Rumale::Utils.binarize_labels(y))
end
end
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# frozen_string_literal: true

require 'spec_helper'

require 'numo/tiny_linalg'
Numo::Linalg = Numo::TinyLinalg

RSpec.describe Rumale::NeuralNetwork::RBFClassifier do
let(:x) { dataset[0] }
let(:y) { dataset[1] }
let(:classes) { y.to_a.uniq.sort }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_classes) { classes.size }
let(:hidden_units) { 64 }
let(:estimator) { described_class.new(hidden_units: hidden_units, reg_param: 1e4, random_seed: 1) }
let(:predicted) { estimator.predict(x) }
let(:score) { estimator.score(x, y) }

shared_examples 'classification' do
before { estimator.fit(x, y) }

it 'classifies given dataset.', :aggregate_failures do
expect(estimator.classes).to be_a(Numo::Int32)
expect(estimator.classes).to be_contiguous
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.centers).to be_a(Numo::DFloat)
expect(estimator.centers).to be_contiguous
expect(estimator.centers.ndim).to eq(2)
expect(estimator.centers.shape[0]).to eq(hidden_units)
expect(estimator.centers.shape[1]).to eq(n_features)
expect(estimator.weight_vec).to be_a(Numo::DFloat)
expect(estimator.weight_vec).to be_contiguous
expect(estimator.weight_vec.ndim).to eq(2)
expect(estimator.weight_vec.shape[0]).to eq(hidden_units)
expect(estimator.weight_vec.shape[1]).to eq(n_classes)
expect(predicted).to be_a(Numo::Int32)
expect(predicted).to be_contiguous
expect(predicted.ndim).to eq(1)
expect(predicted.shape[0]).to eq(n_samples)
expect(predicted).to eq(y)
expect(score).to eq(1.0)
end
end

context 'when the number of hidden units is less than the number of samples' do
context 'when binary classification problem' do
let(:dataset) { xor_dataset }

it_behaves_like 'classification'
end

context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }

it_behaves_like 'classification'
end
end

context 'when the number of hidden units is greater than the number of samples' do
let(:hidden_units) { 512 }

context 'when binary classification problem' do
let(:dataset) { xor_dataset }

it_behaves_like 'classification'
end

context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }

it_behaves_like 'classification'
end
end
end

0 comments on commit b7fa9e3

Please sign in to comment.