Skip to content

Commit

Permalink
Generate Kotlin-esque build function for using closures rather than B…
Browse files Browse the repository at this point in the history
…uilders
  • Loading branch information
Ryan O'Neill committed Oct 11, 2023
1 parent b9f6835 commit 094fa67
Show file tree
Hide file tree
Showing 14 changed files with 844 additions and 0 deletions.
1 change: 1 addition & 0 deletions gen-tests.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ val generateKotlinBuildersOnlyTests by tasks.creating(JavaExec::class) {
"--kotlin_out=wire-tests/src/commonTest/proto-kotlin",
"--kotlin_builders_only",
"redacted_test_builders_only.proto",
"simple_message_builders_only.proto",
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ import com.squareup.kotlinpoet.INT
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.KModifier.ABSTRACT
import com.squareup.kotlinpoet.KModifier.CONST
import com.squareup.kotlinpoet.KModifier.INLINE
import com.squareup.kotlinpoet.KModifier.OVERRIDE
import com.squareup.kotlinpoet.KModifier.PRIVATE
import com.squareup.kotlinpoet.KModifier.PUBLIC
import com.squareup.kotlinpoet.LONG
import com.squareup.kotlinpoet.LambdaTypeName
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.MemberName.Companion.member
import com.squareup.kotlinpoet.NOTHING
Expand Down Expand Up @@ -548,6 +550,7 @@ class KotlinGenerator private constructor(

addDefaultFields(type, companionBuilder, nameAllocator)
addAdapter(type, companionBuilder)
if (buildersOnly || javaInterOp) addBuildClosure(type, companionBuilder, builderClassName)

val classBuilder = TypeSpec.classBuilder(className)
.apply {
Expand Down Expand Up @@ -1968,6 +1971,22 @@ class KotlinGenerator private constructor(
}
}

private fun addBuildClosure(type: MessageType, companionBuilder: TypeSpec.Builder, builderClassName: ClassName) {
val buildFn = FunSpec.builder("build")
.addModifiers(INLINE)
.addParameter("body",
LambdaTypeName.get(
receiver = builderClassName,
returnType = Unit::class.asClassName()
)
)
.addStatement("return %T().apply(body).build()", builderClassName)
.returns(generatedTypeName(type))
.build()

companionBuilder.addFunction(buildFn)
}

private fun redactFun(message: MessageType): FunSpec {
val className = typeToKotlinName.getValue(message.type) as ClassName
val nameAllocator = nameAllocator(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,32 @@ class KotlinGeneratorTest {
assertThat(code).doesNotContain("InnerMessage public constructor(")
}

@Test
fun buildersOnlyOrJavaInteropGeneratesKotlinBuildClosure() {
val schema = buildSchema {
add(
"message.proto".toPath(),
"""
|syntax = "proto2";
|message SomeMessage {
| optional string a = 1;
| optional string b = 2;
|}
|
""".trimMargin(),
)
}
val code = KotlinWithProfilesGenerator(schema)
.generateKotlin("SomeMessage", buildersOnly = false, javaInterop = false)
assertThat(code).doesNotContain("fun build(body: Builder.() -> Unit): SomeMessage {")
val buildersOnlyCode = KotlinWithProfilesGenerator(schema)
.generateKotlin("SomeMessage", buildersOnly = true)
assertThat(buildersOnlyCode).contains("fun build(body: Builder.() -> Unit): SomeMessage {")
val javaInteropCode = KotlinWithProfilesGenerator(schema)
.generateKotlin("SomeMessage", javaInterop = true)
assertThat(javaInteropCode).contains("fun build(body: Builder.() -> Unit): SomeMessage {")
}

@Test
fun javaInteropAndBuildersOnly() {
val schema = buildSchema {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.squareup.wire

import com.squareup.wire.protos.kotlin.simple.buildersonly.SimpleMessage
import kotlin.test.Test
import kotlin.test.assertEquals

class KotlinBuildTest {
@Test fun kotlinBuildEquivalentToBuilderResult() {
val bb = 100

val kotlinBuildResult = SimpleMessage.build {
required_int32 = 4
optional_int32 = 5
optional_nested_msg = SimpleMessage.NestedMessage.build {
this.bb = bb
}
}
val builderResult = SimpleMessage.Builder()
.required_int32(4)
.optional_int32(5)
.optional_nested_msg(SimpleMessage.NestedMessage.Builder()
.bb(bb)
.build()
).build()

assertEquals(kotlinBuildResult, builderResult)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class NotRedacted private constructor(
Expand Down Expand Up @@ -148,5 +149,7 @@ public class NotRedacted private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): NotRedacted = Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class RedactedChild private constructor(
Expand Down Expand Up @@ -174,5 +175,7 @@ public class RedactedChild private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedChild = Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class RedactedCycleA private constructor(
Expand Down Expand Up @@ -124,5 +125,8 @@ public class RedactedCycleA private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedCycleA =
Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class RedactedCycleB private constructor(
Expand Down Expand Up @@ -124,5 +125,8 @@ public class RedactedCycleB private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedCycleB =
Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class RedactedExtension private constructor(
Expand Down Expand Up @@ -150,5 +151,8 @@ public class RedactedExtension private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedExtension =
Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import okio.ByteString

public class RedactedFields private constructor(
Expand Down Expand Up @@ -202,5 +203,8 @@ public class RedactedFields private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedFields =
Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import kotlin.collections.List
import okio.ByteString

Expand Down Expand Up @@ -168,5 +169,8 @@ public class RedactedRepeated private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedRepeated =
Builder().apply(body).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlin.Boolean
import kotlin.Int
import kotlin.Long
import kotlin.String
import kotlin.Unit
import kotlin.UnsupportedOperationException
import okio.ByteString

Expand Down Expand Up @@ -124,5 +125,8 @@ public class RedactedRequired private constructor(
}

private const val serialVersionUID: Long = 0L

public inline fun build(body: Builder.() -> Unit): RedactedRequired =
Builder().apply(body).build()
}
}
Loading

0 comments on commit 094fa67

Please sign in to comment.