summaryrefslogtreecommitdiff
path: root/scalalib/src/scalafmt/ScalafmtWorker.scala
blob: 25b58f24a5f3c8a107bf0f6d5c06b530a0f86eb5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package mill.scalalib.scalafmt

import java.nio.file.{Paths => JPaths}

import mill._
import mill.define.{Discover, ExternalModule, Worker}
import mill.api.Ctx
import org.scalafmt.interfaces.Scalafmt

import scala.collection.mutable
import mill.api.Result

object ScalafmtWorkerModule extends ExternalModule {
  def worker: Worker[ScalafmtWorker] = T.worker { new ScalafmtWorker() }

  lazy val millDiscover = Discover[this.type]
}

private[scalafmt] class ScalafmtWorker {
  private val reformatted: mutable.Map[os.Path, Int] = mutable.Map.empty
  private var configSig: Int = 0

  def reformat(input: Seq[PathRef],
               scalafmtConfig: PathRef)(implicit ctx: Ctx): Unit = {
    reformatAction(input, scalafmtConfig, dryRun = false)
  }

  def checkFormat(input: Seq[PathRef],
                  scalafmtConfig: PathRef)(implicit ctx: Ctx): Result[Unit] = {

    val misformatted = reformatAction(input, scalafmtConfig, dryRun = true)
    if (misformatted.isEmpty) {
      Result.Success(())
    } else {
      val out = ctx.log.outputStream
      for (u <- misformatted) {
        out.println(u.path.toString)
      }
      Result.Failure(s"Found ${misformatted.length} misformatted files")
    }
  }

  // run scalafmt over input files and return any files that changed
  // (only save changes to files if dryRun is false)
  private def reformatAction(
    input: Seq[PathRef],
    scalafmtConfig: PathRef,
    dryRun: Boolean
  )(implicit ctx: Ctx): Seq[PathRef] = {

    // only consider files that have changed since last reformat
    val toConsider =
      if (scalafmtConfig.sig != configSig) input
      else input.filterNot(ref => reformatted.get(ref.path).contains(ref.sig))

    if (toConsider.nonEmpty) {

      if (dryRun) {
        ctx.log.info(s"Checking format of ${toConsider.size} Scala sources")
      } else {
        ctx.log.info(s"Formatting ${toConsider.size} Scala sources")
      }

      val scalafmt = Scalafmt
        .create(this.getClass.getClassLoader)
        .withRespectVersion(false)

      val configPath =
        if (os.exists(scalafmtConfig.path))
          scalafmtConfig.path.toNIO
        else
          JPaths.get(getClass.getResource("default.scalafmt.conf").toURI)

      // keeps track of files that are misformatted
      val misformatted = mutable.ListBuffer.empty[PathRef]

      def markFormatted(path: PathRef) = {
        val updRef = PathRef(path.path)
        reformatted += updRef.path -> updRef.sig
      }

      toConsider.foreach { pathToFormat =>
        val code = os.read(pathToFormat.path)
        val formattedCode = scalafmt.format(configPath, pathToFormat.path.toNIO, code)

        if (code != formattedCode) {
          misformatted += pathToFormat
          if (!dryRun) {
            os.write.over(pathToFormat.path, formattedCode)
            markFormatted(pathToFormat)
          }
        } else {
          markFormatted(pathToFormat)
        }

      }
      configSig = scalafmtConfig.sig
      misformatted.toList
    } else {
      ctx.log.info(s"Everything is formatted already")
      Nil
    }
  }

}