Skip to content

Commit

Permalink
Update to SpecialFunctions 0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
dinarior committed Jan 13, 2020
1 parent 7f4640d commit bd09c02
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 19 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DPMMSubClusters"
uuid = "2841fd70-8698-11e9-176d-6dfa142d2ee7"
authors = ["Or Dinari <[email protected]>"]
version = "0.1.6"
version = "0.1.7"

[deps]
CatViews = "81a5f4ea-a946-549a-aa7e-2a7f63a27d31"
Expand All @@ -20,10 +20,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
CatViews = "1"
Clustering = "0.13.3"
DistributedArrays = "0, 1"
Distributions = "0.21.3"
Distributions = "0, 1"
JLD2 = "0, 1"
NPZ = "0, 1"
SpecialFunctions = "0, 1"
SpecialFunctions = "0.8,0.9"
StatsBase = "0,1"
julia = "1"

Expand Down
4 changes: 0 additions & 4 deletions src/DPMMSubClusters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ include("global_params.jl")
include("dp-parallel-sampling.jl")
include("data_generators.jl")

if length(procs()) == 1
new_procs = addprocs(1)
end


export generate_gaussian_data, generate_mnmm_data, dp_parallel_sampling, dp_parallel, run_model_from_checkpoint, save_model, calculate_posterior, fit, get_labels_histogram

Expand Down
4 changes: 2 additions & 2 deletions src/dp-parallel-sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,15 +444,15 @@ end


function calculate_posterior(model::dp_parallel_sampling)
log_posterior = lgamma(model.model_hyperparams.α) - lgamma(size(model.group.points,2)+model.model_hyperparams.α)
log_posterior = logabsgamma(model.model_hyperparams.α)[1] - logabsgamma(size(model.group.points,2)+model.model_hyperparams.α)[1]
for cluster in model.group.local_clusters
if cluster.cluster_params.cluster_params.suff_statistics.N == 0
continue
end
log_posterior += log_marginal_likelihood(cluster.cluster_params.cluster_params.hyperparams,
cluster.cluster_params.cluster_params.posterior_hyperparams,
cluster.cluster_params.cluster_params.suff_statistics)
log_posterior += log(model.model_hyperparams.α) + lgamma(cluster.cluster_params.cluster_params.suff_statistics.N)
log_posterior += log(model.model_hyperparams.α) + logabsgamma(cluster.cluster_params.cluster_params.suff_statistics.N)[1]
end
return log_posterior
end
Expand Down
4 changes: 2 additions & 2 deletions src/global_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ data_prefix = "data_prefix" #If the data file name is bob.npy, this should be '
iterations = 100
hard_clustering = false #Soft or hard assignments
initial_clusters = 1
argmax_sample_stop = 0 #Change to hard assignment from soft at iterations - argmax_sample_stop
split_stop = 0 #Stop split/merge moves at iterations - split_stop
argmax_sample_stop = 5 #Change to hard assignment from soft at iterations - argmax_sample_stop
split_stop = 5#Stop split/merge moves at iterations - split_stop

random_seed = nothing #When nothing, a random seed will be used.

Expand Down
6 changes: 3 additions & 3 deletions src/local_clusters_actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ function should_split_local!(should_split::AbstractArray{Float32,1},
log_likihood = log_marginal_likelihood(cp.hyperparams, post, cp.suff_statistics)

log_HR = log(α) +
lgamma(cpl.suff_statistics.N) + log_likihood_l +
lgamma(cpr.suff_statistics.N) + log_likihood_r -
(lgamma(cp.suff_statistics.N) + log_likihood)
logabsgamma(cpl.suff_statistics.N)[1] + log_likihood_l +
logabsgamma(cpr.suff_statistics.N)[1] + log_likihood_r -
(logabsgamma(cp.suff_statistics.N)[1] + log_likihood)
if log_HR > log(rand())
should_split .= 1
end
Expand Down
2 changes: 1 addition & 1 deletion src/priors/multinomial_prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
function log_marginal_likelihood(hyper::multinomial_hyper, posterior_hyper::multinomial_hyper, suff_stats::multinomial_sufficient_statistics)
D = length(suff_stats.points_sum)
logpi = log(pi)
val = lgamma(sum(hyper.α)) -lgamma(sum(posterior_hyper.α)) + sum(lgamma.(posterior_hyper.α) - lgamma.(hyper.α))
val = logabsgamma(sum(hyper.α))[1] -logabsgamma(sum(posterior_hyper.α))[1] + sum((x-> logabsgamma(x)[1]).(posterior_hyper.α) - (x-> logabsgamma(x)[1]).(hyper.α))
return val
end

Expand Down
6 changes: 3 additions & 3 deletions src/shared_actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ function should_merge!(should_merge::AbstractArray{Float32,1}, cpl::cluster_para
log_likihood_l = log_marginal_likelihood(cpl.hyperparams, cpl.posterior_hyperparams, cpl.suff_statistics)
log_likihood_r = log_marginal_likelihood(cpr.hyperparams, cpr.posterior_hyperparams, cpr.suff_statistics)
log_likihood = log_marginal_likelihood(cp.hyperparams, cp.posterior_hyperparams, cp.suff_statistics)
log_HR = (-log(α) + lgamma(α) -2*lgamma(0.5*α) + lgamma(cp.suff_statistics.N) -lgamma(cp.suff_statistics.N + α) +
lgamma(cpl.suff_statistics.N + 0.5*α)-lgamma(cpl.suff_statistics.N) - lgamma(cpr.suff_statistics.N) +
lgamma(cpr.suff_statistics.N + 0.5*α)+ log_likihood- log_likihood_l- log_likihood_r)
log_HR = (-log(α) + logabsgamma(α)[1] -2*logabsgamma(0.5*α)[1] + logabsgamma(cp.suff_statistics.N)[1] -logabsgamma(cp.suff_statistics.N + α)[1] +
logabsgamma(cpl.suff_statistics.N + 0.5*α)[1]-logabsgamma(cpl.suff_statistics.N)[1] - logabsgamma(cpr.suff_statistics.N)[1] +
logabsgamma(cpr.suff_statistics.N + 0.5*α)[1]+ log_likihood- log_likihood_l- log_likihood_r)
# log_HR = -(log(α) +
# lgamma(cpl.suff_statistics.N) + log_likihood_l +
# lgamma(cpr.suff_statistics.N) + log_likihood_r -
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
function log_multivariate_gamma(x::Number, D::Number)
res::Float32 = D*(D-1)/4*log(pi)
for d = 1:D
res += lgamma(x+(1-d)/2)
res += logabsgamma(x+(1-d)/2)[1]
end
return res
end
Expand Down
Binary file modified test/save_load_test/checkpoint_20.jld2
Binary file not shown.

0 comments on commit bd09c02

Please sign in to comment.