Skip to content

Commit

Permalink
Make GraphSage test more robust
Browse files Browse the repository at this point in the history
Before it failed on MacOS

Discovered by Iavkan in #199
  • Loading branch information
FlorentinD committed Jun 3, 2022
1 parent 632893b commit 795ae2e
Showing 1 changed file with 22 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.api.Graph;
Expand Down Expand Up @@ -318,33 +317,32 @@ void testConvergence() {
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
}

@ParameterizedTest
@CsvSource({
"0.01, false, 10",
"1.0, true, 7"
})
void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) {
var trainer = new GraphSageModelTrainer(
configBuilder.modelName("convergingModel:)")
.maybeBatchSamplingRatio(batchSamplingRatio)
.embeddingDimension(12)
.aggregator(AggregatorType.POOL)
.epochs(10)
.tolerance(1e-10)
.sampleSizes(List.of(5, 3))
.batchSize(5)
.maxIterations(100)
.randomSeed(42L)
.build(),
@Test
void batchesPerIteration() {
configBuilder.modelName("convergingModel:)")
.embeddingDimension(2)
.aggregator(AggregatorType.POOL)
.epochs(10)
.tolerance(1e-5)
.sampleSizes(List.of(1))
.batchSize(5)
.maxIterations(100)
.randomSeed(42L);

var trainResultWithoutSampling = new GraphSageModelTrainer(
configBuilder.maybeBatchSamplingRatio(1.0).build(),
Pools.DEFAULT,
ProgressTracker.NULL_TRACKER
);
).train(unweightedGraph, features);

var trainResult = trainer.train(unweightedGraph, features);
var trainResultWithSampling = new GraphSageModelTrainer(
configBuilder.maybeBatchSamplingRatio(0.01).build(),
Pools.DEFAULT,
ProgressTracker.NULL_TRACKER
).train(unweightedGraph, features);

var trainMetrics = trainResult.metrics();
assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence);
assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs);
// reason: sampling results in more stochastic gradient descent and different losses
assertThat(trainResultWithoutSampling.metrics().epochLosses().get(0)).isNotEqualTo(trainResultWithSampling.metrics().epochLosses().get(0));
}

@ParameterizedTest
Expand Down

0 comments on commit 795ae2e

Please sign in to comment.