From 888431abd257ce9d47a52da981e9ceac9968ce3e Mon Sep 17 00:00:00 2001 From: Cong Date: Fri, 12 Aug 2022 10:41:02 +0900 Subject: [PATCH] Set test data SEED --- .gitignore | 1 + data.jl | 30 ++++++++----- data/input/lapros-testdata-n10-m3.csv | 11 ----- data/output/lapros-rankdata-n10-m3.csv | 7 --- lapros.jl | 61 ++------------------------ 5 files changed, 25 insertions(+), 85 deletions(-) create mode 100644 .gitignore delete mode 100644 data/input/lapros-testdata-n10-m3.csv delete mode 100644 data/output/lapros-rankdata-n10-m3.csv diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1269488 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +data diff --git a/data.jl b/data.jl index 74c2ddd..adf3ec7 100644 --- a/data.jl +++ b/data.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.9 +# v0.19.11 using Markdown using InteractiveUtils @@ -12,21 +12,30 @@ begin using DataFrames: DataFrame import Random using StatsBase: sample - Random.seed!(333) + SEED = 333 + Random.seed!(SEED) end; +# ╔═╡ 2f312588-71c9-4d12-8156-ac5a873fd0dc +data_folder = "data/input/" + # ╔═╡ 83ef4c9e-0e27-11ed-2fa2-1fab901f3b13 -n = 10*10^6 # number of samples +n = 10*10^0 # number of samples # ╔═╡ bef6e63d-205c-420a-b2f7-e9d45ef33c99 m = 3 # ╔═╡ 83b15f19-8f2c-4201-a0a3-23fb5c38b182 -labels = sample(1:m, n) +labels = let + Random.seed!(SEED) + sample(1:m, n) +end # ╔═╡ cd7ec936-09f6-4412-b5f4-5c31a4639a5b -probas = @chain rand(Float64, (n, m)) begin - _ ./ sum(_, dims=2) +probas = let + Random.seed!(SEED) + ps = rand(Float64, (n, m)) + ps ./ sum(ps, dims=2) end # ╔═╡ bbd796e5-d390-4b60-88bf-bd904e9a746d @@ -40,10 +49,10 @@ begin end # ╔═╡ 54e5c023-6849-4dbd-ae49-92b49836042d -begin +let csv_limit = 10^3 fp_postfix = if n ≤ csv_limit "csv" else "arrow" end - data_fp = "/opt/ml-data/lapros/lapros-testdata-n$(n)-m$(m).$(fp_postfix)" + data_fp = joinpath(data_folder, "lapros-testdata-n$(n)-m$(m).$(fp_postfix)") if endswith(data_fp, ".csv") CSV.write(data_fp, df) else @@ -523,11 +532,12 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" """ # ╔═╡ Cell order: +# ╠═2f312588-71c9-4d12-8156-ac5a873fd0dc # ╠═83ef4c9e-0e27-11ed-2fa2-1fab901f3b13 # ╠═bef6e63d-205c-420a-b2f7-e9d45ef33c99 # ╠═54e5c023-6849-4dbd-ae49-92b49836042d -# ╟─83b15f19-8f2c-4201-a0a3-23fb5c38b182 -# ╟─cd7ec936-09f6-4412-b5f4-5c31a4639a5b +# ╠═83b15f19-8f2c-4201-a0a3-23fb5c38b182 +# ╠═cd7ec936-09f6-4412-b5f4-5c31a4639a5b # ╠═bbd796e5-d390-4b60-88bf-bd904e9a746d # ╠═a7ebe81b-4a9a-46aa-b178-39a9b7fd0238 # ╟─00000000-0000-0000-0000-000000000001 diff --git a/data/input/lapros-testdata-n10-m3.csv b/data/input/lapros-testdata-n10-m3.csv deleted file mode 100644 index d51489b..0000000 --- a/data/input/lapros-testdata-n10-m3.csv +++ /dev/null @@ -1,11 +0,0 @@ -label,proba1,proba2,proba3 -1,0.15048253743933107,0.3064865186051135,0.5430309439555554 -2,0.8593108262431841,0.12815163042435118,0.012537543332464719 -3,0.014835081831322596,0.3987596690229098,0.5864052491457676 -3,0.5164609058808856,0.26564881699228043,0.2178902771268339 -3,0.3251811443028257,0.2029420837819364,0.4718767719152379 -1,0.4561526601098722,0.4271195769560638,0.11672776293406407 -1,0.6095060231908339,0.004991019908860957,0.38550295690030506 -2,0.5896541086861277,0.2486570360239721,0.16168885528990024 -1,0.4880635634900362,0.31229820034596484,0.19963823616399895 -2,0.4545821835102344,0.3498501399434297,0.19556767654633597 diff --git a/data/output/lapros-rankdata-n10-m3.csv b/data/output/lapros-rankdata-n10-m3.csv deleted file mode 100644 index 3eb6681..0000000 --- a/data/output/lapros-rankdata-n10-m3.csv +++ /dev/null @@ -1,7 +0,0 @@ -id,err -1,0.39320883651112964 -2,0.5473276018918989 -4,0.29791019875914637 -6,0.15479851077312565 -8,0.1571654787352215 -9,0.008066230782862699 diff --git a/lapros.jl b/lapros.jl index 42f309b..490e69a 100644 --- a/lapros.jl +++ b/lapros.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.9 +# v0.19.11 using Markdown using InteractiveUtils @@ -11,12 +11,9 @@ begin import CSV using DataFrames using DataStructures - using Formatting: format - import Random + using Formatting: format using SparseArrays: spzeros, findnz - using StatsBase: sample using Statistics: mean - Random.seed!(333) end; # ╔═╡ 29cd589c-b3e9-4a65-9641-352129a27c1e @@ -24,8 +21,8 @@ using PlutoUI; TableOfContents() # ╔═╡ e019b221-ff69-448b-baf3-99e142368a5f (data_fp, output_fp) = ( - # ("./data/input/lapros-testdata-n10-m3.csv", "./data/output/lapros-rankdata-n10-m3.csv") - ("/opt/ml-data/lapros/lapros-testdata-n10000000-m3.arrow", "/opt/ml-data/lapros/lapros-rankdata-n10000000-m3.arrow") + ("./data/input/lapros-testdata-n10-m3.csv", "./data/output/lapros-rankdata-n10-m3.csv") + # ("./data/input/lapros-testdata-n10000000-m3.arrow", "./data/output/lapros-rankdata-n10000000-m3.arrow") ) # ╔═╡ 1800ac8f-683d-416e-aef1-802af62ccc81 @@ -333,10 +330,8 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] Arrow = "~2.3.0" @@ -346,7 +341,6 @@ DataFrames = "~1.3.4" DataStructures = "~0.18.13" Formatting = "~0.4.2" PlutoUI = "~0.7.39" -StatsBase = "~0.33.19" """ # ╔═╡ 00000000-0000-0000-0000-000000000002 @@ -405,18 +399,6 @@ git-tree-sha1 = "8c4920235f6c561e401dfe569beb8b924adad003" uuid = "8be319e6-bccf-4806-a6f7-6fae938471bc" version = "0.5.0" -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "80ca332f6dcb2508adba68f22f551adb2d00a624" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.3" - -[[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.4" - [[deps.CodecLz4]] deps = ["Lz4_jll", "TranscodingStreams"] git-tree-sha1 = "59fe0cb37784288d6b9f1baebddbf75457395d40" @@ -490,12 +472,6 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.6" - [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" @@ -558,22 +534,11 @@ version = "1.1.4" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "b3364212fb5d870f724876ffcd34dd8ec6d98918" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.7" - [[deps.InvertedIndices]] git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" version = "1.1.0" -[[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" - [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -618,12 +583,6 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.15" - [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -766,18 +725,6 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "2c11d7290036fe7aac9038ff312d3b3a2a5bf89e" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.4.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "472d044a1c8df2b062b23f222573ad6837a615ba" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.19" - [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"