diff --git a/src/main/scala/MockedStreams.scala b/src/main/scala/MockedStreams.scala index 52eab63..8886251 100644 --- a/src/main/scala/MockedStreams.scala +++ b/src/main/scala/MockedStreams.scala @@ -19,7 +19,7 @@ package com.madewithtea.mockedstreams import java.util.{Properties, UUID} import org.apache.kafka.common.serialization.Serde -import org.apache.kafka.streams.{StreamsBuilder, StreamsConfig} +import org.apache.kafka.streams.{StreamsBuilder, StreamsConfig, Topology} import org.apache.kafka.streams.state.ReadOnlyWindowStore import org.apache.kafka.test.{ProcessorTopologyTestDriver => Driver} @@ -31,14 +31,23 @@ object MockedStreams { case class Record(topic: String, key: Array[Byte], value: Array[Byte]) - case class Builder(topology: Option[(StreamsBuilder => Unit)] = None, + case class Builder(topology: Option[() => Topology] = None, configuration: Properties = new Properties(), stateStores: Seq[String] = Seq(), inputs: List[Record] = List.empty) { def config(configuration: Properties) = this.copy(configuration = configuration) - def topology(func: (StreamsBuilder => Unit)) = this.copy(topology = Some(func)) + def topology(func: (StreamsBuilder => Unit)) = { + val buildTopology = () => { + val builder = new StreamsBuilder() + func(builder) + builder.build() + } + this.copy(topology = Some(buildTopology)) + } + + def withTopology(t: () => Topology) = this.copy(topology = Some(t)) def stores(stores: Seq[String]) = this.copy(stateStores = stores) @@ -93,14 +102,9 @@ object MockedStreams { props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092") props.putAll(configuration) - val builder = new StreamsBuilder() - - topology match { - case Some(t) => t(builder) - case _ => throw new NoTopologySpecified - } + val t = topology.getOrElse(throw new NoTopologySpecified) - new Driver(new StreamsConfig(props), builder.build()) + new Driver(new StreamsConfig(props), t()) } private def produce(driver: Driver): Unit = { diff --git a/src/test/scala/MockedStreamsSpec.scala b/src/test/scala/MockedStreamsSpec.scala index c0eecf6..4e15b62 100644 --- a/src/test/scala/MockedStreamsSpec.scala +++ b/src/test/scala/MockedStreamsSpec.scala @@ -20,7 +20,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.kafka.common.serialization.Serdes import org.apache.kafka.streams.kstream._ import org.apache.kafka.streams.processor.TimestampExtractor -import org.apache.kafka.streams.{Consumed, StreamsBuilder, KeyValue, StreamsConfig} +import org.apache.kafka.streams._ import org.scalatest.{FlatSpec, Matchers} class MockedStreamsSpec extends FlatSpec with Matchers { @@ -156,6 +156,23 @@ class MockedStreamsSpec extends FlatSpec with Matchers { .shouldEqual(expectedCy.toMap) } + it should "accept already built topology" in { + import Fixtures.Uppercase._ + + def getTopology() = { + val builder = new StreamsBuilder() + topology(builder) + builder.build() + } + + val output = MockedStreams() + .withTopology(getTopology) + .input(InputTopic, strings, strings, input) + .output(OutputTopic, strings, strings, expected.size) + + output shouldEqual expected + } + class LastInitializer extends Initializer[Integer] { override def apply() = 0 }