diff --git a/.gitignore b/.gitignore index 22b845f..459e274 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ data/* *.png datasets/* .idea/ +*.svg diff --git a/poetry.lock b/poetry.lock index f2a1edf..00d5356 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1615,6 +1615,67 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pyqt5" +version = "5.15.11" +description = "Python bindings for the Qt cross platform application toolkit" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c8b03dd9380bb13c804f0bdb0f4956067f281785b5e12303d529f0462f9afdc2"}, + {file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_x86_64.whl", hash = "sha256:6cd75628f6e732b1ffcfe709ab833a0716c0445d7aec8046a48d5843352becb6"}, + {file = "PyQt5-5.15.11-cp38-abi3-manylinux_2_17_x86_64.whl", hash = "sha256:cd672a6738d1ae33ef7d9efa8e6cb0a1525ecf53ec86da80a9e1b6ec38c8d0f1"}, + {file = "PyQt5-5.15.11-cp38-abi3-win32.whl", hash = "sha256:76be0322ceda5deecd1708a8d628e698089a1cea80d1a49d242a6d579a40babd"}, + {file = "PyQt5-5.15.11-cp38-abi3-win_amd64.whl", hash = "sha256:bdde598a3bb95022131a5c9ea62e0a96bd6fb28932cc1619fd7ba211531b7517"}, + {file = "PyQt5-5.15.11.tar.gz", hash = "sha256:fda45743ebb4a27b4b1a51c6d8ef455c4c1b5d610c90d2934c7802b5c1557c52"}, +] + +[package.dependencies] +PyQt5-Qt5 = ">=5.15.2,<5.16.0" +PyQt5-sip = ">=12.15,<13" + +[[package]] +name = "pyqt5-qt5" +version = "5.15.16" +description = "The subset of a Qt installation needed by PyQt5." +optional = false +python-versions = "*" +files = [ + {file = "PyQt5_Qt5-5.15.16-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:18b6fec012de60921fcb131cf2a21368171dc29050d43e4b81a64be407a36105"}, + {file = "PyQt5_Qt5-5.15.16-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e1a0e7ae35a7615c74a293705204579650930486a89af23082462f429dae504a"}, + {file = "PyQt5_Qt5-5.15.16-py3-none-manylinux2014_x86_64.whl", hash = "sha256:5ee1754a6460849cba76c0f0c490c0ccc3b514abc780b141cf772db22b76b54b"}, +] + +[[package]] +name = "pyqt5-sip" +version = "12.16.1" +description = "The sip module support for PyQt5" +optional = false +python-versions = ">=3.9" +files = [ + {file = "PyQt5_sip-12.16.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:72e8be41f053cd2338d373b70cfa8f725f05c7551e014161ae0484c7ffdb5b3d"}, + {file = "PyQt5_sip-12.16.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:389ea632df16794325652fe1a84d4ea77c7ed0240040979dada725473ad1d875"}, + {file = "PyQt5_sip-12.16.1-cp310-cp310-win32.whl", hash = "sha256:b3bfe33e849818a32164fc19346fc889a47ba8b23f803508eac9a5d9d06f59d9"}, + {file = "PyQt5_sip-12.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:7fd6fbff57ba2cda32f1d5ea49500cff6f29e1cafdf41f40b99ed744bdac14a0"}, + {file = "PyQt5_sip-12.16.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:09bfdb5f9adea15a542cbe4b89873a6b290c4f1669f66bb5f1a24993ce8bbdd0"}, + {file = "PyQt5_sip-12.16.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:98b99fcdebbbfc999f4ab10829749151eb371b79201ecd98f20e934c16d0193e"}, + {file = "PyQt5_sip-12.16.1-cp311-cp311-win32.whl", hash = "sha256:67dbdc1b3be045caebfc75ee87966e23c6bee61d94cb634ddd71b634c9089890"}, + {file = "PyQt5_sip-12.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:8afd633d5f35e4e5205680d310800d10d30fcbfb6bb7b852bfaa31097c1be449"}, + {file = "PyQt5_sip-12.16.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f6724c590de3d556c730ebda8b8f906b38373934472209e94d99357b52b56f5f"}, + {file = "PyQt5_sip-12.16.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:633cba509a98bd626def951bb948d4e736635acbd0b7fabd7be55a3a096a8a0b"}, + {file = "PyQt5_sip-12.16.1-cp312-cp312-win32.whl", hash = "sha256:2b35ff92caa569e540675ffcd79ffbf3e7092cccf7166f89e2a8b388db80aa1c"}, + {file = "PyQt5_sip-12.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0f83f554727f43dfe92afbf3a8c51e83bb8b78c5f160b635d4359fad681cebe"}, + {file = "PyQt5_sip-12.16.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2349f118dc6f01ee71fe57d8bab9e606ecf241468989abb23b5691d5538d7a69"}, + {file = "PyQt5_sip-12.16.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:bbabdcc031e40417333bdd9e59d8add815c8f0663ebfcbcd3024bdeb55f35e9c"}, + {file = "PyQt5_sip-12.16.1-cp313-cp313-win32.whl", hash = "sha256:b6d06f6b49c7cd70db44277e21134390dcabb709da434d63754c9968ff6d98e2"}, + {file = "PyQt5_sip-12.16.1-cp313-cp313-win_amd64.whl", hash = "sha256:e2bd572cfb969089c2813c85d6e1393ec1a0aeecebc53934ba9f062acf440b50"}, + {file = "PyQt5_sip-12.16.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:639d0f3767f17c9b7e2c1161a563f1a8da8dc16a348f9fc24314882c0ba57c5e"}, + {file = "PyQt5_sip-12.16.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c2b99bbd915f7649536b46be6f4740122c0ff1a76dc3dcb0c856c7cc770b3545"}, + {file = "PyQt5_sip-12.16.1-cp39-cp39-win32.whl", hash = "sha256:9f6878df5cca870cd48b406c48136842c02f7b0e2a3b7c47cb84dcddcebc5758"}, + {file = "PyQt5_sip-12.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:ffd748efbc9396a7a72de0d617acfd248c04e02dd28cd80e6bc3bf26214786e7"}, + {file = "pyqt5_sip-12.16.1.tar.gz", hash = "sha256:8c831f8b619811a32369d72339faa50ae53a963f5fdfa4d71f845c63e9673125"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2466,4 +2527,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "f1e36a38c991989d64dd329909d7c39bca5747df6a836f4942465ec858bb3e53" +content-hash = "8d8f0c7f732f5c55ef210430c2bfb10b71daf391f4d7e149ba14d982356598fb" diff --git a/pyproject.toml b/pyproject.toml index 4456cdc..6fceaee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ wandb = "^0.17.9" scipy = "^1.14.1" bayesian-optimization = "^2.0.1" safetensors = "^0.4.5" +pyqt5 = "^5.15.11" [build-system] diff --git a/three_nodes_classification/models.py b/three_nodes_classification/models.py index 40a9d5e..f4bb8e7 100644 --- a/three_nodes_classification/models.py +++ b/three_nodes_classification/models.py @@ -21,7 +21,7 @@ def forward(self, data): class SAGE(nn.Module): def __init__(self, normalise=False): super(SAGE, self).__init__() - self.conv = SAGEConv(1, 1, normalize=normalise, bias=True, aggr="sum") + self.conv = SAGEConv(1, 1, normalize=normalise, bias=False, aggr="sum") def forward(self, data): x, edge_index, _ = data.x, data.edge_index, data.batch @@ -39,9 +39,11 @@ def __init__(self, normalise=False, activation="simple"): self.conv = SAGEConv(1, 1, normalize=normalise, bias=False, aggr="sum") if activation == "gaussian": - self.activation = lambda x: torch.exp(-torch.pow(x[:, 0], 2)) + self.activation = lambda x: torch.exp(-torch.pow(x, 2)) elif activation == "simple": - self.activation = lambda x: 1 - torch.pow(x[:, 0], 2) + self.activation = lambda x: 1 - torch.pow(x, 2) + elif activation == "sigmoid": + self.activation = torch.sigmoid def forward(self, data): x, edge_index, _ = data.x, data.edge_index, data.batch @@ -50,6 +52,6 @@ def forward(self, data): x = x.view(-1, 3) - out = self.activation(x) + out = self.activation(x[:, 0]) return out diff --git a/three_nodes_classification/plot_loss.py b/three_nodes_classification/plot_loss.py index 3003051..5ee0658 100644 --- a/three_nodes_classification/plot_loss.py +++ b/three_nodes_classification/plot_loss.py @@ -6,8 +6,8 @@ from torch_geometric.loader import DataLoader import matplotlib -# matplotlib.use("Qt5Agg") -matplotlib.use("pgf") +matplotlib.use("Qt5Agg") +# matplotlib.use("pgf") def print_model_parameters(model): @@ -20,16 +20,20 @@ def print_model_parameters(model): data_loader = DataLoader(data_list, batch_size=1028, shuffle=True) # Create a grid of w1 and w2 values -w1_range = np.linspace(-1, 1, 20) -w2_range = np.linspace(-1, 1, 20) +w1_range = np.linspace(-1, 1, 100) +w2_range = np.linspace(-1, 1, 100) W1, W2 = np.meshgrid(w1_range, w2_range) # Initialize your model -model = NonLinearSAGE() +model = NonLinearSAGE(activation="gaussian") criterion = torch.nn.CrossEntropyLoss() Loss = np.zeros((len(w1_range), len(w2_range))) +max_loss = 0 +max_loss_w_1_weight = 0 +max_loss_w_2_weight = 0 + # Calculate loss over the grid for i in range(len(w1_range)): for j in range(len(w2_range)): @@ -42,9 +46,17 @@ def print_model_parameters(model): loss = criterion(out, batch.y) total_loss += loss.item() Loss[i, j] = total_loss / len(data_loader) - + if Loss[i, j] > max_loss: + max_loss = Loss[i, j] + max_loss_w_1_weight = W1[i, j] + max_loss_w_2_weight = W2[i, j] print(f"Finished {i+1}/{len(w1_range)} iterations") +print(f"Max loss: {max_loss:.4f}") +print(f"Max loss w1 weight: {max_loss_w_1_weight:.4f}") +print(f"Max loss w2 weight: {max_loss_w_2_weight:.4f}") + + # Plot the 3D surface fig = plt.figure() ax = fig.add_subplot(111, projection="3d") @@ -53,10 +65,9 @@ def print_model_parameters(model): ax.set_xlabel("w1") ax.set_ylabel("w2") ax.set_zlabel("Loss") -ax.set_title("3D Loss Surface") # high resolution plot -plt.savefig("three_neighbour_classifier_gaussian_activatoin.pgf", dpi=900) +plt.savefig("three_neighbour_classifier_gaussian_activation.svg", dpi=900) # plt.savefig("three_neighbour_classifier_gaussian_activatoin.png", dpi=900) plt.ion() diff --git a/three_nodes_classification/train.py b/three_nodes_classification/train.py index 7a481d0..c0a2bd7 100644 --- a/three_nodes_classification/train.py +++ b/three_nodes_classification/train.py @@ -6,7 +6,7 @@ from torch.optim import AdamW # Set the seed for reproducibility -torch.manual_seed(0) +# torch.manual_seed(29) # Training function @@ -31,7 +31,7 @@ def evaluate(model, data_loader): with torch.no_grad(): for batch in data_loader: outputs = model(batch) - predicted = (outputs > 0.5).float() + predicted = (outputs.abs() > 0.5).float() total += batch.y.size(0) correct += (predicted == batch.y).sum().item() accuracy = 100 * correct / total @@ -48,17 +48,16 @@ def main( num_epochs=100, num_runs=5, ): - # Generate training and test datasets - train_data_list = generate_three_nodes_dataset(samples) - test_data_list = generate_three_nodes_dataset(samples) - train_loader = DataLoader(train_data_list, batch_size=batch_size, shuffle=True) - test_loader = DataLoader(test_data_list, batch_size=batch_size, shuffle=True) accuracies = [] print(f"Training model: {model_type}, Normalise: {normalise}") for run in range(num_runs): + train_data_list = generate_three_nodes_dataset(samples) + test_data_list = generate_three_nodes_dataset(samples) + train_loader = DataLoader(train_data_list, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_data_list, batch_size=batch_size, shuffle=True) # Select model type based on user input if model_type == "GCN": model = GCN(normalise=normalise)