From 7f079cf7afe8b35bb06195ffbc068ccf69a9ce4c Mon Sep 17 00:00:00 2001 From: hughsimpson Date: Thu, 19 Dec 2024 16:34:57 +0000 Subject: [PATCH] codegen: Permit autogenerating types aliases for endpoints (#4213) --- .../scala/sttp/tapir/codegen/GenScala.scala | 12 +- .../sttp/tapir/codegen/BasicGenerator.scala | 16 +- .../tapir/codegen/EndpointGenerator.scala | 253 ++++++++++++------ .../tapir/codegen/BasicGeneratorSpec.scala | 3 +- .../ClassDefinitionGeneratorSpec.scala | 3 +- .../tapir/codegen/EndpointGeneratorSpec.scala | 15 +- .../sttp/tapir/sbt/OpenapiCodegenKeys.scala | 2 + .../sttp/tapir/sbt/OpenapiCodegenPlugin.scala | 3 + .../sttp/tapir/sbt/OpenapiCodegenTask.scala | 4 +- .../oneOf-json-roundtrip/Expected.scala.txt | 32 ++- .../ExpectedJsonSerdes.scala.txt | 4 + .../ExpectedSchemas.scala.txt | 3 + .../oneOf-json-roundtrip/build.sbt | 3 +- .../oneOf-json-roundtrip/swagger.yaml | 42 ++- 14 files changed, 293 insertions(+), 102 deletions(-) diff --git a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala index 361167d58e..8de7b9cbd1 100644 --- a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala +++ b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala @@ -65,6 +65,9 @@ object GenScala { private val streamingImplementationOpt: Opts[Option[String]] = Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone + private val generateEndpointTypesOpt: Opts[Boolean] = + Opts.flag("generateEndpointTypes", "Whether to emit explicit type aliases for endpoint declarations", "e").orFalse + private val destDirOpt: Opts[File] = Opts .option[String]("destdir", "Destination directory", "d") @@ -88,7 +91,8 @@ object GenScala { jsonLibOpt, validateNonDiscriminatedOneOfsOpt, maxSchemasPerFileOpt, - streamingImplementationOpt + streamingImplementationOpt, + generateEndpointTypesOpt ) .mapN { case ( @@ -101,7 +105,8 @@ object GenScala { jsonLib, validateNonDiscriminatedOneOfs, maxSchemasPerFile, - streamingImplementation + streamingImplementation, + generateEndpointTypes ) => val objectName = maybeObjectName.getOrElse(DefaultObjectName) @@ -116,7 +121,8 @@ object GenScala { jsonLib.getOrElse("circe"), streamingImplementation.getOrElse("fs2"), validateNonDiscriminatedOneOfs, - maxSchemasPerFile.getOrElse(400) + maxSchemasPerFile.getOrElse(400), + generateEndpointTypes ) ) destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index b7d21b9cdd..571fcc236b 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -25,6 +25,10 @@ object StreamingImplementation extends Enumeration { val Akka, FS2, Pekko, Zio = Value type StreamingImplementation = Value } +object EndpointCapabilites extends Enumeration { + val Akka, FS2, Nothing, Pekko, Zio = Value + type EndpointCapabilites = Value +} object BasicGenerator { @@ -40,7 +44,8 @@ object BasicGenerator { jsonSerdeLib: String, streamingImplementation: String, validateNonDiscriminatedOneOfs: Boolean, - maxSchemasPerFile: Int + maxSchemasPerFile: Int, + generateEndpointTypes: Boolean ): Map[String, String] = { val normalisedJsonLib = jsonSerdeLib.toLowerCase match { case "circe" => JsonSerdeLib.Circe @@ -65,7 +70,14 @@ object BasicGenerator { } val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) = - endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation) + endpointGenerator.endpointDefs( + doc, + useHeadTagForObjectNames, + targetScala3, + normalisedJsonLib, + normalisedStreamingImplementation, + generateEndpointTypes + ) val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) = classGenerator .classDefs( diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index 7b388ffc93..a30d874efd 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -2,6 +2,8 @@ package sttp.tapir.codegen import io.circe.Json import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase} import sttp.tapir.codegen.JsonSerdeLib.JsonSerdeLib +import sttp.tapir.codegen.EndpointCapabilites +import sttp.tapir.codegen.EndpointCapabilites.EndpointCapabilites import sttp.tapir.codegen.StreamingImplementation import sttp.tapir.codegen.StreamingImplementation.StreamingImplementation import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse} @@ -11,6 +13,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaBinary, OpenapiSchemaEnum, OpenapiSchemaMap, + OpenapiSchemaOneOf, OpenapiSchemaRef, OpenapiSchemaSimpleType, OpenapiSchemaString @@ -22,7 +25,18 @@ case class Location(path: String, method: String) { override def toString: String = s"${method.toUpperCase} ${path}" } -case class GeneratedEndpoint(name: String, definition: String, maybeLocalEnums: Option[String]) +case class EndpointTypes(security: Seq[String], in: Seq[String], err: Seq[String], out: Seq[String]) { + private def toType(types: Seq[String]) = types match { + case Nil => "Unit" + case t +: Nil => t + case seq => seq.mkString("(", ", ", ")") + } + def securityTypes = toType(security) + def inTypes = toType(in) + def errTypes = toType(err) + def outTypes = toType(out) +} +case class GeneratedEndpoint(name: String, definition: String, maybeLocalEnums: Option[String], types: EndpointTypes) case class GeneratedEndpointsForFile(maybeFileName: Option[String], generatedEndpoints: Seq[GeneratedEndpoint]) case class GeneratedEndpoints( @@ -54,22 +68,38 @@ class EndpointGenerator { private[codegen] def allEndpoints: String = "generatedEndpoints" + private def capabilityImpl(streamingImplementation: StreamingImplementation): String = streamingImplementation match { + case StreamingImplementation.Akka => "sttp.capabilities.akka.AkkaStreams" + case StreamingImplementation.FS2 => "sttp.capabilities.fs2.Fs2Streams[cats.effect.IO]" + case StreamingImplementation.Pekko => "sttp.capabilities.pekko.PekkoStreams" + case StreamingImplementation.Zio => "sttp.capabilities.zio.ZioStreams" + } + def endpointDefs( doc: OpenapiDocument, useHeadTagForObjectNames: Boolean, targetScala3: Boolean, jsonSerdeLib: JsonSerdeLib, - streamingImplementation: StreamingImplementation + streamingImplementation: StreamingImplementation, + generateEndpointTypes: Boolean ): EndpointDefs = { + val capabilities = capabilityImpl(streamingImplementation) val components = Option(doc.components).flatten val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) = doc.paths - .map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation)) + .map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation, doc)) .foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _) val endpointDecls = endpointsByFile.map { case GeneratedEndpointsForFile(k, ge) => val definitions = ge - .map { case GeneratedEndpoint(name, definition, maybeEnums) => - s"""lazy val $name = + .map { case GeneratedEndpoint(name, definition, maybeEnums, types) => + val theCapabilities = if (definition.contains(".capabilities.")) capabilities else "Any" + val endpointTypeDecl = + if (generateEndpointTypes) + s"type ${name.capitalize}Endpoint = Endpoint[${types.securityTypes}, ${types.inTypes}, ${types.errTypes}, ${types.outTypes}, $theCapabilities]\n" + else "" + + val maybeType = if (generateEndpointTypes) s": ${name.capitalize}Endpoint" else "" + s"""${endpointTypeDecl}lazy val $name$maybeType = |${indent(2)(definition)}${maybeEnums.fold("")("\n" + _)} |""".stripMargin } @@ -89,7 +119,8 @@ class EndpointGenerator { useHeadTagForObjectNames: Boolean, targetScala3: Boolean, jsonSerdeLib: JsonSerdeLib, - streamingImplementation: StreamingImplementation + streamingImplementation: StreamingImplementation, + doc: OpenapiDocument )(p: OpenapiPath): GeneratedEndpoints = { val parameters = components.map(_.parameters).getOrElse(Map.empty) val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty) @@ -111,15 +142,19 @@ class EndpointGenerator { } val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize)) - val (inParams, maybeLocalEnums) = + val (pathDecl, pathTypes) = urlMapper(p.url, m.resolvedParameters) + val (securityDecl, securityTypes) = security(securitySchemes, m.security) + val (inParams, maybeLocalEnums, inTypes) = ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib, streamingImplementation) + val (outDecl, outTypes, errTypes) = outs(m.responses, streamingImplementation, doc, targetScala3) + val allTypes = EndpointTypes(securityTypes.toSeq, pathTypes ++ inTypes, errTypes.toSeq, outTypes.toSeq) val definition = s"""|endpoint | .${m.methodType} - | ${urlMapper(p.url, m.resolvedParameters)} - |${indent(2)(security(securitySchemes, m.security))} + | $pathDecl + |${indent(2)(securityDecl)} |${indent(2)(inParams)} - |${indent(2)(outs(m.responses, streamingImplementation))} + |${indent(2)(outDecl)} |${indent(2)(tags(m.tags))} |$attributeString |""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n") @@ -148,7 +183,7 @@ class EndpointGenerator { } .toSet ( - (maybeTargetFileName, GeneratedEndpoint(name, definition, maybeLocalEnums)), + (maybeTargetFileName, GeneratedEndpoint(name, definition, maybeLocalEnums, allTypes)), (queryOrPathParamRefs, jsonParamRefs), maybeLocalEnums.isDefined ) @@ -167,26 +202,30 @@ class EndpointGenerator { ) } - private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = { + private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): (String, Seq[String]) = { // .in(("books" / path[String]("genre") / path[Int]("year")).mapTo[BooksFromYear]) - val inPath = url.split('/').filter(_.nonEmpty) map { segment => - if (segment.startsWith("{")) { - val name = segment.drop(1).dropRight(1) - val param = parameters.find(_.name == name) - param.fold(bail(s"URLParam $name not found!")) { p => - p.schema match { - case st: OpenapiSchemaSimpleType => - val (t, _) = mapSchemaSimpleTypeToType(st) - val desc = p.description.fold("")(d => s""".description("$d")""") - s"""path[$t]("$name")$desc""" - case _ => bail("Can't create non-simple params to url yet") + val (inPath, tpes) = url + .split('/') + .filter(_.nonEmpty) + .map { segment => + if (segment.startsWith("{")) { + val name = segment.drop(1).dropRight(1) + val param = parameters.find(_.name == name) + param.fold(bail(s"URLParam $name not found!")) { p => + p.schema match { + case st: OpenapiSchemaSimpleType => + val (t, _) = mapSchemaSimpleTypeToType(st) + val desc = p.description.fold("")(d => s""".description("$d")""") + s"""path[$t]("$name")$desc""" -> Some(t) + case _ => bail("Can't create non-simple params to url yet") + } } + } else { + '"' + segment + '"' -> None } - } else { - '"' + segment + '"' } - } - ".in((" + inPath.mkString(" / ") + "))" + .unzip + ".in((" + inPath.mkString(" / ") + "))" -> tpes.toSeq.flatten } private def security(securitySchemes: Map[String, OpenapiSecuritySchemeType], security: Seq[Seq[String]])(implicit location: Location) = { @@ -195,16 +234,16 @@ class EndpointGenerator { security.headOption .flatMap(_.headOption) - .fold("") { schemeName => + .fold("" -> Option.empty[String]) { schemeName => securitySchemes.get(schemeName) match { case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBearerType) => - ".securityIn(auth.bearer[String]())" + ".securityIn(auth.bearer[String]())" -> Some("String") case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBasicType) => - ".securityIn(auth.basic[UsernamePassword]())" + ".securityIn(auth.basic[UsernamePassword]())" -> Some("String") case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeApiKeyType(in, name)) => - s""".securityIn(auth.apiKey($in[String]("$name")))""" + s""".securityIn(auth.apiKey($in[String]("$name")))""" -> Some("String") case None => bail(s"Unknown security scheme $schemeName!") @@ -219,7 +258,7 @@ class EndpointGenerator { targetScala3: Boolean, jsonSerdeLib: JsonSerdeLib, streamingImplementation: StreamingImplementation - )(implicit location: Location): (String, Option[String]) = { + )(implicit location: Location): (String, Option[String], Seq[String]) = { def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = { val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize val enumParamRefs = if (param.in == "query" || param.in == "path") Set(enumName) else Set.empty[String] @@ -237,14 +276,20 @@ class EndpointGenerator { // 'exploded' params have no distinction between an empty list and an absent value, so don't wrap in 'Option' for them val noOptionWrapper = required || (isArray && param.isExploded) val req = if (noOptionWrapper) tpe else s"Option[$tpe]" + val outType = (isArray, noOptionWrapper) match { + case (true, true) => s"List[$enumName]" + case (true, false) => s"Option[List[$enumName]]" + case (false, true) => enumName + case (false, false) => s"Option[$enumName]" + } def mapToList = if (!isArray) "" else if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))" val desc = param.description.map(d => JavaEscape.escapeString(d)).fold("")(d => s""".description("$d")""") - s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""" -> Some(enumDefn) + (s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", Some(enumDefn), outType) } // .in(query[Limit]("limit").description("Maximum number of books to retrieve")) // .in(header[AuthToken]("X-Auth-Token")) - val (params, maybeEnumDefns) = parameters + val (params, maybeEnumDefns, inTypes) = parameters .filter(_.in != "path") .map { param => param.schema match { @@ -252,7 +297,7 @@ class EndpointGenerator { val (t, _) = mapSchemaSimpleTypeToType(st) val req = if (param.required.getOrElse(true)) t else s"Option[$t]" val desc = param.description.map(d => JavaEscape.escapeString(d)).fold("")(d => s""".description("$d")""") - s""".in(${param.in}[$req]("${param.name}")$desc)""" -> None + (s""".in(${param.in}[$req]("${param.name}")$desc)""", None, req) case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) => val (t, _) = mapSchemaSimpleTypeToType(st) val arrayType = if (param.isExploded) "ExplodedValues" else "CommaSeparatedValues" @@ -263,25 +308,32 @@ class EndpointGenerator { val req = if (noOptionWrapper) arr else s"Option[$arr]" def mapToList = if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))" val desc = param.description.map(d => JavaEscape.escapeString(d)).fold("")(d => s""".description("$d")""") - s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""" -> None + (s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", None, req) case e @ OpenapiSchemaEnum(_, _, _) => getEnumParamDefn(param, e, isArray = false) case OpenapiSchemaArray(e: OpenapiSchemaEnum, _) => getEnumParamDefn(param, e, isArray = true) case x => bail(s"Can't create non-simple params to input - found $x") } } - .unzip + .unzip3 - val rqBody = requestBody.flatMap { b => + val (rqBody, maybeReqType) = requestBody.flatMap { b => if (b.content.isEmpty) None else if (b.content.size != 1) bail(s"We can handle only one requestBody content! Saw ${b.content.map(_.contentType)}") - else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, streamingImplementation, b.required)})") - } + else { + val (decl, tpe) = contentTypeMapper(b.content.head.contentType, b.content.head.schema, streamingImplementation, b.required) + Some(s".in($decl)" -> tpe) + } + }.unzip - (params ++ rqBody).mkString("\n") -> maybeEnumDefns.foldLeft(Option.empty[String]) { - case (acc, None) => acc - case (None, Some(nxt)) => Some(nxt.mkString("\n")) - case (Some(acc), Some(nxt)) => Some(acc + "\n" + nxt.mkString("\n")) - } + ( + (params ++ rqBody).mkString("\n"), + maybeEnumDefns.foldLeft(Option.empty[String]) { + case (acc, None) => acc + case (None, Some(nxt)) => Some(nxt.mkString("\n")) + case (Some(acc), Some(nxt)) => Some(acc + "\n" + nxt.mkString("\n")) + }, + inTypes ++ maybeReqType + ) } private def tags(openapiTags: Option[Seq[String]]): String = { @@ -305,7 +357,14 @@ class EndpointGenerator { // treats redirects as ok private val okStatus = """([23]\d\d)""".r private val errorStatus = """([45]\d\d)""".r - private def outs(responses: Seq[OpenapiResponse], streamingImplementation: StreamingImplementation)(implicit location: Location) = { + private def outs( + responses: Seq[OpenapiResponse], + streamingImplementation: StreamingImplementation, + doc: OpenapiDocument, + targetScala3: Boolean + )(implicit + location: Location + ) = { // .errorOut(stringBody) // .out(jsonBody[List[Book]]) @@ -320,44 +379,83 @@ class EndpointGenerator { case _ => bail("We can handle only one return content!") } } - def bodyFmt(resp: OpenapiResponse): String = { + def bodyFmt(resp: OpenapiResponse): (String, Option[String]) = { val d = s""".description("${JavaEscape.escapeString(resp.description)}")""" resp.content match { - case Nil => "" + case Nil => "" -> None case content +: Nil => - s"${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d" + val (decl, tpe) = contentTypeMapper(content.contentType, content.schema, streamingImplementation) + s"$decl$d" -> Some(tpe) } } - def mappedGroup(group: Seq[OpenapiResponse]) = group match { - case Nil => None + def mappedGroup(group: Seq[OpenapiResponse]): (Option[String], Option[String]) = group match { + case Nil => None -> None case resp +: Nil => resp.content match { case Nil => val d = s""".description("${JavaEscape.escapeString(resp.description)}")""" - resp.code match { - case "200" | "default" => None - case okStatus(s) => Some(s"statusCode(sttp.model.StatusCode($s))$d") - case errorStatus(s) => Some(s"statusCode(sttp.model.StatusCode($s))$d") - } + ( + resp.code match { + case "200" | "default" => None + case okStatus(s) => Some(s"statusCode(sttp.model.StatusCode($s))$d") + case errorStatus(s) => Some(s"statusCode(sttp.model.StatusCode($s))$d") + }, + None + ) case _ => + val (decl, tpe) = bodyFmt(resp) Some(resp.code match { - case "200" | "default" => s"${bodyFmt(resp)}" - case okStatus(s) => s"${bodyFmt(resp)}.and(statusCode(sttp.model.StatusCode($s)))" - case errorStatus(s) => s"${bodyFmt(resp)}.and(statusCode(sttp.model.StatusCode($s)))" - }) + case "200" | "default" => decl + case okStatus(s) => s"$decl.and(statusCode(sttp.model.StatusCode($s)))" + case errorStatus(s) => s"$decl.and(statusCode(sttp.model.StatusCode($s)))" + }) -> tpe } case many => if (many.map(_.code).distinct.size != many.size) bail("Cannot construct schema for multiple responses with same status code") - val oneOfs = many.map { m => + val (oneOfs, types) = many.map { m => + val (decl, tpe) = bodyFmt(m) val code = if (m.code == "default") "400" else m.code - s"oneOfVariant(sttp.model.StatusCode(${code}), ${bodyFmt(m)})" - } - Some(s"oneOf(${oneOfs.mkString(", ")})") + s"oneOfVariant(sttp.model.StatusCode(${code}), $decl)" -> tpe + }.unzip + val parentMap = doc.components.toSeq + .flatMap(_.schemas) + .collect { case (k, v: OpenapiSchemaOneOf) => + v.types.map { + case r: OpenapiSchemaRef => r.stripped -> k + case x: OpenapiSchemaSimpleType => mapSchemaSimpleTypeToType(x)._1 -> k + case x => bail(s"Unexpected oneOf child type $x") + } + } + .flatten + .groupBy(_._1) + .map { case (k, vs) => k -> vs.map(_._2) } + .toMap + val allElemTypes = many + .flatMap(_.content.map(_.schema)) + .map { + case r: OpenapiSchemaRef => r.stripped + case x: OpenapiSchemaSimpleType => mapSchemaSimpleTypeToType(x)._1 + case x => bail(s"Unexpected oneOf elem type $x") + } + .distinct + val commmonType = + if (allElemTypes.size == 1) allElemTypes.head + else + allElemTypes.map { s => parentMap.getOrElse(s, Nil).toSet }.reduce(_ intersect _) match { + case s if s.isEmpty && targetScala3 => types.mkString(" | ") + case s if s.isEmpty => "Any" + case s if targetScala3 => s.mkString(" & ") + case s => s.mkString(" with ") + } + Some(s"oneOf[$commmonType](${oneOfs.mkString(", ")})") -> Some(commmonType) } - val mappedOuts = mappedGroup(outs).map(s => s".out($s)") - val mappedErrorOuts = mappedGroup(errorOuts).map(s => s".errorOut($s)") - Seq(mappedErrorOuts, mappedOuts).flatten.mkString("\n") + val (outDecls, outTypes) = mappedGroup(outs) + val mappedOuts = outDecls.map(s => s".out($s)") + val (errDecls, errTypes) = mappedGroup(errorOuts) + val mappedErrorOuts = errDecls.map(s => s".errorOut($s)") + + (Seq(mappedErrorOuts, mappedOuts).flatten.mkString("\n"), outTypes, errTypes) } private def contentTypeMapper( @@ -365,10 +463,10 @@ class EndpointGenerator { schema: OpenapiSchemaType, streamingImplementation: StreamingImplementation, required: Boolean = true - )(implicit location: Location) = { + )(implicit location: Location): (String, String) = { contentType match { case "text/plain" => - "stringBody" + "stringBody" -> "String" case "application/json" => val outT = schema match { case st: OpenapiSchemaSimpleType => @@ -383,27 +481,22 @@ class EndpointGenerator { case x => bail(s"Can't create non-simple or array params as output (found $x)") } val req = if (required) outT else s"Option[$outT]" - s"jsonBody[$req]" + s"jsonBody[$req]" -> req case "multipart/form-data" => schema match { case _: OpenapiSchemaBinary => - "multipartBody" + "multipartBody" -> " Seq[Part[Array[Byte]]]" case schemaRef: OpenapiSchemaRef => val (t, _) = mapSchemaSimpleTypeToType(schemaRef, multipartForm = true) - s"multipartBody[$t]" + s"multipartBody[$t]" -> t case x => bail(s"$contentType only supports schema ref or binary. Found $x") } case "application/octet-stream" => - val capability = streamingImplementation match { - case StreamingImplementation.Akka => "sttp.capabilities.akka.AkkaStreams" - case StreamingImplementation.FS2 => "sttp.capabilities.fs2.Fs2Streams[cats.effect.IO]" - case StreamingImplementation.Pekko => "sttp.capabilities.pekko.PekkoStreams" - case StreamingImplementation.Zio => "sttp.capabilities.zio.ZioStreams" - } + val capability = capabilityImpl(streamingImplementation) schema match { case _: OpenapiSchemaString => - s"streamTextBody($capability)(CodecFormat.OctetStream())" + s"streamTextBody($capability)(CodecFormat.OctetStream())" -> s"$capability.BinaryStream" case schema => val outT = schema match { case st: OpenapiSchemaSimpleType => @@ -417,7 +510,7 @@ class EndpointGenerator { s"Map[String, $t]" case x => bail(s"Can't create this param as output (found $x)") } - s"streamBody($capability)(Schema.binary[$outT], CodecFormat.OctetStream())" + s"streamBody($capability)(Schema.binary[$outT], CodecFormat.OctetStream())" -> s"$capability.BinaryStream" } case x => bail(s"Not all content types supported! Found $x") diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala index a4f334339b..d1b5d61501 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala @@ -18,7 +18,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase { jsonSerdeLib = jsonSerdeLib, validateNonDiscriminatedOneOfs = true, maxSchemasPerFile = 400, - streamingImplementation = "fs2" + streamingImplementation = "fs2", + generateEndpointTypes = false ) } def gen( diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala index b73060d8a7..de41c97a6c 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala @@ -396,7 +396,8 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe, - streamingImplementation = StreamingImplementation.FS2 + streamingImplementation = StreamingImplementation.FS2, + generateEndpointTypes = false ) .endpointDecls(None) } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index 453ea55dab..c0f199ef31 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -68,7 +68,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe, - streamingImplementation = StreamingImplementation.FS2 + streamingImplementation = StreamingImplementation.FS2, + generateEndpointTypes = false ) .endpointDecls(None) generatedCode should include("val getTestAsdId =") @@ -153,7 +154,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe, - streamingImplementation = StreamingImplementation.FS2 + streamingImplementation = StreamingImplementation.FS2, + generateEndpointTypes = false ) .endpointDecls(None) shouldCompile () } @@ -205,7 +207,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe, - streamingImplementation = StreamingImplementation.FS2 + streamingImplementation = StreamingImplementation.FS2, + generateEndpointTypes = false ) .endpointDecls(None) generatedCode should include( @@ -272,7 +275,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { jsonSerdeLib = "circe", validateNonDiscriminatedOneOfs = true, maxSchemasPerFile = 400, - streamingImplementation = "fs2" + streamingImplementation = "fs2", + generateEndpointTypes = false )("TapirGeneratedEndpoints") generatedCode should include( """file: sttp.model.Part[java.io.File]""" @@ -294,7 +298,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { jsonSerdeLib = "circe", validateNonDiscriminatedOneOfs = true, maxSchemasPerFile = 400, - streamingImplementation = "fs2" + streamingImplementation = "fs2", + generateEndpointTypes = false )("TapirGeneratedEndpoints") generatedCode shouldCompile () val expectedAttrDecls = Seq( diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala index ba00d8a34e..510f5b8540 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala @@ -11,6 +11,7 @@ case class OpenApiConfiguration( streamingImplementation: String, validateNonDiscriminatedOneOfs: Boolean, maxSchemasPerFile: Int, + generateEndpointTypes: Boolean, additionalPackages: List[(String, File)] ) @@ -27,6 +28,7 @@ trait OpenapiCodegenKeys { lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file") lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.") lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.") + lazy val openapiGenerateEndpointTypes = settingKey[Boolean]("Whether to emit explicit types for endpoint denfs") lazy val openapiOpenApiConfiguration = settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.") diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala index 5f086617e1..16ebcb092b 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala @@ -32,6 +32,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiStreamingImplementation.value, openapiValidateNonDiscriminatedOneOfs.value, openapiMaxSchemasPerFile.value, + openapiGenerateEndpointTypes.value, openapiAdditionalPackages.value ) def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq( @@ -44,6 +45,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiMaxSchemasPerFile := 400, openapiAdditionalPackages := Nil, openapiStreamingImplementation := "fs2", + openapiGenerateEndpointTypes := false, standardParamSetting ) @@ -73,6 +75,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { c.streamingImplementation, c.validateNonDiscriminatedOneOfs, c.maxSchemasPerFile, + c.generateEndpointTypes, srcDir, taskStreams.cacheDirectory, sv.startsWith("3"), diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala index e689e3a09e..ea07a0a226 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala @@ -14,6 +14,7 @@ case class OpenapiCodegenTask( streamingImplementation: String, validateNonDiscriminatedOneOfs: Boolean, maxSchemasPerFile: Int, + generateEndpointTypes: Boolean, dir: File, cacheDir: File, targetScala3: Boolean, @@ -59,7 +60,8 @@ case class OpenapiCodegenTask( jsonSerdeLib, streamingImplementation, validateNonDiscriminatedOneOfs, - maxSchemasPerFile + maxSchemasPerFile, + generateEndpointTypes ) .map { case (objectName, fileBody) => val file = directory / s"$objectName.scala" diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt index 4b2c350770..98e8d347f2 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt @@ -44,7 +44,6 @@ object TapirGeneratedEndpoints { support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values) } - case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] { // Case-insensitive mapping def decode(s: String): sttp.tapir.DecodeResult[T] = @@ -63,9 +62,16 @@ object TapirGeneratedEndpoints { } def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] = EnumExtraParamSupport(enumName, T) + sealed trait Error sealed trait ADTWithoutDiscriminator sealed trait ADTWithDiscriminator sealed trait ADTWithDiscriminatorNoMapping + case class SimpleError ( + message: String + ) extends Error + case class NotFoundError ( + reason: String + ) extends Error case class SubtypeWithoutD1 ( s: String, i: Option[Int] = None, @@ -119,34 +125,39 @@ object TapirGeneratedEndpoints { - lazy val getBinaryTest = + type GetBinaryTestEndpoint = Endpoint[Unit, Unit, Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, sttp.capabilities.pekko.PekkoStreams] + lazy val getBinaryTest: GetBinaryTestEndpoint = endpoint .get .in(("binary" / "test")) .out(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()).description("Response CSV body")) - lazy val postBinaryTest = + type PostBinaryTestEndpoint = Endpoint[Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, Unit, String, sttp.capabilities.pekko.PekkoStreams] + lazy val postBinaryTest: PostBinaryTestEndpoint = endpoint .post .in(("binary" / "test")) .in(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream())) .out(jsonBody[String].description("successful operation")) - lazy val putAdtTest = + type PutAdtTestEndpoint = Endpoint[Unit, ADTWithoutDiscriminator, Unit, ADTWithoutDiscriminator, Any] + lazy val putAdtTest: PutAdtTestEndpoint = endpoint .put .in(("adt" / "test")) .in(jsonBody[ADTWithoutDiscriminator]) .out(jsonBody[ADTWithoutDiscriminator].description("successful operation")) - lazy val postAdtTest = + type PostAdtTestEndpoint = Endpoint[Unit, ADTWithDiscriminatorNoMapping, Unit, ADTWithDiscriminator, Any] + lazy val postAdtTest: PostAdtTestEndpoint = endpoint .post .in(("adt" / "test")) .in(jsonBody[ADTWithDiscriminatorNoMapping]) .out(jsonBody[ADTWithDiscriminator].description("successful operation")) - lazy val postInlineEnumTest = + type PostInlineEnumTestEndpoint = Endpoint[Unit, (PostInlineEnumTestQueryEnum, Option[PostInlineEnumTestQueryOptEnum], List[PostInlineEnumTestQuerySeqEnum], Option[List[PostInlineEnumTestQueryOptSeqEnum]], ObjectWithInlineEnum), Unit, Unit, Any] + lazy val postInlineEnumTest: PostInlineEnumTestEndpoint = endpoint .post .in(("inline" / "enum" / "test")) @@ -197,7 +208,14 @@ object TapirGeneratedEndpoints { extraCodecSupport[PostInlineEnumTestQueryOptSeqEnum]("PostInlineEnumTestQueryOptSeqEnum", PostInlineEnumTestQueryOptSeqEnum) } + type GetOneofErrorTestEndpoint = Endpoint[Unit, Unit, Error, Unit, Any] + lazy val getOneofErrorTest: GetOneofErrorTestEndpoint = + endpoint + .get + .in(("oneof" / "error" / "test")) + .errorOut(oneOf[Error](oneOfVariant(sttp.model.StatusCode(404), jsonBody[NotFoundError].description("Not found")), oneOfVariant(sttp.model.StatusCode(400), jsonBody[SimpleError].description("Not found")))) + .out(statusCode(sttp.model.StatusCode(204)).description("No response")) - lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest) + lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest, getOneofErrorTest) } diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt index eff2439305..9013ccbdfc 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt @@ -16,6 +16,10 @@ object TapirGeneratedEndpointsJsonSerdes { } } yield res } + implicit lazy val simpleErrorJsonDecoder: io.circe.Decoder[SimpleError] = io.circe.generic.semiauto.deriveDecoder[SimpleError] + implicit lazy val simpleErrorJsonEncoder: io.circe.Encoder[SimpleError] = io.circe.generic.semiauto.deriveEncoder[SimpleError] + implicit lazy val notFoundErrorJsonDecoder: io.circe.Decoder[NotFoundError] = io.circe.generic.semiauto.deriveDecoder[NotFoundError] + implicit lazy val notFoundErrorJsonEncoder: io.circe.Encoder[NotFoundError] = io.circe.generic.semiauto.deriveEncoder[NotFoundError] implicit lazy val subtypeWithoutD1JsonDecoder: io.circe.Decoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithoutD1] implicit lazy val subtypeWithoutD1JsonEncoder: io.circe.Encoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithoutD1] implicit lazy val subtypeWithD1JsonDecoder: io.circe.Decoder[SubtypeWithD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithD1] diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt index 90179526d3..f8a7ec11c4 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt @@ -4,8 +4,10 @@ object TapirGeneratedEndpointsSchemas { import sttp.tapir.generated.TapirGeneratedEndpoints._ import sttp.tapir.generic.auto._ implicit lazy val anEnumTapirSchema: sttp.tapir.Schema[AnEnum] = sttp.tapir.Schema.derived + implicit lazy val notFoundErrorTapirSchema: sttp.tapir.Schema[NotFoundError] = sttp.tapir.Schema.derived implicit lazy val objectWithInlineEnumInlineEnumTapirSchema: sttp.tapir.Schema[ObjectWithInlineEnumInlineEnum] = sttp.tapir.Schema.derived implicit lazy val objectWithInlineEnumTapirSchema: sttp.tapir.Schema[ObjectWithInlineEnum] = sttp.tapir.Schema.derived + implicit lazy val simpleErrorTapirSchema: sttp.tapir.Schema[SimpleError] = sttp.tapir.Schema.derived implicit lazy val subtypeWithD1TapirSchema: sttp.tapir.Schema[SubtypeWithD1] = sttp.tapir.Schema.derived implicit lazy val subtypeWithD2TapirSchema: sttp.tapir.Schema[SubtypeWithD2] = sttp.tapir.Schema.derived implicit lazy val subtypeWithoutD1TapirSchema: sttp.tapir.Schema[SubtypeWithoutD1] = sttp.tapir.Schema.derived @@ -38,6 +40,7 @@ object TapirGeneratedEndpointsSchemas { case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminatorNoMapping should be a coproduct") } } + implicit lazy val errorTapirSchema: sttp.tapir.Schema[Error] = sttp.tapir.Schema.derived implicit lazy val subtypeWithoutD3TapirSchema: sttp.tapir.Schema[SubtypeWithoutD3] = sttp.tapir.Schema.derived implicit lazy val aDTWithoutDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithoutDiscriminator] = sttp.tapir.Schema.derived } diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt index 034a656e36..dd0e5c6bae 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt @@ -3,7 +3,8 @@ lazy val root = (project in file(".")) .settings( scalaVersion := "2.13.15", version := "0.1", - openapiStreamingImplementation := "pekko" + openapiStreamingImplementation := "pekko", + openapiGenerateEndpointTypes := true ) libraryDependencies ++= Seq( diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml index c380f44bd7..20b5b94910 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml @@ -127,6 +127,24 @@ paths: application/json: schema: $ref: '#/components/schemas/ObjectWithInlineEnum' + '/oneof/error/test': + get: + responses: + "204": + description: "No response" + "404": + description: Not found + content: + application/json: + schema: + $ref: '#/components/schemas/NotFoundError' + default: + description: Not found + content: + application/json: + schema: + $ref: '#/components/schemas/SimpleError' + components: schemas: @@ -247,4 +265,26 @@ components: - foo1 - foo2 - foo3 - - foo4 \ No newline at end of file + - foo4 + Error: + title: Error + type: object + oneOf: + - $ref: '#/components/schemas/NotFoundError' + - $ref: '#/components/schemas/SimpleError' + NotFoundError: + title: NotFoundError + required: + - reason + type: object + properties: + reason: + type: string + SimpleError: + title: SimpleError + required: + - message + type: object + properties: + message: + type: string \ No newline at end of file