summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakob Odersky <jakob@inpher.io>2019-09-30 19:21:21 -0400
committerJakob Odersky <jakob@inpher.io>2019-11-20 12:13:52 -0500
commit5e062b7e9e50cc6a1dbb12291fbc2643a59a0210 (patch)
tree4f7bb7c37f6588381a79e2e7d055c921ab608703
parent2b5c2465544e318c225d748a3c73b244978ed98c (diff)
downloadmill-scalafmt-check.tar.gz
mill-scalafmt-check.tar.bz2
mill-scalafmt-check.zip
Add task to only check formatting of Scala filesscalafmt-check
-rw-r--r--docs/pages/2 - Configuring Mill.md5
-rw-r--r--scalalib/src/scalafmt/ScalafmtModule.scala20
-rw-r--r--scalalib/src/scalafmt/ScalafmtWorker.scala99
3 files changed, 93 insertions, 31 deletions
diff --git a/docs/pages/2 - Configuring Mill.md b/docs/pages/2 - Configuring Mill.md
index 97fd1a37..f0b03790 100644
--- a/docs/pages/2 - Configuring Mill.md
+++ b/docs/pages/2 - Configuring Mill.md
@@ -259,9 +259,10 @@ object foo extends ScalaModule with ScalafmtModule {
}
```
-Now you can reformat code with `mill foo.reformat` command.
+Now you can reformat code with `mill foo.reformat` command, or only check for misformatted files with `mill checkFormat`.
-You can also reformat your project's code globally with `mill mill.scalalib.scalafmt.ScalafmtModule/reformatAll __.sources` command.
+You can also reformat your project's code globally with `mill mill.scalalib.scalafmt.ScalafmtModule/reformatAll __.sources` command,
+or only check the code's format with `mill mill.scalalib.scalafmt.ScalafmtModule/checkFormatAll __.sources`.
It will reformat all sources that matches `__.sources` query.
If you add a `.scalafmt.conf` file at the root of you project, it will be used
diff --git a/scalalib/src/scalafmt/ScalafmtModule.scala b/scalalib/src/scalafmt/ScalafmtModule.scala
index ea254e6d..ec67ff18 100644
--- a/scalalib/src/scalafmt/ScalafmtModule.scala
+++ b/scalalib/src/scalafmt/ScalafmtModule.scala
@@ -15,6 +15,15 @@ trait ScalafmtModule extends JavaModule {
)
}
+ def checkFormat(): Command[Unit] = T.command {
+ ScalafmtWorkerModule
+ .worker()
+ .checkFormat(
+ filesToFormat(sources()),
+ scalafmtConfig().head
+ )
+ }
+
def scalafmtConfig: Sources = T.sources(os.pwd / ".scalafmt.conf")
protected def filesToFormat(sources: Seq[PathRef]) = {
@@ -39,6 +48,17 @@ object ScalafmtModule extends ExternalModule with ScalafmtModule {
)
}
+ def checkFormatAll(sources: mill.main.Tasks[Seq[PathRef]]): Command[Unit] =
+ T.command {
+ val files = Task.sequence(sources.value)().flatMap(filesToFormat)
+ ScalafmtWorkerModule
+ .worker()
+ .checkFormat(
+ files,
+ scalafmtConfig().head
+ )
+ }
+
implicit def millScoptTargetReads[T] = new mill.main.Tasks.Scopt[T]()
lazy val millDiscover = Discover[this.type]
diff --git a/scalalib/src/scalafmt/ScalafmtWorker.scala b/scalalib/src/scalafmt/ScalafmtWorker.scala
index f9c7e9b4..25b58f24 100644
--- a/scalalib/src/scalafmt/ScalafmtWorker.scala
+++ b/scalalib/src/scalafmt/ScalafmtWorker.scala
@@ -8,6 +8,7 @@ import mill.api.Ctx
import org.scalafmt.interfaces.Scalafmt
import scala.collection.mutable
+import mill.api.Result
object ScalafmtWorkerModule extends ExternalModule {
def worker: Worker[ScalafmtWorker] = T.worker { new ScalafmtWorker() }
@@ -21,44 +22,84 @@ private[scalafmt] class ScalafmtWorker {
def reformat(input: Seq[PathRef],
scalafmtConfig: PathRef)(implicit ctx: Ctx): Unit = {
- val toFormat =
- if (scalafmtConfig.sig != configSig) input
- else
- input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig))
-
- if (toFormat.nonEmpty) {
- ctx.log.info(s"Formatting ${toFormat.size} Scala sources")
- reformatAction(toFormat.map(_.path),
- scalafmtConfig.path)
- reformatted ++= toFormat.map { ref =>
- val updRef = PathRef(ref.path)
- updRef.path -> updRef.sig
- }
- configSig = scalafmtConfig.sig
+ reformatAction(input, scalafmtConfig, dryRun = false)
+ }
+
+ def checkFormat(input: Seq[PathRef],
+ scalafmtConfig: PathRef)(implicit ctx: Ctx): Result[Unit] = {
+
+ val misformatted = reformatAction(input, scalafmtConfig, dryRun = true)
+ if (misformatted.isEmpty) {
+ Result.Success(())
} else {
- ctx.log.info(s"Everything is formatted already")
+ val out = ctx.log.outputStream
+ for (u <- misformatted) {
+ out.println(u.path.toString)
+ }
+ Result.Failure(s"Found ${misformatted.length} misformatted files")
}
}
- private val cliFlags = Seq("--non-interactive", "--quiet")
+ // run scalafmt over input files and return any files that changed
+ // (only save changes to files if dryRun is false)
+ private def reformatAction(
+ input: Seq[PathRef],
+ scalafmtConfig: PathRef,
+ dryRun: Boolean
+ )(implicit ctx: Ctx): Seq[PathRef] = {
+
+ // only consider files that have changed since last reformat
+ val toConsider =
+ if (scalafmtConfig.sig != configSig) input
+ else input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig))
+
+ if (toConsider.nonEmpty) {
- private def reformatAction(toFormat: Seq[os.Path],
- config: os.Path)(implicit ctx: Ctx) = {
- val scalafmt =
- Scalafmt
+ if (dryRun) {
+ ctx.log.info(s"Checking format of ${toConsider.size} Scala sources")
+ } else {
+ ctx.log.info(s"Formatting ${toConsider.size} Scala sources")
+ }
+
+ val scalafmt = Scalafmt
.create(this.getClass.getClassLoader)
.withRespectVersion(false)
- val configPath =
- if (os.exists(config))
- config.toNIO
- else
- JPaths.get(getClass.getResource("default.scalafmt.conf").toURI)
+ val configPath =
+ if (os.exists(scalafmtConfig.path))
+ scalafmtConfig.path.toNIO
+ else
+ JPaths.get(getClass.getResource("default.scalafmt.conf").toURI)
+
+ // keeps track of files that are misformatted
+ val misformatted = mutable.ListBuffer.empty[PathRef]
+
+ def markFormatted(path: PathRef) = {
+ val updRef = PathRef(path.path)
+ reformatted += updRef.path -> updRef.sig
+ }
+
+ toConsider.foreach { pathToFormat =>
+ val code = os.read(pathToFormat.path)
+ val formattedCode = scalafmt.format(configPath, pathToFormat.path.toNIO, code)
+
+ if (code != formattedCode) {
+ misformatted += pathToFormat
+ if (!dryRun) {
+ os.write.over(pathToFormat.path, formattedCode)
+ markFormatted(pathToFormat)
+ }
+ } else {
+ markFormatted(pathToFormat)
+ }
- toFormat.foreach { pathToFormat =>
- val code = os.read(pathToFormat)
- val formatteCode = scalafmt.format(configPath, pathToFormat.toNIO, code)
- os.write.over(pathToFormat, formatteCode)
+ }
+ configSig = scalafmtConfig.sig
+ misformatted.toList
+ } else {
+ ctx.log.info(s"Everything is formatted already")
+ Nil
}
}
+
}