Skip to content

Commit

Permalink
Improve three_nodes experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
joshniemela committed Dec 28, 2024
1 parent b9ecf63 commit 0495450
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
4 changes: 3 additions & 1 deletion three_nodes_classification/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ def generate_three_nodes_dataset(n: int, device="cpu") -> list[Data]:

node_values = torch.empty(n, 3, dtype=torch.float, device=device)

# Set the two side nodes to something between 0 and 10
# Set the two side nodes to something between 0 and 9
node_values[:, 1:] = torch.randint(10, (n, 2)).float()

# Set the root node to the sum of the two side nodes
node_values[:, 0] = torch.sum(node_values[:, 1:], 1)

# we replace half of the root values with random values
mask = torch.rand(n) > 0.5

# random from 0 to 18
node_values[mask, 0] = torch.randint(10 * 2 - 1, (mask.sum(),)).float()

# We need to reshape the tensor to have the correct shape
Expand Down
6 changes: 6 additions & 0 deletions three_nodes_regression/summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ results = CSV.read(joinpath(results_folder, "runs.csv"), DataFrame)

normalised = filter(row -> row.normalise, results)
not_normalised = filter(row -> !row.normalise, results)

print("Normalised")
print(describe(normalised))

print("Not normalised")
print(describe(not_normalised))

0 comments on commit 0495450

Please sign in to comment.