Skip to content

Commit

Permalink
Merge pull request #9 from dinarior/dist_bug_fix
Browse files Browse the repository at this point in the history
Dist bug fix
  • Loading branch information
dinarior authored Oct 31, 2019
2 parents dc00e7e + d07dd99 commit 3708adf
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DPMMSubClusters"
uuid = "2841fd70-8698-11e9-176d-6dfa142d2ee7"
authors = ["Or Dinari <[email protected]>"]
version = "0.1.3"
version = "0.1.4"

[deps]
CatViews = "81a5f4ea-a946-549a-aa7e-2a7f63a27d31"
Expand Down
20 changes: 12 additions & 8 deletions src/dp-parallel-sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,31 +326,35 @@ function run_model(dp_model, first_iter, model_params="none", prev_time = 0)
group_step(dp_model.group, no_more_splits, final, i==1)
iter_time = time() - prev_time
push!(iter_count,iter_time)
push!(liklihood_history,calculate_posterior(dp_model))

push!(cluster_count_history,length(dp_model.group.local_clusters))
group_labels = Array(dp_model.group.labels)

if ground_truth != nothing
group_labels = Array(dp_model.group.labels)
push!(v_score_history, varinfo(Int.(ground_truth),group_labels))
push!(nmi_score_history, mutualinfo(Int.(ground_truth),group_labels,normed=true))
else
push!(v_score_history, "no gt")
push!(nmi_score_history, "no gt")
end
if use_verbose
push!(liklihood_history,calculate_posterior(dp_model))
println("Iteration: " * string(i) * " || Clusters count: " *
string(cluster_count_history[end]) *
" || Log posterior: " * string(liklihood_history[end]) *
" || Vi score: " * string(v_score_history[end]) *
" || NMI score: " * string(nmi_score_history[end]) *
" || Iter Time:" * string(iter_time) *
" || Total time:" * string(sum(iter_count)))
else
push!(liklihood_history,1)
end
if length(dp_model.group.local_clusters) > cur_parr_count
cur_parr_count += max(20,length(dp_model.group.local_clusters))
@sync for i in (nworkers()== 0 ? procs() : workers())
@spawnat i set_parr_worker(dp_model.group.labels,cur_parr_count)
end
end
# if length(dp_model.group.local_clusters) > cur_parr_count
# cur_parr_count += max(20,length(dp_model.group.local_clusters))
# @sync for i in (nworkers()== 0 ? procs() : workers())
# @spawnat i set_parr_worker(dp_model.group.labels,cur_parr_count)
# end
# end
if i % model_save_interval == 0 && should_save_model
println("Saving Model:")
# save_time = time()
Expand Down
25 changes: 13 additions & 12 deletions src/local_clusters_actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end


function sample_sub_clusters!(group::local_group)
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
@spawnat i sample_sub_clusters_worker!(group.points, group.labels, group.labels_subcluster)
end
end
Expand Down Expand Up @@ -69,7 +69,7 @@ end
function sample_labels!(labels::AbstractArray{Int64,1},
points::AbstractArray{Float32,2},
final::Bool)
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
@spawnat i sample_labels_worker!(labels,points,final)
end
end
Expand All @@ -96,10 +96,11 @@ function sample_labels_worker!(labels::AbstractArray{Int64,1},
pts = localpart(points)
log_weights = log.(clusters_weights)
parr = zeros(Float32,length(indices), length(clusters_vector))
@inbounds for (k,cluster) in enumerate(clusters_vector)
log_likelihood!(reshape((@view parr[:,k]),:,1), localpart(points),cluster.cluster_dist)
tic = time()
for (k,cluster) in enumerate(clusters_vector)
log_likelihood!(reshape((@view parr[:,k]),:,1), pts,cluster.cluster_dist)
end

# println("Time: "* string(time()-tic) * " size:" *string(size(pts)))
# parr = zeros(Float32,length(indices), length(clusters_vector))
# newx = copy(localpart(points)')
# @time log_likelihood!(parr, localpart(points),[c.cluster_dist for c in clusters_vector],log.(clusters_weights))
Expand Down Expand Up @@ -170,7 +171,7 @@ function create_suff_stats_dict_node_leader(group_pts, group_labels, group_subla
if indices == nothing
indices = collect(1:length(clusters_vector))
end
@sync for i in proc_ids
for i in proc_ids
workers_suff_dict[i] = remotecall(create_suff_stats_dict_worker,i,group_pts,
group_labels,
group_sublabels,
Expand Down Expand Up @@ -199,13 +200,13 @@ function create_suff_stats_dict_node_leader(group_pts, group_labels, group_subla
end


function update_suff_stats_posterior!(group::local_group,indices = nothing, use_leader::Bool = false)
function update_suff_stats_posterior!(group::local_group,indices = nothing, use_leader::Bool = true)
workers_suff_dict = Dict()
if indices == nothing
indices = collect(1:length(group.local_clusters))
end
if use_leader
@sync for i in collect(keys(leader_dict))
for i in collect(keys(leader_dict))
workers_suff_dict[i] = remotecall(create_suff_stats_dict_node_leader, i ,group.points,
group.labels,
group.labels_subcluster,
Expand All @@ -214,7 +215,7 @@ function update_suff_stats_posterior!(group::local_group,indices = nothing, use_
indices)
end
else
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
workers_suff_dict[i] = @spawnat i create_suff_stats_dict_worker(group.points,
group.labels,
group.labels_subcluster,
Expand Down Expand Up @@ -358,7 +359,7 @@ function check_and_split!(group::local_group, final::Bool)
end
all_indices = vcat(indices,new_indices)
if length(indices) > 0
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
@spawnat i split_cluster_local_worker!(group.labels,group.labels_subcluster,group.points,indices,new_indices)
end
end
Expand Down Expand Up @@ -387,7 +388,7 @@ function check_and_merge!(group::local_group, final::Bool)
mergable[1] = 0
end
end
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
@spawnat i merge_clusters_worker!(group,indices,new_indices)
end
return indices
Expand Down Expand Up @@ -467,7 +468,7 @@ function reset_splitted_clusters!(group::local_group, bad_clusters::Vector{Int64
for i in bad_clusters
group.local_clusters[i].cluster_params.logsublikelihood_hist = [-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf,-Inf]
end
@sync for i in (nworkers()== 0 ? procs() : workers())
for i in (nworkers()== 0 ? procs() : workers())
@spawnat i reset_bad_clusters_worker!(bad_clusters,group.points, group.labels, group.labels_subcluster)
end
update_suff_stats_posterior!(group,bad_clusters)
Expand Down

0 comments on commit 3708adf

Please sign in to comment.