Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP cleanlab compatible mode #9

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions lapros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ p̂

# ╔═╡ 8fadfa66-a730-4f99-aa70-e25fec26cb09
"Compute the average model confidence for samples in each class.
If there is no sample in some specific class,
we take the avarage over all samples."
If there is no sample in some specific class, take null (missing)."
function avg_confidence(
ps::Array{Float64,2},
ys::Array{Int64,1},
Expand All @@ -109,12 +108,12 @@ function avg_confidence(
m = size(ps, 2)
t = zeros(Float64, m)
for i ∈ 1:m
x_yi = ys .== i
x_yi = (ys .== i)
# @show i, x_yi
if any(x_yi)
t[i] = mean(ps[x_yi, i])
else
t[i] = mean(ps[:, i])
t[i] = missing
end
end
t
Expand Down Expand Up @@ -152,31 +151,39 @@ md"""
Danh sách nhãn đáng tin nhất đối với từng mẫu là như sau, trong đó $0$ đánh dấu trường hợp không có nhãn phù hợp.
"""

# ╔═╡ 69e425b0-7b6b-4c2f-afba-70ea787bde54
# ╔═╡ 9f29abe2-f98a-44af-8420-895f2888780a
t2p

# ╔═╡ ce883c1c-0d6c-4f02-aae1-c5049904b4bd
"Find possible labels for each sample."
function find_likely_label(ps::Array{Float64,2}, t2p::Array{Float64,2})
pos_idx = (t2p .> 0)
ps = ifelse.(pos_idx, ps, missing)
am = argmax(ps, dims=2)
likely_labels = last.(Tuple.(am))
ll = ifelse.(any(t2p .≥ 0, dims=2)[:], likely_labels, missing)
vec(ll)
end

# ╔═╡ 591f5a84-b437-4b59-9961-bc110e34eeae
"Find the most likely labels for each sample.

## Params:

- mask_negative: For some specific sample, if the normalized probabilities are all negative then we use 0 to mark that there is no likely class label for the sample."
function find_likely_label(t2p::Array{Float64,2}, mask_negative::Bool=false)
For some specific sample, if the normalized probabilities are all negative then we mark that there is no likely class label for the sample."
function find_likely_label(t2p::Array{Float64,2})
# @show t2p
am = argmax(t2p, dims=2)
# @show am
# @time am = am[:]
# @show am
likely_labels = last.(Tuple.(am))
ll = if mask_negative
ifelse.(any(t2p .≥ 0, dims=2)[:], likely_labels, 0)
else
likely_labels
end
ll = ifelse.(any(t2p .≥ 0, dims=2)[:], likely_labels, missing)
vec(ll)
# @show ll
end

# ╔═╡ 60163432-e144-466d-8595-828ad715eb03
find_likely_label(p̂, t2p)

# ╔═╡ 5921b759-9474-414f-a25a-416aec299afb
@time l̂ = find_likely_label(t2p)

Expand All @@ -185,14 +192,12 @@ md"""
Xếp các mẫu vào ma trận có hàng thể hiện nhãn đã quan sát $\tilde{y}$, còn cột thể hiện nhãn đáng tin nhất $\hat{l}.$
"""

# ╔═╡ 45b1feb2-45d3-46df-973a-8b95a70ea164
# Xỹẏ = partition_X(l̂, ỹ, 1:m)

# ╔═╡ 2a49d3be-d6ae-43fd-8fe8-a14800c023c9
function partition_X(
ls::Array{Int64,1},
ys::Array{Int64,1},
M::Array{Int64,1})
M::UnitRange{Int64},
)
# @show ys
# @show ls
X_partition = [[] for i ∈ M, j ∈ M]
Expand All @@ -204,6 +209,12 @@ function partition_X(
X_partition
end

# ╔═╡ 45b1feb2-45d3-46df-973a-8b95a70ea164
# ╠═╡ disabled = true
#=╠═╡
Xỹẏ = partition_X(l̂, ỹ, 1:m)
╠═╡ =#

# ╔═╡ 9c063c91-0497-40c2-8eb1-bdf8060aaf70
md"""
## Độ khả nghi
Expand All @@ -214,6 +225,9 @@ md"""
# ╔═╡ 14d5fa77-7393-41ef-95e2-6848976c3346
t2p

# ╔═╡ eed60900-7662-4833-978e-b7cc160ce195
[missing] .!= [1]

# ╔═╡ f44f3c50-2c8f-4ea6-b712-07fa9f951ffb
function rank_suspicious(
ps::Array{Float64,2},
Expand All @@ -224,7 +238,7 @@ function rank_suspicious(
# @show ys
n,m = size(ps)
e = spzeros(Float64, n)
ids = (ls.≠ys) .&& (ls.≠0)
ids = (ls.≠ys) # .&& .!ismissing.(ls)
for k in (1:n)[ids]
e[k] = ps[k, ls[k]] - ps[k, ys[k]]
# @show k, ls[k], ys[k], e[k]
Expand All @@ -247,7 +261,8 @@ function lapros(
t = avg_confidence(p, y)
t2p = p .- t'
@time ll = find_likely_label(t2p)
@time rank = rank_suspicious(t2p, ll, y)
# @time rank = rank_suspicious(t2p, ll, y)
@time rank = rank_suspicious(p, ll, y)
end

# ╔═╡ 2c9a31f4-6abd-40ef-a3b3-f6eb39eda059
Expand All @@ -257,7 +272,10 @@ end
@assert all(errs .≥ 0)

# ╔═╡ a7dfaa5e-7376-49c7-b025-f16e4c5b8f56
@time rank = rank_suspicious(t2p, l̂, ỹ)
@time rank_t = rank_suspicious(t2p, l̂, ỹ)

# ╔═╡ 9f48fa46-20dd-4877-abcd-2b53090dc4c7
@time rank = rank_suspicious(p̂, l̂, ỹ)

# ╔═╡ 9bb332cb-fc80-42b3-a4d9-e4354bb4720f
@assert all(errs .≈ rank)
Expand Down Expand Up @@ -347,9 +365,8 @@ PlutoUI = "~0.7.39"
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised

julia_version = "1.8.0"
julia_version = "1.7.3"
manifest_format = "2.0"
project_hash = "4b6af47aac154a439bb522a11f7292ebe69c4ff4"

[[deps.AbstractPlutoDingetjes]]
deps = ["Pkg"]
Expand All @@ -359,7 +376,6 @@ version = "1.1.4"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"

[[deps.Arrow]]
deps = ["ArrowTypes", "BitIntegers", "CodecLz4", "CodecZstd", "DataAPI", "Dates", "Mmap", "PooledArrays", "SentinelArrays", "Tables", "TimeZones", "UUIDs"]
Expand Down Expand Up @@ -434,7 +450,6 @@ version = "3.45.0"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.5.2+0"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand Down Expand Up @@ -478,7 +493,6 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

[[deps.ExprTools]]
git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d"
Expand Down Expand Up @@ -567,12 +581,10 @@ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.3"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "7.84.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
Expand All @@ -581,7 +593,6 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.10.2+0"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -606,7 +617,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.0+0"

[[deps.Missings]]
deps = ["DataAPI"]
Expand All @@ -625,16 +635,13 @@ version = "0.7.3"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.2.1"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.20+0"

[[deps.OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
Expand All @@ -650,7 +657,6 @@ version = "2.3.2"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.8.0"

[[deps.PlutoUI]]
deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"]
Expand Down Expand Up @@ -700,7 +706,6 @@ version = "1.2.2"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.Scratch]]
deps = ["Dates"]
Expand Down Expand Up @@ -741,7 +746,6 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.0"

[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
Expand All @@ -758,7 +762,6 @@ version = "1.7.0"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"

[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
Expand Down Expand Up @@ -797,7 +800,6 @@ version = "1.4.2"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.12+3"

[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand All @@ -808,17 +810,14 @@ version = "1.5.2+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.1.1+0"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.48.0+0"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
"""

# ╔═╡ Cell order:
Expand Down Expand Up @@ -850,15 +849,19 @@ version = "17.4.0+0"
# ╟─21b37653-4dc8-48c6-bc31-cd0449b24bc2
# ╠═f8cbcd15-d2d0-4945-a80b-0e0e615d22e5
# ╟─5be2ee06-178d-42ec-97b2-b51dbe189633
# ╠═69e425b0-7b6b-4c2f-afba-70ea787bde54
# ╠═9f29abe2-f98a-44af-8420-895f2888780a
# ╠═60163432-e144-466d-8595-828ad715eb03
# ╠═ce883c1c-0d6c-4f02-aae1-c5049904b4bd
# ╠═5921b759-9474-414f-a25a-416aec299afb
# ╠═591f5a84-b437-4b59-9961-bc110e34eeae
# ╟─451f8698-2c6c-4cf1-8e1a-ede8cc1c535b
# ╠═45b1feb2-45d3-46df-973a-8b95a70ea164
# ╠═2a49d3be-d6ae-43fd-8fe8-a14800c023c9
# ╟─9c063c91-0497-40c2-8eb1-bdf8060aaf70
# ╠═a7dfaa5e-7376-49c7-b025-f16e4c5b8f56
# ╠═9f48fa46-20dd-4877-abcd-2b53090dc4c7
# ╠═14d5fa77-7393-41ef-95e2-6848976c3346
# ╠═eed60900-7662-4833-978e-b7cc160ce195
# ╠═f44f3c50-2c8f-4ea6-b712-07fa9f951ffb
# ╠═7edc3678-03fe-4eb8-ab39-f7033a06e93d
# ╠═b33c68aa-0a24-4622-8e4d-d9abbce885a5
Expand Down