diff --git a/rumale-neural_network/lib/rumale/neural_network.rb b/rumale-neural_network/lib/rumale/neural_network.rb index ff849d34..adbad8a2 100644 --- a/rumale-neural_network/lib/rumale/neural_network.rb +++ b/rumale-neural_network/lib/rumale/neural_network.rb @@ -11,4 +11,5 @@ require_relative 'neural_network/rbf_classifier' require_relative 'neural_network/rbf_regressor' require_relative 'neural_network/base_rvfl' +require_relative 'neural_network/rvfl_classifier' require_relative 'neural_network/rvfl_regressor' diff --git a/rumale-neural_network/lib/rumale/neural_network/rvfl_classifier.rb b/rumale-neural_network/lib/rumale/neural_network/rvfl_classifier.rb new file mode 100644 index 00000000..fec6f3aa --- /dev/null +++ b/rumale-neural_network/lib/rumale/neural_network/rvfl_classifier.rb @@ -0,0 +1,108 @@ +# frozen_string_literal: true + +require 'rumale/base/classifier' +require 'rumale/neural_network/base_rvfl' +require 'rumale/utils' +require 'rumale/validation' + +module Rumale + module NeuralNetwork + # RVFLClassifier is a class that implements classifier based on random vector functional link (RVFL) network. + # The current implementation uses sigmoid function as activation function. + # + # @example + # require 'numo/tiny_linalg' + # Numo::Linalg = Numo::TinyLinalg + # + # require 'rumale/neural_network/rvfl_classifier' + # + # estimator = Rumale::NeuralNetwork::RVFLClassifier.new(hidden_units: 128, reg_param: 100.0) + # estimator.fit(training_samples, traininig_labels) + # results = estimator.predict(testing_samples) + # + # *Reference* + # - Malik, A. K., Gao, R., Ganaie, M. A., Tanveer, M., and Suganthan, P. N., "Random vector functional link network: recent developments, applications, and future directions," Applied Soft Computing, vol. 143, 2023. + # - Zhang, L., and Suganthan, P. N., "A comprehensive evaluation of random vector functional link networks," Information Sciences, vol. 367--368, pp. 1094--1105, 2016. + class RVFLClassifier < BaseRVFL + include ::Rumale::Base::Classifier + + # Return the class labels. + # @return [Numo::Int32] (size: n_classes) + attr_reader :classes + + # Return the weight vector in the hidden layer of RVFL network. + # @return [Numo::DFloat] (shape: [n_hidden_units, n_features]) + attr_reader :random_weight_vec + + # Return the bias vector in the hidden layer of RVFL network. + # @return [Numo::DFloat] (shape: [n_hidden_units]) + attr_reader :random_bias + + # Return the weight vector. + # @return [Numo::DFloat] (shape: [n_features + n_hidden_units, n_classes]) + attr_reader :weight_vec + + # Return the random generator. + # @return [Random] + attr_reader :rng + + # Create a new classifier with RVFL network. + # + # @param hidden_units [Array] The number of units in the hidden layer. + # @param reg_param [Float] The regularization parameter. + # @param scale [Float] The scale parameter for random weight and bias. + # @param random_seed [Integer] The seed value using to initialize the random generator. + def initialize(hidden_units: 128, reg_param: 100.0, scale: 1.0, random_seed: nil) + super + end + + # Fit the model with given training data. + # + # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model. + # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model. + # @return [RVFLClassifier] The learned classifier itself. + def fit(x, y) + x = ::Rumale::Validation.check_convert_sample_array(x) + y = ::Rumale::Validation.check_convert_label_array(y) + ::Rumale::Validation.check_sample_size(x, y) + raise 'RVFLClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false) + + @classes = Numo::NArray[*y.to_a.uniq.sort] + + partial_fit(x, one_hot_encode(y)) + + self + end + + # Calculate confidence scores for samples. + # + # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores. + # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample. + def decision_function(x) + x = ::Rumale::Validation.check_convert_sample_array(x) + + h = hidden_output(x) + h.dot(@weight_vec) + end + + # Predict class labels for samples. + # + # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels. + # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample. + def predict(x) + x = ::Rumale::Validation.check_convert_sample_array(x) + + scores = decision_function(x) + n_samples, n_classes = scores.shape + label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes + @classes[label_ids].dup + end + + private + + def one_hot_encode(y) + Numo::DFloat.cast(::Rumale::Utils.binarize_labels(y)) + end + end + end +end diff --git a/rumale-neural_network/spec/rumale/neural_network/rvfl_classifier_spec.rb b/rumale-neural_network/spec/rumale/neural_network/rvfl_classifier_spec.rb new file mode 100644 index 00000000..037aba4d --- /dev/null +++ b/rumale-neural_network/spec/rumale/neural_network/rvfl_classifier_spec.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +require 'spec_helper' + +require 'numo/tiny_linalg' +Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg) + +RSpec.describe Rumale::NeuralNetwork::RVFLClassifier do + let(:x) { dataset[0] } + let(:y) { dataset[1] } + let(:classes) { y.to_a.uniq.sort } + let(:n_samples) { x.shape[0] } + let(:n_features) { x.shape[1] } + let(:n_classes) { classes.size } + let(:hidden_units) { 64 } + let(:estimator) { described_class.new(hidden_units: hidden_units, reg_param: 1e4, random_seed: 1) } + let(:predicted) { estimator.predict(x) } + let(:score) { estimator.score(x, y) } + + shared_examples 'classification' do + before { estimator.fit(x, y) } + + it 'classifies given dataset.', :aggregate_failures do + expect(estimator.classes).to be_a(Numo::Int32) + expect(estimator.classes).to be_contiguous + expect(estimator.classes.ndim).to eq(1) + expect(estimator.classes.shape[0]).to eq(n_classes) + expect(estimator.random_weight_vec).to be_a(Numo::DFloat) + expect(estimator.random_weight_vec).to be_contiguous + expect(estimator.random_weight_vec.ndim).to eq(2) + expect(estimator.random_weight_vec.shape[0]).to eq(n_features) + expect(estimator.random_weight_vec.shape[1]).to eq(hidden_units) + expect(estimator.random_bias).to be_a(Numo::DFloat) + expect(estimator.random_bias).to be_contiguous + expect(estimator.random_bias.ndim).to eq(1) + expect(estimator.random_bias.shape[0]).to eq(hidden_units) + expect(estimator.weight_vec).to be_a(Numo::DFloat) + expect(estimator.weight_vec).to be_contiguous + expect(estimator.weight_vec.ndim).to eq(2) + expect(estimator.weight_vec.shape[0]).to eq(n_features + hidden_units) + expect(estimator.weight_vec.shape[1]).to eq(n_classes) + expect(predicted).to be_a(Numo::Int32) + expect(predicted).to be_contiguous + expect(predicted.ndim).to eq(1) + expect(predicted.shape[0]).to eq(n_samples) + expect(predicted).to eq(y) + expect(score).to eq(1.0) + end + end + + context 'when the number of hidden units is less than the number of samples' do + context 'when binary classification problem' do + let(:dataset) { xor_dataset } + + it_behaves_like 'classification' + end + + context 'when multiclass classification problem' do + let(:dataset) { three_clusters_dataset } + + it_behaves_like 'classification' + end + end + + context 'when the number of hidden units is greater than the number of samples' do + let(:hidden_units) { 512 } + + context 'when binary classification problem' do + let(:dataset) { xor_dataset } + + it_behaves_like 'classification' + end + + context 'when multiclass classification problem' do + let(:dataset) { three_clusters_dataset } + + it_behaves_like 'classification' + end + end +end