aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/spray/boilerplate/BoilerplatePlugin.scala
blob: 9394be807dd1f576f6fbfba913161687f68965a3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/*
 * sbt-boilerplate is distributed under the 2-Clause BSD license. See the LICENSE file in the root
 * of the repository.
 *
 * Copyright (c) 2012 Johannes Rudolph
 */
package spray.boilerplate

import sbt._
import Keys._

object BoilerplatePlugin extends AutoPlugin {
  override def trigger: PluginTrigger = noTrigger
  override def `requires`: Plugins = empty

  object autoImport {
    val boilerplateGenerate = taskKey[Seq[File]]("Generates boilerplate from template files")
    val boilerplateSource = settingKey[File]("Default directory containing boilerplate template sources.")
  }

  import autoImport._

  override def projectSettings: Seq[Def.Setting[_]] =
    inConfig(Compile)(rawBoilerplateSettings) ++ inConfig(Test)(rawBoilerplateSettings)

  private def rawBoilerplateSettings: Seq[Setting[_]] = {
    val inputFilter = "*.template"
    Seq(
      boilerplateSource := sourceDirectory.value / "boilerplate",
      watchSources in Defaults.ConfigGlobal ++= ((boilerplateSource.value ** inputFilter) --- (boilerplateSource.value ** excludeFilter.value ** inputFilter)).get,
      boilerplateGenerate := generateFromTemplates(streams.value, boilerplateSource.value, sourceManaged.value),
      mappings in packageSrc ++= managedSources.value pair (Path.relativeTo(sourceManaged.value) | Path.flat),
      sourceGenerators <+= boilerplateGenerate)
  }

  def generateFromTemplates(streams: TaskStreams, sourceDir: File, targetDir: File): Seq[File] = {
    val files = sourceDir ** "*.template"

    def changeExtension(f: File): File = {
      val (_, name) = f.getName.reverse.span(_ != '.')
      val strippedName = name.drop(1).reverse.toString
      val newName =
        if (!strippedName.contains(".")) s"$strippedName.scala"
        else strippedName
      new File(f.getParent, newName)
    }

    val mapping = (files pair rebase(sourceDir, targetDir)).map {
      case (orig, target)  (orig, changeExtension(target))
    }

    mapping foreach {
      case (templateFile, target) 
        if (templateFile.lastModified > target.lastModified) {
          streams.log.info("Generating '%s'" format target.getName)
          val template = IO.read(templateFile)
          IO.write(target, Generator.generateFromTemplate(template, 22))
        } else
          streams.log.debug("Template '%s' older than target. Ignoring." format templateFile.getName)
    }

    mapping.map(_._2)
  }
}