diff options
Diffstat (limited to 'scalalib/src/scalafmt/ScalafmtWorker.scala')
-rw-r--r-- | scalalib/src/scalafmt/ScalafmtWorker.scala | 99 |
1 files changed, 70 insertions, 29 deletions
diff --git a/scalalib/src/scalafmt/ScalafmtWorker.scala b/scalalib/src/scalafmt/ScalafmtWorker.scala index f9c7e9b4..25b58f24 100644 --- a/scalalib/src/scalafmt/ScalafmtWorker.scala +++ b/scalalib/src/scalafmt/ScalafmtWorker.scala @@ -8,6 +8,7 @@ import mill.api.Ctx import org.scalafmt.interfaces.Scalafmt import scala.collection.mutable +import mill.api.Result object ScalafmtWorkerModule extends ExternalModule { def worker: Worker[ScalafmtWorker] = T.worker { new ScalafmtWorker() } @@ -21,44 +22,84 @@ private[scalafmt] class ScalafmtWorker { def reformat(input: Seq[PathRef], scalafmtConfig: PathRef)(implicit ctx: Ctx): Unit = { - val toFormat = - if (scalafmtConfig.sig != configSig) input - else - input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig)) - - if (toFormat.nonEmpty) { - ctx.log.info(s"Formatting ${toFormat.size} Scala sources") - reformatAction(toFormat.map(_.path), - scalafmtConfig.path) - reformatted ++= toFormat.map { ref => - val updRef = PathRef(ref.path) - updRef.path -> updRef.sig - } - configSig = scalafmtConfig.sig + reformatAction(input, scalafmtConfig, dryRun = false) + } + + def checkFormat(input: Seq[PathRef], + scalafmtConfig: PathRef)(implicit ctx: Ctx): Result[Unit] = { + + val misformatted = reformatAction(input, scalafmtConfig, dryRun = true) + if (misformatted.isEmpty) { + Result.Success(()) } else { - ctx.log.info(s"Everything is formatted already") + val out = ctx.log.outputStream + for (u <- misformatted) { + out.println(u.path.toString) + } + Result.Failure(s"Found ${misformatted.length} misformatted files") } } - private val cliFlags = Seq("--non-interactive", "--quiet") + // run scalafmt over input files and return any files that changed + // (only save changes to files if dryRun is false) + private def reformatAction( + input: Seq[PathRef], + scalafmtConfig: PathRef, + dryRun: Boolean + )(implicit ctx: Ctx): Seq[PathRef] = { + + // only consider files that have changed since last reformat + val toConsider = + if (scalafmtConfig.sig != configSig) input + else input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig)) + + if (toConsider.nonEmpty) { - private def reformatAction(toFormat: Seq[os.Path], - config: os.Path)(implicit ctx: Ctx) = { - val scalafmt = - Scalafmt + if (dryRun) { + ctx.log.info(s"Checking format of ${toConsider.size} Scala sources") + } else { + ctx.log.info(s"Formatting ${toConsider.size} Scala sources") + } + + val scalafmt = Scalafmt .create(this.getClass.getClassLoader) .withRespectVersion(false) - val configPath = - if (os.exists(config)) - config.toNIO - else - JPaths.get(getClass.getResource("default.scalafmt.conf").toURI) + val configPath = + if (os.exists(scalafmtConfig.path)) + scalafmtConfig.path.toNIO + else + JPaths.get(getClass.getResource("default.scalafmt.conf").toURI) + + // keeps track of files that are misformatted + val misformatted = mutable.ListBuffer.empty[PathRef] + + def markFormatted(path: PathRef) = { + val updRef = PathRef(path.path) + reformatted += updRef.path -> updRef.sig + } + + toConsider.foreach { pathToFormat => + val code = os.read(pathToFormat.path) + val formattedCode = scalafmt.format(configPath, pathToFormat.path.toNIO, code) + + if (code != formattedCode) { + misformatted += pathToFormat + if (!dryRun) { + os.write.over(pathToFormat.path, formattedCode) + markFormatted(pathToFormat) + } + } else { + markFormatted(pathToFormat) + } - toFormat.foreach { pathToFormat => - val code = os.read(pathToFormat) - val formatteCode = scalafmt.format(configPath, pathToFormat.toNIO, code) - os.write.over(pathToFormat, formatteCode) + } + configSig = scalafmtConfig.sig + misformatted.toList + } else { + ctx.log.info(s"Everything is formatted already") + Nil } } + } |