diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java index 0e39b016adc..f50d940b1a9 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java @@ -25,7 +25,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; import org.neo4j.gds.Orientation; import org.neo4j.gds.api.Graph; @@ -318,33 +317,32 @@ void testConvergence() { assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2); } - @ParameterizedTest - @CsvSource({ - "0.01, false, 10", - "1.0, true, 7" - }) - void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) { - var trainer = new GraphSageModelTrainer( - configBuilder.modelName("convergingModel:)") - .maybeBatchSamplingRatio(batchSamplingRatio) - .embeddingDimension(12) - .aggregator(AggregatorType.POOL) - .epochs(10) - .tolerance(1e-10) - .sampleSizes(List.of(5, 3)) - .batchSize(5) - .maxIterations(100) - .randomSeed(42L) - .build(), + @Test + void batchesPerIteration() { + configBuilder.modelName("convergingModel:)") + .embeddingDimension(2) + .aggregator(AggregatorType.POOL) + .epochs(10) + .tolerance(1e-5) + .sampleSizes(List.of(1)) + .batchSize(5) + .maxIterations(100) + .randomSeed(42L); + + var trainResultWithoutSampling = new GraphSageModelTrainer( + configBuilder.maybeBatchSamplingRatio(1.0).build(), Pools.DEFAULT, ProgressTracker.NULL_TRACKER - ); + ).train(unweightedGraph, features); - var trainResult = trainer.train(unweightedGraph, features); + var trainResultWithSampling = new GraphSageModelTrainer( + configBuilder.maybeBatchSamplingRatio(0.01).build(), + Pools.DEFAULT, + ProgressTracker.NULL_TRACKER + ).train(unweightedGraph, features); - var trainMetrics = trainResult.metrics(); - assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence); - assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs); + // reason: sampling results in more stochastic gradient descent and different losses + assertThat(trainResultWithoutSampling.metrics().epochLosses().get(0)).isNotEqualTo(trainResultWithSampling.metrics().epochLosses().get(0)); } @ParameterizedTest