Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] concat operation on distributions #499

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@
"code"
]
},
{
"login": "sairevanth25",
"name": "Sai Revanth Gowravajhala",
"avatar_url": "https://avatars.githubusercontent.com/SaiRevanth25",
"profile": "https://github.com/SaiRevanth25",
"contributions": [
"code"
]
},
{
"login": "malikrafsan",
"name": "Malik Akbar Hashemi Rafsanjani",
Expand Down
4 changes: 4 additions & 0 deletions skpro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"Beta",
"Binomial",
"ChiSquared",
"concat",
"ConcatDistr",
"Delta",
"Empirical",
"Exponential",
Expand Down Expand Up @@ -39,11 +41,13 @@
"Weibull",
]

from skpro.distributions._concat import ConcatDistr
from skpro.distributions.alpha import Alpha
from skpro.distributions.beta import Beta
from skpro.distributions.binomial import Binomial
from skpro.distributions.chi_squared import ChiSquared
from skpro.distributions.compose import IID
from skpro.distributions.concat import concat
from skpro.distributions.delta import Delta
from skpro.distributions.empirical import Empirical
from skpro.distributions.exponential import Exponential
Expand Down
82 changes: 82 additions & 0 deletions skpro/distributions/_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
"""Base class for concat operation."""

__author__ = ["SaiRevanth25"]

import pandas as pd
from skbase.base._meta import _MetaObjectMixin


class ConcatDistr(_MetaObjectMixin):
"""Concatenate the given distributions along specified axis.

Parameters
----------
distributions : list
list of distributions
axis : {0/'index', 1/'columns'}, default 0
The axis to concatenate along
"""

def __init__(self, distributions, axis=0):
"""Initialize concat with list of distributions and axis for concatenation."""
self.distributions = distributions
self.axis = axis
self.distribution_names = [dist.name for dist in distributions]

def mean(self):
"""Calculate and concatenate means for each distribution."""
means = [dist.mean() for dist in self.distributions]
concatenated = pd.concat(means, axis=self.axis, ignore_index=True)
concatenated.index = self._generate_index(len(concatenated))
return concatenated

def var(self):
"""Calculate and concatenate variances for each distribution."""
variances = [dist.var() for dist in self.distributions]
concatenated = pd.concat(variances, axis=self.axis, ignore_index=True)
concatenated.index = self._generate_index(len(concatenated))
return concatenated

def pdf(self, x):
"""Concatenate PDFs of the distributions for a given value of `x`."""
pdfs = []
for dist in self.distributions:
try:
pdf_values = dist.pdf(x)
pdfs.append(pdf_values)
except ValueError as e:
raise ValueError(
f"Error in pdf computation for distribution {dist.name}: {str(e)}"
)

concatenated = pd.concat(pdfs, axis=self.axis, ignore_index=True)
concatenated.index = self._generate_index(len(concatenated))
return concatenated

def cdf(self, x):
"""Concatenate CDFs for each distribution at a given value of `x`."""
cdfs = []
for dist in self.distributions:
try:
cdf_values = dist.cdf(x)
cdfs.append(cdf_values)
except ValueError as e:
raise ValueError(
f"Error in cdf computation for distribution {dist.name}: {str(e)}"
)

concatenated = pd.concat(cdfs, axis=self.axis, ignore_index=True)
concatenated.index = self._generate_index(len(concatenated))
return concatenated

def _generate_index(self, length):
"""Generate index for concatenated result."""
if length != len(self.distribution_names):
return pd.RangeIndex(start=0, stop=length)
return self.distribution_names

# todo: Constructing a new distribution when the two distributions are same.
def _constr_distribution(self):
"""Construct a new distrbution when the distributions are same."""
pass
54 changes: 54 additions & 0 deletions skpro/distributions/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
"""Concat operation."""

__author__ = ["SaiRevanth25"]

from skpro.distributions._concat import ConcatDistr


def concat(objs, axis=0):
"""
Concatenate a list of distributions into a ConcatDistr.

Parameters
----------
objs : list
List of distribution-like objects to concatenate.
axis : {0/'index', 1/'columns'}, default 0
Axis to concatenate along.

Returns
-------
ConcatDistr
An object representing the concatenation of the given distributions.

Examples
--------
>>> import skpro.distributions as skpro
>>> d1 = Normal(mu=[[1, 2], [3, 4]], sigma=1)
>>> d2 = Normal(mu=0, sigma = [[2, 42]])
>>> skpro.concat([d1,d2]).mean()
0 1
0 1 2
1 3 4
2 0 0
>>> skpro.concat([d1,d2]).var()

0 1
0 1 1
1 1 1
2 4 1764
>>> d3 = Gamma(alpha=[[5, 2]], beta=4)
>>> d4 = Laplace(mu= [5,7], scale=[2,8])
>>> skpro.concat([d2,d3,d4]).pdf(x=1)
0 1
Normal 4.0000 1764.000
Gamma 0.3125 0.125
Laplace 8.0000 128.000
"""
if not isinstance(objs, list):
raise ValueError("`objs` must be a list of distribution-like objects.")
if axis not in [0, 1, "index", "columns"]:
raise ValueError("`axis` must be one of {0, 1, 'index', 'columns'}.")

return ConcatDistr(objs, axis=axis)