summaryrefslogtreecommitdiff
path: root/scalalib/src/scalafmt/ScalafmtWorker.scala
diff options
context:
space:
mode:
Diffstat (limited to 'scalalib/src/scalafmt/ScalafmtWorker.scala')
-rw-r--r--scalalib/src/scalafmt/ScalafmtWorker.scala99
1 files changed, 70 insertions, 29 deletions
diff --git a/scalalib/src/scalafmt/ScalafmtWorker.scala b/scalalib/src/scalafmt/ScalafmtWorker.scala
index f9c7e9b4..25b58f24 100644
--- a/scalalib/src/scalafmt/ScalafmtWorker.scala
+++ b/scalalib/src/scalafmt/ScalafmtWorker.scala
@@ -8,6 +8,7 @@ import mill.api.Ctx
import org.scalafmt.interfaces.Scalafmt
import scala.collection.mutable
+import mill.api.Result
object ScalafmtWorkerModule extends ExternalModule {
def worker: Worker[ScalafmtWorker] = T.worker { new ScalafmtWorker() }
@@ -21,44 +22,84 @@ private[scalafmt] class ScalafmtWorker {
def reformat(input: Seq[PathRef],
scalafmtConfig: PathRef)(implicit ctx: Ctx): Unit = {
- val toFormat =
- if (scalafmtConfig.sig != configSig) input
- else
- input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig))
-
- if (toFormat.nonEmpty) {
- ctx.log.info(s"Formatting ${toFormat.size} Scala sources")
- reformatAction(toFormat.map(_.path),
- scalafmtConfig.path)
- reformatted ++= toFormat.map { ref =>
- val updRef = PathRef(ref.path)
- updRef.path -> updRef.sig
- }
- configSig = scalafmtConfig.sig
+ reformatAction(input, scalafmtConfig, dryRun = false)
+ }
+
+ def checkFormat(input: Seq[PathRef],
+ scalafmtConfig: PathRef)(implicit ctx: Ctx): Result[Unit] = {
+
+ val misformatted = reformatAction(input, scalafmtConfig, dryRun = true)
+ if (misformatted.isEmpty) {
+ Result.Success(())
} else {
- ctx.log.info(s"Everything is formatted already")
+ val out = ctx.log.outputStream
+ for (u <- misformatted) {
+ out.println(u.path.toString)
+ }
+ Result.Failure(s"Found ${misformatted.length} misformatted files")
}
}
- private val cliFlags = Seq("--non-interactive", "--quiet")
+ // run scalafmt over input files and return any files that changed
+ // (only save changes to files if dryRun is false)
+ private def reformatAction(
+ input: Seq[PathRef],
+ scalafmtConfig: PathRef,
+ dryRun: Boolean
+ )(implicit ctx: Ctx): Seq[PathRef] = {
+
+ // only consider files that have changed since last reformat
+ val toConsider =
+ if (scalafmtConfig.sig != configSig) input
+ else input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig))
+
+ if (toConsider.nonEmpty) {
- private def reformatAction(toFormat: Seq[os.Path],
- config: os.Path)(implicit ctx: Ctx) = {
- val scalafmt =
- Scalafmt
+ if (dryRun) {
+ ctx.log.info(s"Checking format of ${toConsider.size} Scala sources")
+ } else {
+ ctx.log.info(s"Formatting ${toConsider.size} Scala sources")
+ }
+
+ val scalafmt = Scalafmt
.create(this.getClass.getClassLoader)
.withRespectVersion(false)
- val configPath =
- if (os.exists(config))
- config.toNIO
- else
- JPaths.get(getClass.getResource("default.scalafmt.conf").toURI)
+ val configPath =
+ if (os.exists(scalafmtConfig.path))
+ scalafmtConfig.path.toNIO
+ else
+ JPaths.get(getClass.getResource("default.scalafmt.conf").toURI)
+
+ // keeps track of files that are misformatted
+ val misformatted = mutable.ListBuffer.empty[PathRef]
+
+ def markFormatted(path: PathRef) = {
+ val updRef = PathRef(path.path)
+ reformatted += updRef.path -> updRef.sig
+ }
+
+ toConsider.foreach { pathToFormat =>
+ val code = os.read(pathToFormat.path)
+ val formattedCode = scalafmt.format(configPath, pathToFormat.path.toNIO, code)
+
+ if (code != formattedCode) {
+ misformatted += pathToFormat
+ if (!dryRun) {
+ os.write.over(pathToFormat.path, formattedCode)
+ markFormatted(pathToFormat)
+ }
+ } else {
+ markFormatted(pathToFormat)
+ }
- toFormat.foreach { pathToFormat =>
- val code = os.read(pathToFormat)
- val formatteCode = scalafmt.format(configPath, pathToFormat.toNIO, code)
- os.write.over(pathToFormat, formatteCode)
+ }
+ configSig = scalafmtConfig.sig
+ misformatted.toList
+ } else {
+ ctx.log.info(s"Everything is formatted already")
+ Nil
}
}
+
}