Skip to content

Commit

Permalink
feat: add RVFLClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Nov 6, 2023
1 parent 0983cc0 commit 71a7978
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
1 change: 1 addition & 0 deletions rumale-neural_network/lib/rumale/neural_network.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
require_relative 'neural_network/rbf_classifier'
require_relative 'neural_network/rbf_regressor'
require_relative 'neural_network/base_rvfl'
require_relative 'neural_network/rvfl_classifier'
require_relative 'neural_network/rvfl_regressor'
108 changes: 108 additions & 0 deletions rumale-neural_network/lib/rumale/neural_network/rvfl_classifier.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# frozen_string_literal: true

require 'rumale/base/classifier'
require 'rumale/neural_network/base_rvfl'
require 'rumale/utils'
require 'rumale/validation'

module Rumale
module NeuralNetwork
# RVFLClassifier is a class that implements classifier based on random vector functional link (RVFL) network.
# The current implementation uses sigmoid function as activation function.
#
# @example
# require 'numo/tiny_linalg'
# Numo::Linalg = Numo::TinyLinalg
#
# require 'rumale/neural_network/rvfl_classifier'
#
# estimator = Rumale::NeuralNetwork::RVFLClassifier.new(hidden_units: 128, reg_param: 100.0)
# estimator.fit(training_samples, traininig_labels)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - Malik, A. K., Gao, R., Ganaie, M. A., Tanveer, M., and Suganthan, P. N., "Random vector functional link network: recent developments, applications, and future directions," Applied Soft Computing, vol. 143, 2023.
# - Zhang, L., and Suganthan, P. N., "A comprehensive evaluation of random vector functional link networks," Information Sciences, vol. 367--368, pp. 1094--1105, 2016.
class RVFLClassifier < BaseRVFL
include ::Rumale::Base::Classifier

# Return the class labels.
# @return [Numo::Int32] (size: n_classes)
attr_reader :classes

# Return the weight vector in the hidden layer of RVFL network.
# @return [Numo::DFloat] (shape: [n_hidden_units, n_features])
attr_reader :random_weight_vec

# Return the bias vector in the hidden layer of RVFL network.
# @return [Numo::DFloat] (shape: [n_hidden_units])
attr_reader :random_bias

# Return the weight vector.
# @return [Numo::DFloat] (shape: [n_features + n_hidden_units, n_classes])
attr_reader :weight_vec

# Return the random generator.
# @return [Random]
attr_reader :rng

# Create a new classifier with RVFL network.
#
# @param hidden_units [Array] The number of units in the hidden layer.
# @param reg_param [Float] The regularization parameter.
# @param scale [Float] The scale parameter for random weight and bias.
# @param random_seed [Integer] The seed value using to initialize the random generator.
def initialize(hidden_units: 128, reg_param: 100.0, scale: 1.0, random_seed: nil)
super
end

# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [RVFLClassifier] The learned classifier itself.
def fit(x, y)
x = ::Rumale::Validation.check_convert_sample_array(x)
y = ::Rumale::Validation.check_convert_label_array(y)
::Rumale::Validation.check_sample_size(x, y)
raise 'RVFLClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false)

@classes = Numo::NArray[*y.to_a.uniq.sort]

partial_fit(x, one_hot_encode(y))

self
end

# Calculate confidence scores for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
def decision_function(x)
x = ::Rumale::Validation.check_convert_sample_array(x)

h = hidden_output(x)
h.dot(@weight_vec)
end

# Predict class labels for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
x = ::Rumale::Validation.check_convert_sample_array(x)

scores = decision_function(x)
n_samples, n_classes = scores.shape
label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes
@classes[label_ids].dup
end

private

def one_hot_encode(y)
Numo::DFloat.cast(::Rumale::Utils.binarize_labels(y))
end
end
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# frozen_string_literal: true

require 'spec_helper'

require 'numo/tiny_linalg'
Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)

RSpec.describe Rumale::NeuralNetwork::RVFLClassifier do
let(:x) { dataset[0] }
let(:y) { dataset[1] }
let(:classes) { y.to_a.uniq.sort }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_classes) { classes.size }
let(:hidden_units) { 64 }
let(:estimator) { described_class.new(hidden_units: hidden_units, reg_param: 1e4, random_seed: 1) }
let(:predicted) { estimator.predict(x) }
let(:score) { estimator.score(x, y) }

shared_examples 'classification' do
before { estimator.fit(x, y) }

it 'classifies given dataset.', :aggregate_failures do
expect(estimator.classes).to be_a(Numo::Int32)
expect(estimator.classes).to be_contiguous
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.random_weight_vec).to be_a(Numo::DFloat)
expect(estimator.random_weight_vec).to be_contiguous
expect(estimator.random_weight_vec.ndim).to eq(2)
expect(estimator.random_weight_vec.shape[0]).to eq(n_features)
expect(estimator.random_weight_vec.shape[1]).to eq(hidden_units)
expect(estimator.random_bias).to be_a(Numo::DFloat)
expect(estimator.random_bias).to be_contiguous
expect(estimator.random_bias.ndim).to eq(1)
expect(estimator.random_bias.shape[0]).to eq(hidden_units)
expect(estimator.weight_vec).to be_a(Numo::DFloat)
expect(estimator.weight_vec).to be_contiguous
expect(estimator.weight_vec.ndim).to eq(2)
expect(estimator.weight_vec.shape[0]).to eq(n_features + hidden_units)
expect(estimator.weight_vec.shape[1]).to eq(n_classes)
expect(predicted).to be_a(Numo::Int32)
expect(predicted).to be_contiguous
expect(predicted.ndim).to eq(1)
expect(predicted.shape[0]).to eq(n_samples)
expect(predicted).to eq(y)
expect(score).to eq(1.0)
end
end

context 'when the number of hidden units is less than the number of samples' do
context 'when binary classification problem' do
let(:dataset) { xor_dataset }

it_behaves_like 'classification'
end

context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }

it_behaves_like 'classification'
end
end

context 'when the number of hidden units is greater than the number of samples' do
let(:hidden_units) { 512 }

context 'when binary classification problem' do
let(:dataset) { xor_dataset }

it_behaves_like 'classification'
end

context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }

it_behaves_like 'classification'
end
end
end

0 comments on commit 71a7978

Please sign in to comment.