Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hierarchical Clustering & some docstring fixes #9

Merged
merged 12 commits into from
Sep 6, 2022
126 changes: 124 additions & 2 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Distances

# ===================================================================
## EXPORTS
export KMeans, KMedoids, DBSCAN
export KMeans, KMedoids, DBSCAN, HierarchicalClustering

# ===================================================================
## CONSTANTS
Expand All @@ -31,6 +31,7 @@ const PKG = "MLJClusteringInterface"
@mlj_model mutable struct KMeans <: MMI.Unsupervised
k::Int = 3::(_ ≥ 2)
metric::SemiMetric = SqEuclidean()
init = :kmpp
end

function MMI.fit(model::KMeans, verbosity::Int, X)
Expand Down Expand Up @@ -169,10 +170,51 @@ end
MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)


# # HierarchicalClustering
@mlj_model mutable struct HierarchicalClustering <: MMI.Static
linkage::Symbol = :single :: (_ ∈ (:single, :average, :complete, :ward, :ward_presquared))
metric::SemiMetric = SqEuclidean()
branchorder::Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal))
h::Union{Nothing,Float64} = nothing
k::Int = 3
end
"""
struct DendrogramCutter{T}
dendrogram::T
end

Callable object to cut a dendrogram.
"""
struct DendrogramCutter{T}
dendrogram::T
end
"""
(cutter::DendrogramCutter)(; h = nothing, k = 3)

Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
"""
function (cutter::DendrogramCutter)(; h = nothing, k = 3)
MMI.categorical(Cl.cutree(cutter.dendrogram, k = k, h = h))
end
function Base.show(io::IO, ::DendrogramCutter)
print(io, "Dendrogram Cutter.")
end

function MMI.predict(model::HierarchicalClustering, ::Nothing, X)
Xarray = MMI.matrix(X)
d = pairwise(model.metric, Xarray, dims = 1) # n x n
dendrogram = Cl.hclust(d, linkage = model.linkage, branchorder = model.branchorder)
cutter = DendrogramCutter(dendrogram)
yhat = cutter(h = model.h, k = model.k)
return yhat, (; cutter, dendrogram)
end

MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)

# # METADATA

metadata_pkg.(
(KMeans, KMedoids, DBSCAN),
(KMeans, KMedoids, DBSCAN, HierarchicalClustering),
name="Clustering",
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
url="https://github.com/JuliaStats/Clustering.jl",
Expand Down Expand Up @@ -205,6 +247,12 @@ metadata_model(
path = "$(PKG).DBSCAN"
)

metadata_model(
HierarchicalClustering,
human_name = "hierarchical clusterer",
input = MMI.Table(Continuous),
path = "$(PKG).HierarchicalClustering"
)

"""
$(MMI.doc_header(KMeans))
Expand Down Expand Up @@ -477,4 +525,78 @@ scatter(points, color=colors)
"""
DBSCAN

"""
$(MMI.doc_header(HierarchicalClustering))

[Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
clustering algorithm that organizes the data in a dendrogram based on distances between
groups of points and computes cluster assignments by cutting the dendrogram at a given
height. More information is available at the [Clustering.jl
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
machine report (see below).

This is a static implementation, i.e., it does not generalize to new data instances, and
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
[`KMedoids`](@ref).

In MLJ or MLJBase, create a machine with

mach = machine(model)

# Hyper-parameters

- `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)

- `metric = SqEuclidean`: metric (see `Distances.jl` for available metrics)

- `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)

- `h = nothing`: height at which the dendrogram is cut

- `k = 3`: number of clusters.

If both `k` and `h` are specified, it is guaranteed that the number of clusters is not less than `k` and their height is not above `h`.


# Operations

- `predict(mach, X)`: return cluster label assignments, as an unordered
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
columns are of scitype `Continuous`; check column scitypes with `schema(X)`.


# Report

After calling `predict(mach)`, the fields of `report(mach)` are:

- `dendrogram`: the dendrogram that was computed when calling `predict`.

- `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters (see example below).

# Examples

```
using MLJ

X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X

HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
model = HierarchicalClustering(linkage = :complete)
mach = machine(model)

# compute and output cluster assignments for observations in `X`:
yhat = predict(mach, X)

# plot dendrogram:
using StatsPlots
plot(report(mach).dendrogram)

# make new predictions by cutting the dendrogram at another height
report(mach).cutter(h = 2.5)
```

"""
HierarchicalClustering

ablaom marked this conversation as resolved.
Show resolved Hide resolved
end # module
19 changes: 18 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,25 @@ end

end

# # HierarchicalClustering

@testset "HierarchicalClustering" begin
h = Inf; k = 1; linkage = :complete; bo = :optimal;
metric = Distances.Euclidean()
mach = machine(HierarchicalClustering(h = h, k = k, metric = metric,
linkage = linkage, branchorder = bo))
yhat = predict(mach, X)
@test length(union(yhat)) == 1 # uses h = Inf
cutter = report(mach).cutter
@test length(union(cutter(k = 4))) == 4 # uses k = 4
dendro = Clustering.hclust(Distances.pairwise(metric, hcat(X...), dims = 1),
linkage = linkage, branchorder = bo)
@test cutter(k = 2) == Clustering.cutree(dendro, k = 2)
@test report(mach).dendrogram.heights == dendro.heights
end

@testset "MLJ interface" begin
models = [KMeans, KMedoids, DBSCAN]
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
failures, summary = MLJTestIntegration.test(
models,
X;
Expand Down