Skip to content

Commit

Permalink
Upgrade Nx (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshprice authored Jul 30, 2024
1 parent bb1c046 commit 5be2fff
Show file tree
Hide file tree
Showing 25 changed files with 301 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
uses: erlef/setup-beam@v1
with:
otp-version: "24.0"
elixir-version: "1.13.0"
elixir-version: "1.14.0"
- name: Cache Mix
uses: actions/cache@v2
with:
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ You can define the algorithm in a single Elixir script file like this:

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand All @@ -32,7 +32,7 @@ defmodule Problem do

defn evaluate_rastrigin(genomes) do
sums =
(10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
(10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
|> Nx.sum(axes: [1])

-sums
Expand Down
6 changes: 3 additions & 3 deletions examples/distributed.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand All @@ -23,7 +23,7 @@ defmodule Problem do

defn evaluate_rastrigin(genomes) do
sums =
(10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
(10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
|> Nx.sum(axes: [1])

-sums
Expand Down
6 changes: 3 additions & 3 deletions examples/intro.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand All @@ -19,7 +19,7 @@ defmodule Problem do

defn evaluate_rastrigin(genomes) do
sums =
(10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
(10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
|> Nx.sum(axes: [1])

-sums
Expand Down
4 changes: 2 additions & 2 deletions examples/knapsack.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand Down
6 changes: 3 additions & 3 deletions examples/metrics.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"},
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"},
{:vega_lite, "~> 0.1.1"},
{:jason, "~> 1.4"}
])
Expand All @@ -17,7 +17,7 @@ defmodule Problem do

defn evaluate_rastrigin(genomes) do
sums =
(10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
(10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
|> Nx.sum(axes: [1])

-sums
Expand Down
4 changes: 2 additions & 2 deletions examples/one_max.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand Down
6 changes: 3 additions & 3 deletions examples/rastrigin.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand All @@ -17,7 +17,7 @@ defmodule Problem do

defn evaluate(genomes) do
sums =
(10 + Nx.power(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
(10 + Nx.pow(genomes, 2) - 10 * Nx.cos(genomes * @two_pi))
|> Nx.sum(axes: [1])

-sums
Expand Down
8 changes: 4 additions & 4 deletions examples/tsp.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

Mix.install([
{:meow, "~> 0.1.0-dev", github: "jonatanklosko/meow"},
{:nx, "~> 0.3.0"},
{:exla, "~> 0.3.0"},
{:req, "~> 0.3.0"}
{:nx, "~> 0.7.0"},
{:exla, "~> 0.7.0"},
{:req, "~> 0.5.0"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Expand Down Expand Up @@ -56,7 +56,7 @@ defmodule Problem do
end

{tsp_size, distances} =
Problem.load_distances("https://people.sc.fsu.edu/~jburkardt/datasets/tsp/fri26_d.txt")
Problem.load_distances("https://raw.githubusercontent.com/dada8397/mpi_tsp/master/fri26_d.txt")

algorithm =
Meow.objective(&Problem.evaluate(&1, distances))
Expand Down
4 changes: 2 additions & 2 deletions lib/meow/ops.ex
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ defmodule Meow.Ops do
n when is_integer(n) ->
n..n

min..max ->
min..max
min..max//step ->
min..max//step

other ->
raise ArgumentError,
Expand Down
57 changes: 27 additions & 30 deletions lib/meow_nx/crossover.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ defmodule MeowNx.Crossover do
* [On the Virtues of Parameterized Uniform Crossover](http://www.mli.gmu.edu/papers/91-95/91-18.pdf)
"""
defn uniform(parents, opts \\ []) do
defn uniform(parents, prng_key, opts \\ []) do
opts = keyword!(opts, probability: 0.5)
probability = opts[:probability]

Expand All @@ -49,8 +49,10 @@ defmodule MeowNx.Crossover do

swapped_parents = Utils.swap_adjacent_rows(parents)

{random, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length})

swap? =
Nx.random_uniform({half_n, length})
random
|> Nx.less_equal(probability)
|> Utils.duplicate_rows()

Expand All @@ -64,20 +66,19 @@ defmodule MeowNx.Crossover do
a random split point and swaps all genes $x_i$, $y_i$
on one side of that split point.
"""
defn single_point(parents) do
defn single_point(parents, prng_key) do
{n, length} = Nx.shape(parents)
half_n = div(n, 2)

swapped_parents = Utils.swap_adjacent_rows(parents)

# Generate n / 2 split points (like [5, 2, 3]), and replicate
# them for adjacent parents (like [5, 5, 2, 2, 3, 3])
split_idx =
Nx.random_uniform({half_n, 1}, 1, length)
|> Utils.duplicate_rows()
{random, _prng_key} = Nx.Random.uniform(prng_key, 1, length, shape: {half_n, 1})
random = Nx.as_type(random, {:u, 32})
split_idx = Utils.duplicate_rows(random)

swap? = Nx.less_equal(split_idx, Nx.iota({1, length}))

Nx.select(swap?, swapped_parents, parents)
end

Expand All @@ -92,28 +93,23 @@ defmodule MeowNx.Crossover do
* `:points` - the number of crossover points
"""
defn multi_point(parents, opts \\ []) do
defn multi_point(parents, prng_key, opts \\ []) do
opts = keyword!(opts, [:points])
points = opts[:points]

{n, length} = Nx.shape(parents)
half_n = div(n, 2)

transform({length, points}, fn {length, points} ->
unless Elixir.Kernel.<(points, length) do
raise ArgumentError,
"#{points}-point crossover is not valid for genome of length #{length}"
end
end)
validate_crossover(length, points)

swapped_parents = Utils.swap_adjacent_rows(parents)

# For each pair of parents we generate k unique crossover points,
# then we convert each of them to 1-point crossover mask and finally
# combine these masks using gen-wise XOR (sum modulo 2)

split_idx =
Utils.random_idx_without_replacement(
{split_idx, _prng_key} =
Utils.random_idx_without_replacement(prng_key,
shape: {half_n, points, 1},
min: 1,
max: length,
Expand All @@ -129,6 +125,13 @@ defmodule MeowNx.Crossover do
Nx.select(swap?, swapped_parents, parents)
end

deftransformp validate_crossover(length, points) do
unless Elixir.Kernel.<(points, length) do
raise ArgumentError,
"#{points}-point crossover is not valid for genome of length #{length}"
end
end

@doc """
Performs blend-alpha crossover, also referred to as BLX-alpha.
Expand Down Expand Up @@ -161,7 +164,7 @@ defmodule MeowNx.Crossover do
* [Tackling Real-Coded Genetic Algorithms: Operators and Tools for Behavioural Analysis](https://sci2s.ugr.es/sites/default/files/files/ScientificImpact/AIRE12-1998.PDF), Section 4.3
* [Multiobjective Evolutionary Algorithms forElectric Power Dispatch Problem](https://www.researchgate.net/figure/Blend-crossover-operator-BLX_fig1_226044085), Fig. 1.
"""
defn blend_alpha(parents, opts \\ []) do
defn blend_alpha(parents, prng_key, opts \\ []) do
opts = keyword!(opts, alpha: 0.5)
alpha = opts[:alpha]

Expand All @@ -175,7 +178,8 @@ defmodule MeowNx.Crossover do
# may be negative (for y < x), but then we shift from x
# in the opposite direction, so it works as expected.

gamma = (1 + 2 * alpha) * Nx.random_uniform({half_n, length}) - alpha
{u, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length})
gamma = (1 + 2 * alpha) * u - alpha
gamma = Utils.duplicate_rows(gamma)

x + gamma * (y - x)
Expand All @@ -200,7 +204,7 @@ defmodule MeowNx.Crossover do
* [Self-Adaptive Genetic Algorithms with Simulated Binary Crossover](https://eldorado.tu-dortmund.de/bitstream/2003/5370/1/ci61.pdf)
* [Engineering Analysis and Design Using Genetic Algorithms / Lecture 4: Real-Coded Genetic Algorithms](https://engineering.purdue.edu/~sudhoff/ee630/Lecture04.pdf)
"""
defn simulated_binary(parents, opts \\ []) do
defn simulated_binary(parents, prng_key, opts \\ []) do
opts = keyword!(opts, [:eta])
eta = opts[:eta]

Expand All @@ -209,17 +213,10 @@ defmodule MeowNx.Crossover do

{x, y} = {parents, Utils.swap_adjacent_rows(parents)}

beta_base =
Nx.random_uniform({half_n, length})
|> Nx.map(fn u ->
if Nx.less(u, 0.5) do
2 * u
else
1 / (2 * (1 - u))
end
end)

beta = Nx.power(beta_base, 1 / (eta + 1))
{u, _prng_key} = Nx.Random.uniform(prng_key, shape: {half_n, length})
beta_base = Nx.select(u < 0.5, 2 * u, 1 / (2 * (1 - u)))

beta = Nx.pow(beta_base, 1 / (eta + 1))
beta = Utils.duplicate_rows(beta)

0.5 * ((1 + beta) * x + (1 - beta) * y)
Expand Down
12 changes: 8 additions & 4 deletions lib/meow_nx/init.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@ defmodule MeowNx.Init do
* `:max` - the maximum possible value of a gene. Required.
"""
defn real_random_uniform(opts \\ []) do
defn real_random_uniform(prng_key, opts \\ []) do
opts = keyword!(opts, [:n, :length, :min, :max])
n = opts[:n]
length = opts[:length]
min = opts[:min]
max = opts[:max]

Nx.random_uniform({n, length}, min, max, type: {:f, 64})
{random, _prng_key} =
Nx.Random.uniform(prng_key, min, max, shape: {n, length}, type: {:f, 64})

random
end

@doc """
Expand All @@ -42,11 +45,12 @@ defmodule MeowNx.Init do
* `:length` - the length of a single genome. Required.
"""
defn binary_random_uniform(opts \\ []) do
defn binary_random_uniform(prng_key, opts \\ []) do
opts = keyword!(opts, [:n, :length])
n = opts[:n]
length = opts[:length]

Nx.random_uniform({n, length}, 0, 2) |> Nx.as_type({:u, 8})
{random, _prng_key} = Nx.Random.uniform(prng_key, 0, 2, shape: {n, length})
Nx.as_type(random, {:u, 8})
end
end
13 changes: 8 additions & 5 deletions lib/meow_nx/metric.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,18 @@ defmodule MeowNx.Metric do
defn fitness_entropy(_genomes, fitness, opts \\ []) do
opts = keyword!(opts, [:precision])

values =
transform({fitness, opts[:precision]}, fn
{fitness, nil} -> fitness
{fitness, precision} -> fitness |> Nx.divide(precision) |> Nx.round()
end)
values = transform_fitness(fitness, opts[:precision])

MeowNx.Utils.entropy(values)
end

deftransformp transform_fitness(fitness, precision) do
case precision do
nil -> fitness
_ -> fitness |> Nx.divide(precision) |> Nx.round()
end
end

@doc """
Calculates the mean Euclidean distance between each pair
of genomes.
Expand Down
Loading

0 comments on commit 5be2fff

Please sign in to comment.