From 795ae2e01141783503f9a11a685aeb157f295996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Fri, 3 Jun 2022 15:24:20 +0200 Subject: [PATCH] Make GraphSage test more robust Before it failed on MacOS Discovered by Iavkan in #199 --- .../graphsage/GraphSageModelTrainerTest.java | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java index 0e39b016adc..f50d940b1a9 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java @@ -25,7 +25,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; import org.neo4j.gds.Orientation; import org.neo4j.gds.api.Graph; @@ -318,33 +317,32 @@ void testConvergence() { assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2); } - @ParameterizedTest - @CsvSource({ - "0.01, false, 10", - "1.0, true, 7" - }) - void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) { - var trainer = new GraphSageModelTrainer( - configBuilder.modelName("convergingModel:)") - .maybeBatchSamplingRatio(batchSamplingRatio) - .embeddingDimension(12) - .aggregator(AggregatorType.POOL) - .epochs(10) - .tolerance(1e-10) - .sampleSizes(List.of(5, 3)) - .batchSize(5) - .maxIterations(100) - .randomSeed(42L) - .build(), + @Test + void batchesPerIteration() { + configBuilder.modelName("convergingModel:)") + .embeddingDimension(2) + .aggregator(AggregatorType.POOL) + .epochs(10) + .tolerance(1e-5) + .sampleSizes(List.of(1)) + .batchSize(5) + .maxIterations(100) + .randomSeed(42L); + + var trainResultWithoutSampling = new GraphSageModelTrainer( + configBuilder.maybeBatchSamplingRatio(1.0).build(), Pools.DEFAULT, ProgressTracker.NULL_TRACKER - ); + ).train(unweightedGraph, features); - var trainResult = trainer.train(unweightedGraph, features); + var trainResultWithSampling = new GraphSageModelTrainer( + configBuilder.maybeBatchSamplingRatio(0.01).build(), + Pools.DEFAULT, + ProgressTracker.NULL_TRACKER + ).train(unweightedGraph, features); - var trainMetrics = trainResult.metrics(); - assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence); - assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs); + // reason: sampling results in more stochastic gradient descent and different losses + assertThat(trainResultWithoutSampling.metrics().epochLosses().get(0)).isNotEqualTo(trainResultWithSampling.metrics().epochLosses().get(0)); } @ParameterizedTest