Skip to content

Commit

Permalink
Minor code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mdedetrich committed Jan 23, 2024
1 parent 02994bf commit f84c6bd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
52 changes: 25 additions & 27 deletions src/main/scala/com/typesafe/sbt/MultiJvmPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import Keys._
import java.io.File
import java.lang.Boolean.getBoolean

import scala.Console.{ GREEN, RESET }
import scala.sys.process.Process

import sbtassembly.AssemblyPlugin.assemblySettings
Expand All @@ -20,7 +19,7 @@ import AssemblyKeys._

object MultiJvmPlugin extends AutoPlugin {

case class Options(jvm: Seq[String], extra: String => Seq[String], run: String => Seq[String])
final case class Options(jvm: Seq[String], extra: String => Seq[String], run: String => Seq[String])

object autoImport extends MultiJvmKeys

Expand Down Expand Up @@ -181,9 +180,9 @@ object MultiJvmPlugin extends AutoPlugin {
}
}

def multiName(name: String, marker: String) = name.split(marker).head
def multiName(name: String, marker: String): String = name.split(marker).head

def multiSimpleName(name: String) = name.split("\\.").last
def multiSimpleName(name: String): String = name.split("\\.").last

def javaCommand(javaHome: Option[File], name: String): File = {
val home = javaHome.getOrElse(new File(System.getProperty("java.home")))
Expand All @@ -198,7 +197,7 @@ object MultiJvmPlugin extends AutoPlugin {
options: Seq[String],
fullClasspath: Classpath,
multiRunCopiedClassDir: File
) = {
): String => Seq[String] = {
val directoryBasedClasspathEntries = fullClasspath.files.filter(_.isDirectory)
// Copy over just the jars to this folder.
fullClasspath.files
Expand All @@ -211,11 +210,12 @@ object MultiJvmPlugin extends AutoPlugin {
(testClass: String) => { Seq("-cp", cp, runner, "-s", testClass) ++ options }
}

def scalaMultiNodeOptionsForScalatest(runner: String, options: Seq[String]) = { (testClass: String) =>
{ Seq(runner, "-s", testClass) ++ options }
def scalaMultiNodeOptionsForScalatest(runner: String, options: Seq[String]): String => Seq[String] = {
(testClass: String) =>
{ Seq(runner, "-s", testClass) ++ options }
}

def scalaOptionsForApps(classpath: Classpath) = {
def scalaOptionsForApps(classpath: Classpath): String => Seq[String] = {
val cp = classpath.files.absString
(mainClass: String) => Seq("-cp", cp, mainClass)
}
Expand Down Expand Up @@ -270,12 +270,12 @@ object MultiJvmPlugin extends AutoPlugin {
List()
else
tests.map { case (_name, classes) =>
multi(_name, classes, marker, javaBin, options, srcDir, false, createLogger, log)
multi(_name, classes, marker, javaBin, options, srcDir, input = false, createLogger, log)
}
Tests.Output(
Tests.overall(results.map(_._2)),
Tests.overall(results.map { case (_, testResult) => testResult }),
Map.empty,
results.map(result => Tests.Summary("multi-jvm", result._1))
results.map { case (testClass, _) => Tests.Summary("multi-jvm", testClass) }
)
}

Expand Down Expand Up @@ -304,7 +304,7 @@ object MultiJvmPlugin extends AutoPlugin {

def runParser: (State, Seq[String]) => complete.Parser[String] = {
import complete.DefaultParsers._
(state, appClasses) => Space ~> token(NotSpace examples appClasses.toSet)
(_, appClasses) => Space ~> token(NotSpace examples appClasses.toSet)
}

def multi(
Expand All @@ -318,10 +318,9 @@ object MultiJvmPlugin extends AutoPlugin {
createLogger: String => Logger,
log: Logger
): (String, TestResult) = {
val logName = "* " + name
log.info(if (log.ansiCodesSupported) GREEN + logName + RESET else logName)
log.info("* " + name)
val classesHostsJavas = getClassesHostsJavas(classes, IndexedSeq.empty, IndexedSeq.empty, "")
val hosts = classesHostsJavas.map(_._2)
val hosts = classesHostsJavas.map { case (_, hostAndUser, _) => hostAndUser }
val processes = classes.zipWithIndex map { case (testClass, index) =>
val className = multiSimpleName(testClass)
val jvmName = "JVM-" + (index + 1) + "-" + className
Expand Down Expand Up @@ -432,9 +431,9 @@ object MultiJvmPlugin extends AutoPlugin {
)
}
Tests.Output(
Tests.overall(results.map(_._2)),
Tests.overall(results.map { case (_, testResult) => testResult }),
Map.empty,
results.map(result => Tests.Summary("multi-jvm", result._1))
results.map { case (testClass, _) => Tests.Summary("multi-jvm", testClass) }
)
}

Expand All @@ -453,16 +452,15 @@ object MultiJvmPlugin extends AutoPlugin {
createLogger: String => Logger,
log: Logger
): (String, TestResult) = {
val logName = "* " + name
log.info(if (log.ansiCodesSupported) GREEN + logName + RESET else logName)
log.info("* " + name)
val classesHostsJavas = getClassesHostsJavas(classes, hostsAndUsers, javas, defaultJava)
val hosts = classesHostsJavas.map(_._2)
val hostAndUsers = classesHostsJavas.map { case (_, hostAndUser, _) => hostAndUser }
// TODO move this out, maybe to the hosts string as well?
val syncProcesses = classesHostsJavas.map { case (testClass, hostAndUser, _) =>
(testClass + " sync", Jvm.syncJar(testJar, hostAndUser, targetDir, log))
}
val syncResult = processExitCodes(name, syncProcesses, log)
if (syncResult._2 == TestResult.Passed) {
val (syncName, syncTestResult) = processExitCodes(name, syncProcesses, log)
if (syncTestResult == TestResult.Passed) {
val processes = classesHostsJavas.zipWithIndex map { case ((testClass, hostAndUser, java), index) =>
val jvmName = "JVM-" + (index + 1)
val jvmLogger = createLogger(jvmName)
Expand All @@ -471,7 +469,7 @@ object MultiJvmPlugin extends AutoPlugin {
val optionsFromFile =
optionsFile map (IO.read(_)) map (_.trim.replace("\\n", " ").split("\\s+").toList) getOrElse Seq
.empty[String]
val multiNodeOptions = getMultiNodeCommandLineOptions(hosts, index, classes.size)
val multiNodeOptions = getMultiNodeCommandLineOptions(hostAndUsers, index, classes.size)
val allJvmOptions = options.jvm ++ optionsFromFile ++ options.extra(className) ++ multiNodeOptions
val runOptions = options.run(testClass)
val connectInput = input && index == 0
Expand All @@ -494,15 +492,15 @@ object MultiJvmPlugin extends AutoPlugin {
}
processExitCodes(name, processes, log)
} else
syncResult
(syncName, syncTestResult)
}

private def padSeqOrDefaultTo(seq: IndexedSeq[String], default: String, max: Int): IndexedSeq[String] = {
val realSeq = if (seq.isEmpty) IndexedSeq(default) else seq
if (realSeq.size >= max)
realSeq
else
(realSeq /: (0 until (max - realSeq.size)))((mySeq, pos) => mySeq :+ realSeq(pos % realSeq.size))
(0 until (max - realSeq.size)).foldLeft(realSeq)((mySeq, pos) => mySeq :+ realSeq(pos % realSeq.size))
}

private def getClassesHostsJavas(
Expand Down Expand Up @@ -540,13 +538,13 @@ object MultiJvmPlugin extends AutoPlugin {
if (hosts.isEmpty) {
if (hostsFile.exists && hostsFile.canRead) {
s.log.info("Using hosts defined in file " + hostsFile.getAbsolutePath)
IO.readLines(hostsFile).map(_.trim).filter(_.length > 0).toIndexedSeq
IO.readLines(hostsFile).map(_.trim).filter(_.nonEmpty).toIndexedSeq
} else
hosts.toIndexedSeq
} else {
if (hostsFile.exists && hostsFile.canRead)
s.log.info(
"Hosts from setting " + multiNodeHosts.key.label + " is overrriding file " + hostsFile.getAbsolutePath
"Hosts from setting " + multiNodeHosts.key.label + " is overriding file " + hostsFile.getAbsolutePath
)
hosts.toIndexedSeq
}
Expand Down
21 changes: 10 additions & 11 deletions src/main/scala/com/typesafe/sbt/multijvm/Jvm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ object Jvm {
runOptions: Seq[String],
logger: Logger,
connectInput: Boolean
) = {
): Process =
forkJava(javaBin, jvmOptions ++ runOptions, logger, connectInput)
}

def forkJava(javaBin: File, options: Seq[String], logger: Logger, connectInput: Boolean) = {
def forkJava(javaBin: File, options: Seq[String], logger: Logger, connectInput: Boolean): Process = {
val java = javaBin.toString
val command = (java :: options.toList).toArray
val builder = new JProcessBuilder(command: _*)
Expand All @@ -31,7 +30,7 @@ object Jvm {
/**
* check if the current operating system is some OS
*/
def isOS(os: String) = try {
def isOS(os: String): Boolean = try {
System.getProperty("os.name").toUpperCase startsWith os.toUpperCase
} catch {
case _: Throwable => false
Expand All @@ -40,7 +39,7 @@ object Jvm {
/**
* convert to proper path for the operating system
*/
def osPath(path: String) = if (isOS("WINDOWS")) Process(Seq("cygpath", path)).lineStream.mkString else path
def osPath(path: String): String = if (isOS("WINDOWS")) Process(Seq("cygpath", path)).lineStream.mkString else path

def syncJar(jarName: String, hostAndUser: String, remoteDir: String, sbtLogger: Logger): Process = {
val command: Array[String] = Array("ssh", hostAndUser, "mkdir -p " + remoteDir)
Expand Down Expand Up @@ -78,21 +77,21 @@ object Jvm {
}

class JvmBasicLogger(name: String) extends BasicLogger {
def jvm(message: String) = "[%s] %s" format (name, message)
def jvm(message: String): String = "[%s] %s" format (name, message)

def log(level: Level.Value, message: => String) = System.out.synchronized {
def log(level: Level.Value, message: => String): Unit = System.out.synchronized {
System.out.println(jvm(message))
}

def trace(t: => Throwable) = System.out.synchronized {
def trace(t: => Throwable): Unit = System.out.synchronized {
val traceLevel = getTrace
if (traceLevel >= 0) System.out.print(StackTrace.trimmed(t, traceLevel))
}

def success(message: => String) = log(Level.Info, message)
def control(event: ControlEvent.Value, message: => String) = log(Level.Info, message)
def success(message: => String): Unit = log(Level.Info, message)
def control(event: ControlEvent.Value, message: => String): Unit = log(Level.Info, message)

def logAll(events: Seq[LogEvent]) = System.out.synchronized { events.foreach(log) }
def logAll(events: Seq[LogEvent]): Unit = System.out.synchronized { events.foreach(log) }
}

final class JvmLogger(name: String) extends JvmBasicLogger(name)

0 comments on commit f84c6bd

Please sign in to comment.