From 5be2fffad78aa9be2328d6158061bb4693db6698 Mon Sep 17 00:00:00 2001 From: Josh Price Date: Tue, 30 Jul 2024 23:45:32 +1000 Subject: [PATCH] Upgrade Nx (#45) --- .github/workflows/test.yml | 2 +- README.md | 6 +- examples/distributed.exs | 6 +- examples/intro.exs | 6 +- examples/knapsack.exs | 4 +- examples/metrics.exs | 6 +- examples/one_max.exs | 4 +- examples/rastrigin.exs | 6 +- examples/tsp.exs | 8 +-- lib/meow/ops.ex | 4 +- lib/meow_nx/crossover.ex | 57 +++++++++--------- lib/meow_nx/init.ex | 12 ++-- lib/meow_nx/metric.ex | 13 ++-- lib/meow_nx/mutation.ex | 20 +++--- lib/meow_nx/ops.ex | 65 ++++++++++++-------- lib/meow_nx/ops/permutation.ex | 35 ++++++----- lib/meow_nx/permutation.ex | 81 +++++++++++++++---------- lib/meow_nx/representation_spec.ex | 6 +- lib/meow_nx/selection.ex | 18 +++--- lib/meow_nx/utils.ex | 97 ++++++++++++++++++++---------- mix.exs | 4 +- mix.lock | 22 ++++--- notebooks/metrics.livemd | 6 +- notebooks/rastrigin_intro.livemd | 6 +- test/meow_nx/crossover_test.exs | 16 +++-- 25 files changed, 301 insertions(+), 209 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7cab629..5add8bb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: uses: erlef/setup-beam@v1 with: otp-version: "24.0" - elixir-version: "1.13.0" + elixir-version: "1.14.0" - name: Cache Mix uses: actions/cache@v2 with: diff --git a/README.md b/README.md index 9488e79..b25f2a8 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,8 @@ You can define the algorithm in a single Elixir script file like this: Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -32,7 +32,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/examples/distributed.exs b/examples/distributed.exs index a51e81a..2604140 100644 --- a/examples/distributed.exs +++ b/examples/distributed.exs @@ -8,8 +8,8 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -23,7 +23,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/examples/intro.exs b/examples/intro.exs index 762ab38..3c09dbf 100644 --- a/examples/intro.exs +++ b/examples/intro.exs @@ -2,8 +2,8 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -19,7 +19,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/examples/knapsack.exs b/examples/knapsack.exs index 8d6ca2c..a1e6007 100644 --- a/examples/knapsack.exs +++ b/examples/knapsack.exs @@ -10,8 +10,8 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) diff --git a/examples/metrics.exs b/examples/metrics.exs index a2d56e5..91e10d6 100644 --- a/examples/metrics.exs +++ b/examples/metrics.exs @@ -1,7 +1,7 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"}, + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"}, {:vega_lite, "~> 0.1.1"}, {:jason, "~> 1.4"} ]) @@ -17,7 +17,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/examples/one_max.exs b/examples/one_max.exs index dcc9d70..04e1917 100644 --- a/examples/one_max.exs +++ b/examples/one_max.exs @@ -3,8 +3,8 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) diff --git a/examples/rastrigin.exs b/examples/rastrigin.exs index 4910d96..03753c9 100644 --- a/examples/rastrigin.exs +++ b/examples/rastrigin.exs @@ -2,8 +2,8 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -17,7 +17,7 @@ defmodule Problem do defn evaluate(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/examples/tsp.exs b/examples/tsp.exs index 819eb54..551d325 100644 --- a/examples/tsp.exs +++ b/examples/tsp.exs @@ -8,9 +8,9 @@ Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"}, - {:req, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"}, + {:req, "~> 0.5.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -56,7 +56,7 @@ defmodule Problem do end {tsp_size, distances} = - Problem.load_distances("https://people.sc.fsu.edu/~jburkardt/datasets/tsp/fri26_d.txt") + Problem.load_distances("https://raw.githubusercontent.com/dada8397/mpi_tsp/master/fri26_d.txt") algorithm = Meow.objective(&Problem.evaluate(&1, distances)) diff --git a/lib/meow/ops.ex b/lib/meow/ops.ex index a1ee2af..04f464c 100644 --- a/lib/meow/ops.ex +++ b/lib/meow/ops.ex @@ -177,8 +177,8 @@ defmodule Meow.Ops do n when is_integer(n) -> n..n - min..max -> - min..max + min..max//step -> + min..max//step other -> raise ArgumentError, diff --git a/lib/meow_nx/crossover.ex b/lib/meow_nx/crossover.ex index 8779abd..27840f1 100644 --- a/lib/meow_nx/crossover.ex +++ b/lib/meow_nx/crossover.ex @@ -40,7 +40,7 @@ defmodule MeowNx.Crossover do * [On the Virtues of Parameterized Uniform Crossover](http://www.mli.gmu.edu/papers/91-95/91-18.pdf) """ - defn uniform(parents, opts \\ []) do + defn uniform(parents, prng_key, opts \\ []) do opts = keyword!(opts, probability: 0.5) probability = opts[:probability] @@ -49,8 +49,10 @@ defmodule MeowNx.Crossover do swapped_parents = Utils.swap_adjacent_rows(parents) + {random, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length}) + swap? = - Nx.random_uniform({half_n, length}) + random |> Nx.less_equal(probability) |> Utils.duplicate_rows() @@ -64,7 +66,7 @@ defmodule MeowNx.Crossover do a random split point and swaps all genes $x_i$, $y_i$ on one side of that split point. """ - defn single_point(parents) do + defn single_point(parents, prng_key) do {n, length} = Nx.shape(parents) half_n = div(n, 2) @@ -72,12 +74,11 @@ defmodule MeowNx.Crossover do # Generate n / 2 split points (like [5, 2, 3]), and replicate # them for adjacent parents (like [5, 5, 2, 2, 3, 3]) - split_idx = - Nx.random_uniform({half_n, 1}, 1, length) - |> Utils.duplicate_rows() + {random, _prng_key} = Nx.Random.uniform(prng_key, 1, length, shape: {half_n, 1}) + random = Nx.as_type(random, {:u, 32}) + split_idx = Utils.duplicate_rows(random) swap? = Nx.less_equal(split_idx, Nx.iota({1, length})) - Nx.select(swap?, swapped_parents, parents) end @@ -92,19 +93,14 @@ defmodule MeowNx.Crossover do * `:points` - the number of crossover points """ - defn multi_point(parents, opts \\ []) do + defn multi_point(parents, prng_key, opts \\ []) do opts = keyword!(opts, [:points]) points = opts[:points] {n, length} = Nx.shape(parents) half_n = div(n, 2) - transform({length, points}, fn {length, points} -> - unless Elixir.Kernel.<(points, length) do - raise ArgumentError, - "#{points}-point crossover is not valid for genome of length #{length}" - end - end) + validate_crossover(length, points) swapped_parents = Utils.swap_adjacent_rows(parents) @@ -112,8 +108,8 @@ defmodule MeowNx.Crossover do # then we convert each of them to 1-point crossover mask and finally # combine these masks using gen-wise XOR (sum modulo 2) - split_idx = - Utils.random_idx_without_replacement( + {split_idx, _prng_key} = + Utils.random_idx_without_replacement(prng_key, shape: {half_n, points, 1}, min: 1, max: length, @@ -129,6 +125,13 @@ defmodule MeowNx.Crossover do Nx.select(swap?, swapped_parents, parents) end + deftransformp validate_crossover(length, points) do + unless Elixir.Kernel.<(points, length) do + raise ArgumentError, + "#{points}-point crossover is not valid for genome of length #{length}" + end + end + @doc """ Performs blend-alpha crossover, also referred to as BLX-alpha. @@ -161,7 +164,7 @@ defmodule MeowNx.Crossover do * [Tackling Real-Coded Genetic Algorithms: Operators and Tools for Behavioural Analysis](https://sci2s.ugr.es/sites/default/files/files/ScientificImpact/AIRE12-1998.PDF), Section 4.3 * [Multiobjective Evolutionary Algorithms forElectric Power Dispatch Problem](https://www.researchgate.net/figure/Blend-crossover-operator-BLX_fig1_226044085), Fig. 1. """ - defn blend_alpha(parents, opts \\ []) do + defn blend_alpha(parents, prng_key, opts \\ []) do opts = keyword!(opts, alpha: 0.5) alpha = opts[:alpha] @@ -175,7 +178,8 @@ defmodule MeowNx.Crossover do # may be negative (for y < x), but then we shift from x # in the opposite direction, so it works as expected. - gamma = (1 + 2 * alpha) * Nx.random_uniform({half_n, length}) - alpha + {u, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length}) + gamma = (1 + 2 * alpha) * u - alpha gamma = Utils.duplicate_rows(gamma) x + gamma * (y - x) @@ -200,7 +204,7 @@ defmodule MeowNx.Crossover do * [Self-Adaptive Genetic Algorithms with Simulated Binary Crossover](https://eldorado.tu-dortmund.de/bitstream/2003/5370/1/ci61.pdf) * [Engineering Analysis and Design Using Genetic Algorithms / Lecture 4: Real-Coded Genetic Algorithms](https://engineering.purdue.edu/~sudhoff/ee630/Lecture04.pdf) """ - defn simulated_binary(parents, opts \\ []) do + defn simulated_binary(parents, prng_key, opts \\ []) do opts = keyword!(opts, [:eta]) eta = opts[:eta] @@ -209,17 +213,10 @@ defmodule MeowNx.Crossover do {x, y} = {parents, Utils.swap_adjacent_rows(parents)} - beta_base = - Nx.random_uniform({half_n, length}) - |> Nx.map(fn u -> - if Nx.less(u, 0.5) do - 2 * u - else - 1 / (2 * (1 - u)) - end - end) - - beta = Nx.power(beta_base, 1 / (eta + 1)) + {u, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length}) + beta_base = Nx.select(u < 0.5, 2 * u, 1 / (2 * (1 - u))) + + beta = Nx.pow(beta_base, 1 / (eta + 1)) beta = Utils.duplicate_rows(beta) 0.5 * ((1 + beta) * x + (1 - beta) * y) diff --git a/lib/meow_nx/init.ex b/lib/meow_nx/init.ex index 4106f3b..13c2760 100644 --- a/lib/meow_nx/init.ex +++ b/lib/meow_nx/init.ex @@ -22,14 +22,17 @@ defmodule MeowNx.Init do * `:max` - the maximum possible value of a gene. Required. """ - defn real_random_uniform(opts \\ []) do + defn real_random_uniform(prng_key, opts \\ []) do opts = keyword!(opts, [:n, :length, :min, :max]) n = opts[:n] length = opts[:length] min = opts[:min] max = opts[:max] - Nx.random_uniform({n, length}, min, max, type: {:f, 64}) + {random, _prng_key} = + Nx.Random.uniform(prng_key, min, max, shape: {n, length}, type: {:f, 64}) + + random end @doc """ @@ -42,11 +45,12 @@ defmodule MeowNx.Init do * `:length` - the length of a single genome. Required. """ - defn binary_random_uniform(opts \\ []) do + defn binary_random_uniform(prng_key, opts \\ []) do opts = keyword!(opts, [:n, :length]) n = opts[:n] length = opts[:length] - Nx.random_uniform({n, length}, 0, 2) |> Nx.as_type({:u, 8}) + {random, _prng_key} = Nx.Random.uniform(prng_key, 0, 2, shape: {n, length}) + Nx.as_type(random, {:u, 8}) end end diff --git a/lib/meow_nx/metric.ex b/lib/meow_nx/metric.ex index 78295e6..951a1df 100644 --- a/lib/meow_nx/metric.ex +++ b/lib/meow_nx/metric.ex @@ -64,15 +64,18 @@ defmodule MeowNx.Metric do defn fitness_entropy(_genomes, fitness, opts \\ []) do opts = keyword!(opts, [:precision]) - values = - transform({fitness, opts[:precision]}, fn - {fitness, nil} -> fitness - {fitness, precision} -> fitness |> Nx.divide(precision) |> Nx.round() - end) + values = transform_fitness(fitness, opts[:precision]) MeowNx.Utils.entropy(values) end + deftransformp transform_fitness(fitness, precision) do + case precision do + nil -> fitness + _ -> fitness |> Nx.divide(precision) |> Nx.round() + end + end + @doc """ Calculates the mean Euclidean distance between each pair of genomes. diff --git a/lib/meow_nx/mutation.ex b/lib/meow_nx/mutation.ex index caa4fff..c3ff2d5 100644 --- a/lib/meow_nx/mutation.ex +++ b/lib/meow_nx/mutation.ex @@ -35,7 +35,7 @@ defmodule MeowNx.Mutation do * `:min` - the upper bound of the range to draw from. Required. """ - defn replace_uniform(genomes, opts \\ []) do + defn replace_uniform(genomes, prng_key, opts \\ []) do opts = keyword!(opts, [:probability, :min, :max]) probability = opts[:probability] min = opts[:min] @@ -44,8 +44,9 @@ defmodule MeowNx.Mutation do shape = Nx.shape(genomes) # Mutate each gene separately with the given probability - mutate? = Nx.random_uniform(shape) |> Nx.less(probability) - mutated = Nx.random_uniform(shape, min, max) + {u, prng_key} = Nx.Random.uniform(prng_key, shape: shape) + mutate? = Nx.less(u, probability) + {mutated, _prng_key} = Nx.Random.uniform(prng_key, min, max, shape: shape) Nx.select(mutate?, mutated, genomes) end @@ -57,14 +58,15 @@ defmodule MeowNx.Mutation do * `:probability` - the probability of each gene getting mutated. Required. """ - defn bit_flip(genomes, opts \\ []) do + defn bit_flip(genomes, prng_key, opts \\ []) do opts = keyword!(opts, [:probability]) probability = opts[:probability] shape = Nx.shape(genomes) # Mutate each gene separately with the given probability - mutate? = Nx.random_uniform(shape) |> Nx.less(probability) + {u, _prng_key} = Nx.Random.uniform(prng_key, shape: shape) + mutate? = Nx.less(u, probability) mutated = Nx.subtract(1, genomes) Nx.select(mutate?, mutated, genomes) end @@ -91,7 +93,7 @@ defmodule MeowNx.Mutation do * [Adaptive Mutation Strategies for Evolutionary Algorithms](https://www.dynardo.de/fileadmin/Material_Dynardo/WOST/Paper/wost2.0/AdaptiveMutation.pdf), Section 3.1 """ - defn shift_gaussian(genomes, opts \\ []) do + defn shift_gaussian(genomes, prng_key, opts \\ []) do opts = keyword!(opts, [:probability, sigma: 1.0]) probability = opts[:probability] sigma = opts[:sigma] @@ -99,8 +101,10 @@ defmodule MeowNx.Mutation do shape = Nx.shape(genomes) # Mutate each gene separately with the given probability - mutate? = Nx.random_uniform(shape) |> Nx.less(probability) - mutated = genomes + Nx.random_normal(shape, 0.0, sigma) + {u, prng_key} = Nx.Random.uniform(prng_key, shape: shape) + mutate? = Nx.less(u, probability) + {noise, _prng_key} = Nx.Random.normal(prng_key, 0.0, sigma, shape: shape) + mutated = genomes + noise Nx.select(mutate?, mutated, genomes) end end diff --git a/lib/meow_nx/ops.ex b/lib/meow_nx/ops.ex index b02f90b..11b067f 100644 --- a/lib/meow_nx/ops.ex +++ b/lib/meow_nx/ops.ex @@ -21,7 +21,7 @@ defmodule MeowNx.Ops do @doc """ Builds a random initializer for the real representation. - See `MeowNx.Init.real_random_uniform/1` for more details. + See `MeowNx.Init.real_random_uniform/2` for more details. """ @doc type: :init @spec init_real_random_uniform(non_neg_integer(), non_neg_integer(), float(), float()) :: Op.t() @@ -36,7 +36,8 @@ defmodule MeowNx.Ops do out_representation: MeowNx.real_representation(), impl: fn population, _ctx -> Population.map_genomes(population, fn _genomes -> - Init.real_random_uniform(opts) + prng_key = MeowNx.Utils.prng_key() + Init.real_random_uniform(prng_key, opts) end) end } @@ -45,7 +46,7 @@ defmodule MeowNx.Ops do @doc """ Builds a random initializer for the binary representation. - See `MeowNx.Init.binary_random_uniform/1` for more details. + See `MeowNx.Init.binary_random_uniform/2` for more details. """ @doc type: :init @spec init_binary_random_uniform(non_neg_integer(), non_neg_integer()) :: Op.t() @@ -60,7 +61,8 @@ defmodule MeowNx.Ops do out_representation: MeowNx.binary_representation(), impl: fn population, _ctx -> Population.map_genomes(population, fn _genomes -> - Init.binary_random_uniform(opts) + prng_key = MeowNx.Utils.prng_key() + Init.binary_random_uniform(prng_key, opts) end) end } @@ -69,7 +71,7 @@ defmodule MeowNx.Ops do @doc """ Builds a tournament selection operation. - See `MeowNx.Selection.tournament/3` for more details. + See `MeowNx.Selection.tournament/4` for more details. """ @doc type: :selection @spec selection_tournament(non_neg_integer() | float()) :: Op.t() @@ -83,7 +85,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes_and_fitness(population, fn genomes, fitness -> - Selection.tournament(genomes, fitness, opts) + prng_key = MeowNx.Utils.prng_key() + Selection.tournament(genomes, fitness, prng_key, opts) end) end } @@ -115,7 +118,7 @@ defmodule MeowNx.Ops do @doc """ Builds a roulette selection operation. - See `MeowNx.Selection.roulette/3` for more details. + See `MeowNx.Selection.roulette/4` for more details. """ @doc type: :selection @spec selection_roulette(non_neg_integer() | float()) :: Op.t() @@ -129,7 +132,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes_and_fitness(population, fn genomes, fitness -> - Selection.roulette(genomes, fitness, opts) + prng_key = MeowNx.Utils.prng_key() + Selection.roulette(genomes, fitness, prng_key, opts) end) end } @@ -138,7 +142,7 @@ defmodule MeowNx.Ops do @doc """ Builds a stochastic universal sampling operation. - See `MeowNx.Selection.stochastic_universal_sampling/3` for more details. + See `MeowNx.Selection.stochastic_universal_sampling/4` for more details. """ @doc type: :selection @spec selection_stochastic_universal_sampling(non_neg_integer() | float()) :: Op.t() @@ -152,7 +156,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes_and_fitness(population, fn genomes, fitness -> - Selection.stochastic_universal_sampling(genomes, fitness, opts) + prng_key = MeowNx.Utils.prng_key() + Selection.stochastic_universal_sampling(genomes, fitness, prng_key, opts) end) end } @@ -161,7 +166,7 @@ defmodule MeowNx.Ops do @doc """ Builds a uniform crossover operation. - See `MeowNx.Crossover.uniform/2` for more details. + See `MeowNx.Crossover.uniform/3` for more details. """ @doc type: :crossover @spec crossover_uniform(float()) :: Op.t() @@ -175,7 +180,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Crossover.uniform(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Crossover.uniform(genomes, prng_key, opts) end) end } @@ -184,7 +190,7 @@ defmodule MeowNx.Ops do @doc """ Builds a single point crossover operation. - See `MeowNx.Crossover.single_point/1` for more details. + See `MeowNx.Crossover.single_point/2` for more details. """ @doc type: :crossover @spec crossover_single_point() :: Op.t() @@ -196,7 +202,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Crossover.single_point(genomes) + prng_key = MeowNx.Utils.prng_key() + Crossover.single_point(genomes, prng_key) end) end } @@ -205,7 +212,7 @@ defmodule MeowNx.Ops do @doc """ Builds a multi point crossover operation. - See `MeowNx.Crossover.multi_point/1` for more details. + See `MeowNx.Crossover.multi_point/3` for more details. """ @doc type: :crossover @spec crossover_multi_point(pos_integer()) :: Op.t() @@ -219,7 +226,8 @@ defmodule MeowNx.Ops do in_representations: @representations, impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Crossover.multi_point(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Crossover.multi_point(genomes, prng_key, opts) end) end } @@ -228,7 +236,7 @@ defmodule MeowNx.Ops do @doc """ Builds a blend-alpha crossover operation. - See `MeowNx.Crossover.blend_alpha/2` for more details. + See `MeowNx.Crossover.blend_alpha/3` for more details. """ @doc type: :crossover @spec crossover_blend_alpha(float()) :: Op.t() @@ -242,7 +250,8 @@ defmodule MeowNx.Ops do in_representations: [MeowNx.real_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Crossover.blend_alpha(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Crossover.blend_alpha(genomes, prng_key, opts) end) end } @@ -251,7 +260,7 @@ defmodule MeowNx.Ops do @doc """ Builds a simulated binary crossover operation. - See `MeowNx.Crossover.simulated_binary/2` for more details. + See `MeowNx.Crossover.simulated_binary/3` for more details. """ @doc type: :crossover @spec crossover_simulated_binary(float()) :: Op.t() @@ -265,7 +274,8 @@ defmodule MeowNx.Ops do in_representations: [MeowNx.real_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Crossover.simulated_binary(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Crossover.simulated_binary(genomes, prng_key, opts) end) end } @@ -274,7 +284,7 @@ defmodule MeowNx.Ops do @doc """ Builds a uniform replacement mutation operation. - See `MeowNx.Mutation.replace_uniform/2` for more details. + See `MeowNx.Mutation.replace_uniform/3` for more details. """ @doc type: :mutation @spec mutation_replace_uniform(float(), float(), float()) :: Op.t() @@ -288,7 +298,8 @@ defmodule MeowNx.Ops do in_representations: [MeowNx.real_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Mutation.replace_uniform(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Mutation.replace_uniform(genomes, prng_key, opts) end) end } @@ -297,7 +308,7 @@ defmodule MeowNx.Ops do @doc """ Builds a bit flip mutation operation. - See `MeowNx.Mutation.replace_uniform/2` for more details. + See `MeowNx.Mutation.replace_uniform/3` for more details. """ @doc type: :mutation @spec mutation_bit_flip(float()) :: Op.t() @@ -311,7 +322,8 @@ defmodule MeowNx.Ops do in_representations: [MeowNx.binary_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Mutation.bit_flip(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Mutation.bit_flip(genomes, prng_key, opts) end) end } @@ -320,7 +332,7 @@ defmodule MeowNx.Ops do @doc """ Builds a Gaussian shift mutation operation. - See `MeowNx.Mutation.shift_gaussian/2` for more details. + See `MeowNx.Mutation.shift_gaussian/3` for more details. """ @doc type: :mutation @spec mutation_shift_gaussian(float(), keyword()) :: Op.t() @@ -334,7 +346,8 @@ defmodule MeowNx.Ops do in_representations: [MeowNx.real_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Mutation.shift_gaussian(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Mutation.shift_gaussian(genomes, prng_key, opts) end) end } diff --git a/lib/meow_nx/ops/permutation.ex b/lib/meow_nx/ops/permutation.ex index 6cb9be9..8a3a817 100644 --- a/lib/meow_nx/ops/permutation.ex +++ b/lib/meow_nx/ops/permutation.ex @@ -9,7 +9,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds a random initializer for the permutation representation. - See `MeowNx.Permutation.init_random/1` for more details. + See `MeowNx.Permutation.init_random/2` for more details. """ @doc type: :init @spec init_permutation_random(non_neg_integer(), non_neg_integer()) :: Op.t() @@ -24,7 +24,8 @@ defmodule MeowNx.Ops.Permutation do out_representation: MeowNx.permutation_representation(), impl: fn population, _ctx -> Meow.Population.map_genomes(population, fn _genomes -> - Permutation.init_random(opts) + prng_key = MeowNx.Utils.prng_key() + Permutation.init_random(prng_key, opts) end) end } @@ -33,7 +34,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds a single point crossover operation adopted for permutations. - See `MeowNx.Permutation.crossover_single_point/1` for more details. + See `MeowNx.Permutation.crossover_single_point/2` for more details. """ @doc type: :crossover @spec crossover_single_point() :: Op.t() @@ -45,7 +46,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.crossover_single_point(genomes) + prng_key = MeowNx.Utils.prng_key() + Permutation.crossover_single_point(genomes, prng_key) end) end } @@ -54,7 +56,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds an order crossover operation. - See `MeowNx.Permutation.crossover_order/1` for more details. + See `MeowNx.Permutation.crossover_order/2` for more details. """ @doc type: :crossover @spec crossover_order() :: Op.t() @@ -66,7 +68,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.crossover_order(genomes) + prng_key = MeowNx.Utils.prng_key() + Permutation.crossover_order(genomes, prng_key) end) end } @@ -75,7 +78,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds an position based crossover operation. - See `MeowNx.Permutation.crossover_position_based/1` for more details. + See `MeowNx.Permutation.crossover_position_based/2` for more details. """ @doc type: :crossover @spec crossover_position_based() :: Op.t() @@ -87,7 +90,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.crossover_position_based(genomes) + prng_key = MeowNx.Utils.prng_key() + Permutation.crossover_position_based(genomes, prng_key) end) end } @@ -96,7 +100,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds a linear order crossover operation. - See `MeowNx.Permutation.crossover_linear_order/1` for more details. + See `MeowNx.Permutation.crossover_linear_order/2` for more details. """ @doc type: :crossover @spec crossover_linear_order() :: Op.t() @@ -108,7 +112,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.crossover_linear_order(genomes) + prng_key = MeowNx.Utils.prng_key() + Permutation.crossover_linear_order(genomes, prng_key) end) end } @@ -117,7 +122,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds an inversion mutation operation. - See `MeowNx.Permutation.mutation_inversion/2` for more details. + See `MeowNx.Permutation.mutation_inversion/3` for more details. """ @doc type: :mutation @spec mutation_inversion(float()) :: Op.t() @@ -131,7 +136,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.mutation_inversion(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Permutation.mutation_inversion(genomes, prng_key, opts) end) end } @@ -140,7 +146,7 @@ defmodule MeowNx.Ops.Permutation do @doc """ Builds a swap mutation operation. - See `MeowNx.Permutation.mutation_swap/2` for more details. + See `MeowNx.Permutation.mutation_swap/3` for more details. """ @doc type: :mutation @spec mutation_swap(float()) :: Op.t() @@ -154,7 +160,8 @@ defmodule MeowNx.Ops.Permutation do in_representations: [MeowNx.permutation_representation()], impl: fn population, _ctx -> Population.map_genomes(population, fn genomes -> - Permutation.mutation_swap(genomes, opts) + prng_key = MeowNx.Utils.prng_key() + Permutation.mutation_swap(genomes, prng_key, opts) end) end } diff --git a/lib/meow_nx/permutation.ex b/lib/meow_nx/permutation.ex index ff58e08..3518935 100644 --- a/lib/meow_nx/permutation.ex +++ b/lib/meow_nx/permutation.ex @@ -102,12 +102,14 @@ defmodule MeowNx.Permutation do * `:length` - the length of a single genome. Required. """ - defn init_random(opts \\ []) do + defn init_random(prng_key, opts \\ []) do opts = keyword!(opts, [:n, :length]) n = opts[:n] length = opts[:length] - Nx.random_uniform({n, length}) + {u, _prng_key} = Nx.Random.uniform(prng_key, shape: {n, length}) + + u |> Nx.argsort(axis: 1) |> Nx.as_type({:u, 16}) end @@ -124,12 +126,14 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 8. """ - defn crossover_single_point(genomes) do + defn crossover_single_point(genomes, prng_key) do {n, length} = Nx.shape(genomes) half_n = div(n, 2) positions = permutations_to_positions(genomes) - split_position = Nx.random_uniform({half_n, 1}, 1, length) |> Utils.duplicate_rows() + {random, _prng_key} = Nx.Random.uniform(prng_key, 1, length, shape: {half_n, 1}) + random = Nx.as_type(random, {:u, 32}) + split_position = Utils.duplicate_rows(random) positions_single_point(positions, positions, split_position) end @@ -150,8 +154,8 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 9. """ - defn crossover_order(genomes) do - {offset, block_length} = random_genome_blocks(genomes, paired: true) + defn crossover_order(genomes, prng_key) do + {offset, block_length, _prng_key} = random_genome_blocks(genomes, prng_key, paired: true) positions = permutations_to_positions(genomes) shifted_positions = shift_positions(positions, -offset) @@ -183,15 +187,18 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 12. * [Crossover operators for permutations equivalence between position and order-based crossover](https://www.researchgate.net/publication/220245134_Crossover_operators_for_permutations_equivalence_between_position_and_order-based_crossover) """ - defn crossover_position_based(genomes) do + defn crossover_position_based(genomes, prng_key) do {n, length} = Nx.shape(genomes) half_n = div(n, 2) - fix_position? = Nx.random_uniform({half_n, length}, 0, 2) |> Utils.duplicate_rows() + {random, _prng_key} = Nx.Random.uniform(prng_key, 0, 2, shape: {half_n, length}) + random = Nx.as_type(random, {:u, 8}) + fix_position? = Utils.duplicate_rows(random) split_position = Nx.sum(fix_position?, axes: [1], keep_axes: true) mapping = fix_position? + |> Nx.shape() |> Nx.iota(axis: 1) |> Nx.add(Nx.negate(fix_position?) * length) |> relative_positions_to_permutations() @@ -216,15 +223,15 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 10. * [Non-Wrapping Order Crossover: An Order Preserving Crossover Operator that Respects Absolute Position](https://www.researchgate.net/publication/220739642_Non-wrapping_order_crossover_An_order_preserving_crossover_operator_that_respects_absolute_position) """ - defn crossover_linear_order(genomes) do - {offset, block_length} = random_genome_blocks(genomes, paired: true) + defn crossover_linear_order(genomes, prng_key) do + {offset, block_length, _prng_key} = random_genome_blocks(genomes, prng_key, paired: true) # A random block divides genome into three parts A B C (where B # is the random block). We want to rearrange the genes into parts # B A C, so that we can fix genes in B. We do this rearrangement # by generating an index mapping - idx = Nx.iota(genomes, axis: 1) + idx = Nx.iota(Nx.shape(genomes), axis: 1) mapping = idx + (idx < block_length) * offset - @@ -248,21 +255,21 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 15. """ - defn mutation_inversion(genomes, opts \\ []) do + defn mutation_inversion(genomes, prng_key, opts \\ []) do opts = keyword!(opts, [:probability]) probability = opts[:probability] - {offset, block_length} = random_genome_blocks(genomes, paired: false) + {offset, block_length, prng_key} = random_genome_blocks(genomes, prng_key, paired: false) a = offset b = offset + block_length - 1 - idx = Nx.iota(genomes, axis: 1) + idx = Nx.iota(Nx.shape(genomes), axis: 1) block? = a <= idx and idx <= b mapping = Nx.select(block?, b - (idx - a), idx) mutated = Nx.take_along_axis(genomes, mapping, axis: 1) - incorporate_mutated(genomes, mutated, probability) + incorporate_mutated(genomes, mutated, probability, prng_key) end @doc """ @@ -282,15 +289,17 @@ defmodule MeowNx.Permutation do * [Genetic Algorithms for Shop Scheduling Problems: A Survey](https://www.researchgate.net/publication/230724890_GENETIC_ALGORITHMS_FOR_SHOP_SCHEDULING_PROBLEMS_A_SURVEY), Fig. 14. """ - defn mutation_swap(genomes, opts \\ []) do + defn mutation_swap(genomes, prng_key, opts \\ []) do opts = keyword!(opts, [:probability]) probability = opts[:probability] {n, length} = Nx.shape(genomes) # Randomly generate two distinct positions - swap_position1 = Nx.random_uniform({n}, 0, length) - swap_position2 = Nx.random_uniform({n}, 0, length - 1) + {swap_position1, prng_key} = Nx.Random.uniform(prng_key, 0, length, shape: {n}) + swap_position1 = Nx.as_type(swap_position1, {:u, 32}) + {swap_position2, prng_key} = Nx.Random.uniform(prng_key, 0, length - 1, shape: {n}) + swap_position2 = Nx.as_type(swap_position2, {:u, 32}) swap_position2 = swap_position2 + (swap_position2 >= swap_position1) swap_positions = Nx.stack([swap_position1, swap_position2], axis: -1) @@ -306,13 +315,14 @@ defmodule MeowNx.Permutation do Nx.reshape(diff, {:auto}) ) - incorporate_mutated(genomes, mutated, probability) + incorporate_mutated(genomes, mutated, probability, prng_key) end - defnp incorporate_mutated(genomes, mutated_genomes, probability) do + defnp incorporate_mutated(genomes, mutated_genomes, probability, prng_key) do {n, length} = Nx.shape(genomes) - mutate? = Nx.random_uniform({n, 1}) |> Nx.less(probability) + {u, _prng_key} = Nx.Random.uniform(prng_key, shape: {n, 1}) + mutate? = u |> Nx.less(probability) mutate? |> Nx.broadcast({n, length}) @@ -348,27 +358,32 @@ defmodule MeowNx.Permutation do |> Nx.take_along_axis(reverse_mapping, axis: 1) end - defnp random_genome_blocks(genomes, opts \\ []) do + defnp random_genome_blocks(genomes, prng_key, opts \\ []) do opts = keyword!(opts, paired: false) + generate_blocks(genomes, prng_key, opts[:paired]) + end - transform({genomes, opts[:paired]}, fn - {genomes, false} -> + deftransformp generate_blocks(genomes, prng_key, paired) do + case paired do + false -> {n, length} = Nx.shape(genomes) - random_blocks(n: n, length: length) + random_blocks(prng_key, n: n, length: length) - {genomes, true} -> + true -> {n, length} = Nx.shape(genomes) - {offset, block_length} = random_blocks(n: div(n, 2), length: length) - {Utils.duplicate_rows(offset), Utils.duplicate_rows(block_length)} - end) + {offset, block_length, prng_key} = random_blocks(prng_key, n: div(n, 2), length: length) + {Utils.duplicate_rows(offset), Utils.duplicate_rows(block_length), prng_key} + end end - defnp random_blocks(opts \\ []) do + defnp random_blocks(prng_key, opts \\ []) do n = opts[:n] length = opts[:length] - idx1 = Nx.random_uniform({n, 1}, 0, length) - idx2 = Nx.random_uniform({n, 1}, 0, length) + {idx1, prng_key} = Nx.Random.uniform(prng_key, 0, length, shape: {n, 1}) + idx1 = Nx.as_type(idx1, {:u, 32}) + {idx2, prng_key} = Nx.Random.uniform(prng_key, 0, length, shape: {n, 1}) + idx2 = Nx.as_type(idx2, {:u, 32}) a = Nx.min(idx1, idx2) b = Nx.max(idx1, idx2) @@ -376,7 +391,7 @@ defmodule MeowNx.Permutation do offset = a block_length = b - a + 1 - {offset, block_length} + {offset, block_length, prng_key} end defnp shift_positions(positions, offset) do diff --git a/lib/meow_nx/representation_spec.ex b/lib/meow_nx/representation_spec.ex index 71c3a6e..1a4e3fa 100644 --- a/lib/meow_nx/representation_spec.ex +++ b/lib/meow_nx/representation_spec.ex @@ -40,7 +40,11 @@ defmodule MeowNx.RepresentationSpec do defn concatenate_tuple(populations) do populations - |> transform(&Tuple.to_list/1) + |> tuple_to_list() |> Nx.concatenate() end + + deftransformp tuple_to_list(tuple) do + Tuple.to_list(tuple) + end end diff --git a/lib/meow_nx/selection.ex b/lib/meow_nx/selection.ex index 1bdaf39..c30d48e 100644 --- a/lib/meow_nx/selection.ex +++ b/lib/meow_nx/selection.ex @@ -26,14 +26,16 @@ defmodule MeowNx.Selection do * `:n` - the number of individuals to select. Required. """ - defn tournament(genomes, fitness, opts \\ []) do + defn tournament(genomes, fitness, prng_key, opts \\ []) do opts = keyword!(opts, [:n]) n = MeowNx.Utils.resolve_n(opts[:n], genomes) {base_n, length} = Nx.shape(genomes) - idx1 = Nx.random_uniform({n}, 0, base_n, type: {:u, 32}) - idx2 = Nx.random_uniform({n}, 0, base_n, type: {:u, 32}) + {idx1, prng_key} = Nx.Random.uniform(prng_key, 0, base_n, shape: {n}) + idx1 = Nx.as_type(idx1, {:u, 32}) + {idx2, _prng_key} = Nx.Random.uniform(prng_key, 0, base_n, shape: {n}) + idx2 = Nx.as_type(idx2, {:u, 32}) parents1 = Nx.take(genomes, idx1) fitness1 = Nx.take(fitness, idx1) @@ -94,7 +96,7 @@ defmodule MeowNx.Selection do * [Fitness proportionate selection](https://en.wikipedia.org/wiki/Fitness_proportionate_selection) """ - defn roulette(genomes, fitness, opts \\ []) do + defn roulette(genomes, fitness, prng_key, opts \\ []) do opts = keyword!(opts, [:n]) n = MeowNx.Utils.resolve_n(opts[:n], genomes) @@ -102,7 +104,8 @@ defmodule MeowNx.Selection do fitness_sum = fitness_cumulative[-1] # Random points on the cumulative ruler - points = Nx.random_uniform({n, 1}, 0, fitness_sum) + {points, _prng_key} = Nx.Random.uniform(prng_key, 0, fitness_sum, shape: {n, 1}) + points = Nx.as_type(points, {:u, 32}) idx = cumulative_points_to_indices(fitness_cumulative, points) take_individuals(genomes, fitness, idx) @@ -125,7 +128,7 @@ defmodule MeowNx.Selection do * [Stochastic universal sampling](https://en.wikipedia.org/wiki/Stochastic_universal_sampling) """ - defn stochastic_universal_sampling(genomes, fitness, opts \\ []) do + defn stochastic_universal_sampling(genomes, fitness, prng_key, opts \\ []) do opts = keyword!(opts, [:n]) n = MeowNx.Utils.resolve_n(opts[:n], genomes) @@ -134,7 +137,8 @@ defmodule MeowNx.Selection do # Random points on the cumulative ruler, each in its own interval step = Nx.divide(fitness_sum, n) - start = Nx.random_uniform({}, 0, step) + {start, _prng_key} = Nx.Random.uniform(prng_key, 0, step, shape: {}) + start = Nx.as_type(start, {:u, 32}) points = Nx.iota({n, 1}) |> Nx.multiply(step) |> Nx.add(start) idx = cumulative_points_to_indices(fitness_cumulative, points) diff --git a/lib/meow_nx/utils.ex b/lib/meow_nx/utils.ex index 2fe4f86..130d920 100644 --- a/lib/meow_nx/utils.ex +++ b/lib/meow_nx/utils.ex @@ -51,13 +51,17 @@ defmodule MeowNx.Utils do """ defn duplicate_rows(t) do {n, m} = Nx.shape(t) - twice_n = transform(n, &(&1 * 2)) + twice_n = double_n(n) t |> Nx.tile([1, 2]) |> Nx.reshape({twice_n, m}) end + deftransformp double_n(n) do + n * 2 + end + @doc """ Returns the cumulative sum of elements in the given 1-dimensional tensor. @@ -88,7 +92,7 @@ defmodule MeowNx.Utils do The resulting tensor has `:shape` with random indices `:axis`. """ - defn random_idx_without_replacement(opts \\ []) do + defn random_idx_without_replacement(prng_key, opts \\ []) do opts = keyword!(opts, [:shape, :min, :max, :axis]) shape = opts[:shape] min = opts[:min] @@ -100,14 +104,26 @@ defmodule MeowNx.Utils do range = max - min - sample_size = transform(shape, &elem(&1, axis)) - random_shape = transform(shape, &put_elem(&1, axis, range)) + sample_size = get_sample_size(shape, axis) + random_shape = get_random_shape(shape, axis, range) + + {u, prng_key} = Nx.Random.uniform(prng_key, shape: random_shape) + + result = + u + |> Nx.argsort(axis: axis) + |> Nx.slice_along_axis(0, sample_size, axis: axis) + |> Nx.add(min) + + {result, prng_key} + end - random_shape - |> Nx.random_uniform() - |> Nx.argsort(axis: axis) - |> Nx.slice_along_axis(0, sample_size, axis: axis) - |> Nx.add(min) + deftransformp get_sample_size(shape, axis) do + elem(shape, axis) + end + + deftransformp get_random_shape(shape, axis, range) do + put_elem(shape, axis, range) end @doc """ @@ -126,7 +142,7 @@ defmodule MeowNx.Utils do t |> Nx.subtract(mean) - |> Nx.power(2) + |> Nx.pow(2) |> Nx.mean() |> Nx.sqrt() end @@ -254,24 +270,25 @@ defmodule MeowNx.Utils do empty = Nx.broadcast(Nx.tensor(0, type: type), permutations) - indices = - transform({permutations, axis}, fn {permutations, axis} -> - permutations - |> Nx.axes() - |> Enum.map(fn - ^axis -> permutations - axis -> Nx.iota(permutations, axis: axis) - end) - |> Nx.stack(axis: -1) - |> Nx.reshape({:auto, Nx.rank(permutations)}) - end) + indices = create_indices(permutations, axis) - iota = permutations |> Nx.iota(type: type, axis: axis) |> Nx.flatten() + iota = permutations |> Nx.shape() |> Nx.iota(type: type, axis: axis) |> Nx.flatten() # We use each permutation as indexing for 1-dimensional iota Nx.indexed_add(empty, indices, iota) end + deftransformp create_indices(permutations, axis) do + permutations + |> Nx.axes() + |> Enum.map(fn + ^axis -> permutations + axis -> Nx.iota(Nx.shape(permutations), axis: axis) + end) + |> Nx.stack(axis: -1) + |> Nx.reshape({:auto, Nx.rank(permutations)}) + end + @doc """ Shifts elements in `tensor` along the given axis. @@ -330,17 +347,17 @@ defmodule MeowNx.Utils do opts = keyword!(opts, axis: 0) axis = opts[:axis] - transform({tensor, offsets, axis}, &validate_shift!/1) + validate_shift!(tensor, offsets, axis) axis_size = Nx.axis_size(tensor, axis) offsets = Nx.new_axis(offsets, axis) - idx = Nx.iota(tensor, axis: axis) + idx = tensor |> Nx.shape() |> Nx.iota(axis: axis) shifted_idx = Nx.remainder(idx - offsets + axis_size, axis_size) Nx.take_along_axis(tensor, shifted_idx, axis: axis) end - defp validate_shift!({tensor, offsets, axis}) do + deftransformp validate_shift!(tensor, offsets, axis) do shape = Nx.shape(tensor) offsets_type = Nx.type(offsets) offsets_shape = Nx.shape(offsets) @@ -362,15 +379,18 @@ defmodule MeowNx.Utils do Asserts `left` has same shape as `right`. """ defn assert_shape!(left, right) do - transform({left, right}, fn {left, right} -> - left_shape = Nx.shape(left) - right_shape = Nx.shape(right) + do_assert_shape!(left, right) + {left, right} + end - unless Elixir.Kernel.==(left_shape, right_shape) do - raise ArgumentError, - "expected tensor shapes to match, but got #{inspect(left_shape)} and #{inspect(right_shape)}" - end - end) + deftransformp do_assert_shape!(left, right) do + left_shape = Nx.shape(left) + right_shape = Nx.shape(right) + + unless Elixir.Kernel.==(left_shape, right_shape) do + raise ArgumentError, + "expected tensor shapes to match, but got #{inspect(left_shape)} and #{inspect(right_shape)}" + end end @doc """ @@ -410,4 +430,15 @@ defmodule MeowNx.Utils do n end + + @doc """ + Returns a random PRNG key. + """ + def prng_key() do + # TODO: currently we use a random key for each opereation call. + # Ideally we would use a single seed to control the whole pipeline + # and pass keys around accordingly. This is not straightforward + # at the moment, because the pipeline is designed to be Nx agnostic + Nx.Random.key(:erlang.system_time()) + end end diff --git a/mix.exs b/mix.exs index cbf5737..81b127c 100644 --- a/mix.exs +++ b/mix.exs @@ -10,7 +10,7 @@ defmodule Meow.MixProject do version: @version, description: @description, name: "Meow", - elixir: "~> 1.13", + elixir: "~> 1.14", deps: deps(), docs: docs() ] @@ -24,7 +24,7 @@ defmodule Meow.MixProject do defp deps do [ - {:nx, "~> 0.3.0"}, + {:nx, "~> 0.7"}, {:vega_lite, "~> 0.1.1", optional: true}, {:jason, "~> 1.2", optional: true}, {:ex_doc, "~> 0.24", only: :dev, runtime: false} diff --git a/mix.lock b/mix.lock index 739995c..8c741b1 100644 --- a/mix.lock +++ b/mix.lock @@ -1,12 +1,14 @@ %{ - "complex": {:hex, :complex, "0.4.2", "923e5db0be13dbb3ea00cf8459d9f75f3afdd9ff5a82742ded21064330d28273", [:mix], [], "hexpm", "069a085ef820ce675a2619fd125b963ff4514af2102c7f7d7965128e5ec0a429"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.28", "0bf6546eb7cd6185ae086cbc5d20cd6dbb4b428aad14c02c49f7b554484b4586", [:mix], [], "hexpm", "501cef12286a3231dc80c81352a9453decf9586977f917a96e619293132743fb"}, - "ex_doc": {:hex, :ex_doc, "0.28.5", "3e52a6d2130ce74d096859e477b97080c156d0926701c13870a4e1f752363279", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "d2c4b07133113e9aa3e9ba27efb9088ba900e9e51caa383919676afdf09ab181"}, - "jason": {:hex, :jason, "1.3.0", "fa6b82a934feb176263ad2df0dbd91bf633d4a46ebfdffea0c8ae82953714946", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "53fc1f51255390e0ec7e50f9cb41e751c260d065dcba2bf0d08dc51a4002c2ac"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, - "nx": {:hex, :nx, "0.3.0", "12448769619ed1442e22bd198faf4629ae3d0ee3ed565e42be49829cc42bf35f", [:mix], [{:complex, "~> 0.4.2", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "d08d8962f4379ade54281aad84c45cb7eb0118b406fc836f07b94acc40df5859"}, - "vega_lite": {:hex, :vega_lite, "0.1.3", "38eeb47d66a881086d0e596b592e3aa75084115d00315c387716c131684f3513", [:mix], [], "hexpm", "ea6e8f951944144b15f26d0b33c255c2b1b553a65e3fffcd91784bfdc56ebb54"}, + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, + "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, + "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, + "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, + "nx": {:hex, :nx, "0.7.3", "51ff45d9f9ff58b616f4221fa54ccddda98f30319bb8caaf86695234a469017a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "5ff29af84f08db9bda66b8ef7ce92ab583ab4f983629fe00b479f1e5c7c705a6"}, + "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "vega_lite": {:hex, :vega_lite, "0.1.9", "d7a288665f916181b68d0a3617f1b3611d16a4dcd5fafb51b847b71db1159d4c", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "c6a056e763162198e73ae6dfb46c09753bb0298474410fd085074e1cdcee7418"}, } diff --git a/notebooks/metrics.livemd b/notebooks/metrics.livemd index f3a9d05..072d3e7 100644 --- a/notebooks/metrics.livemd +++ b/notebooks/metrics.livemd @@ -3,8 +3,8 @@ ```elixir Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"}, + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"}, {:vega_lite, "~> 0.1.6"}, {:kino_vega_lite, "~> 0.1.4"} ]) @@ -36,7 +36,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/notebooks/rastrigin_intro.livemd b/notebooks/rastrigin_intro.livemd index 6a828ae..c70ad70 100644 --- a/notebooks/rastrigin_intro.livemd +++ b/notebooks/rastrigin_intro.livemd @@ -3,8 +3,8 @@ ```elixir Mix.install([ {:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"}, - {:nx, "~> 0.3.0"}, - {:exla, "~> 0.3.0"} + {:nx, "~> 0.7.0"}, + {:exla, "~> 0.7.0"} ]) Nx.Defn.global_default_options(compiler: EXLA) @@ -55,7 +55,7 @@ defmodule Problem do defn evaluate_rastrigin(genomes) do sums = - (10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) + (10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi)) |> Nx.sum(axes: [1]) -sums diff --git a/test/meow_nx/crossover_test.exs b/test/meow_nx/crossover_test.exs index 002775e..2d37150 100644 --- a/test/meow_nx/crossover_test.exs +++ b/test/meow_nx/crossover_test.exs @@ -8,7 +8,8 @@ defmodule MeowNx.CrossoverTest do genomes = Nx.tensor([[1, 2], [1, 2]]) assert_raise ArgumentError, "2-point crossover is not valid for genome of length 2", fn -> - Crossover.multi_point(genomes, points: 2) + prng_key = MeowNx.Utils.prng_key() + Crossover.multi_point(genomes, prng_key, points: 2) end end @@ -21,7 +22,9 @@ defmodule MeowNx.CrossoverTest do [-11, -22, -33, -44] ]) - assert Crossover.multi_point(genomes, points: 3) == + prng_key = MeowNx.Utils.prng_key() + + assert Crossover.multi_point(genomes, prng_key, points: 3) == Nx.tensor([ [1, -2, 3, -4], [-1, 2, -3, 4], @@ -31,10 +34,15 @@ defmodule MeowNx.CrossoverTest do end test "property: the number of crossover points matches the specified one" do - genomes = Nx.random_uniform({100, 100}) + {genomes, _seed} = + 0 + |> Nx.Random.key() + |> Nx.Random.uniform(shape: {100, 100}) + points = 10 - offsprings = Crossover.multi_point(genomes, points: points) + prng_key = MeowNx.Utils.prng_key() + offsprings = Crossover.multi_point(genomes, prng_key, points: points) same_gene? = Nx.equal(genomes, offsprings)