summaryrefslogtreecommitdiff
path: root/main/src/modules/Assembly.scala
blob: 141bc226810055e7c60fce8d35a240e4f9dacd1e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package mill.modules

import java.io.InputStream
import java.util.jar.JarFile
import java.util.regex.Pattern

import geny.Generator
import mill.Agg

import scala.collection.JavaConverters._

object Assembly {

  val defaultRules: Seq[Rule] = Seq(
    Rule.Append("reference.conf"),
    Rule.Exclude(JarFile.MANIFEST_NAME),
    Rule.ExcludePattern(".*\\.[sS][fF]"),
    Rule.ExcludePattern(".*\\.[dD][sS][aA]"),
    Rule.ExcludePattern(".*\\.[rR][sS][aA]")
  )

  sealed trait Rule extends Product with Serializable
  object Rule {
    case class Append(path: String) extends Rule

    object AppendPattern {
      def apply(pattern: String): AppendPattern = AppendPattern(Pattern.compile(pattern))
    }
    case class AppendPattern(pattern: Pattern) extends Rule

    case class Exclude(path: String) extends Rule

    object ExcludePattern {
      def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern))
    }
    case class ExcludePattern(pattern: Pattern) extends Rule
  }

  def groupAssemblyEntries(inputPaths: Agg[os.Path], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = {
    val rulesMap = assemblyRules.collect {
      case r@Rule.Append(path) => path -> r
      case r@Rule.Exclude(path) => path -> r
    }.toMap

    val appendPatterns = assemblyRules.collect {
      case Rule.AppendPattern(pattern) => pattern.asPredicate().test(_)
    }

    val excludePatterns = assemblyRules.collect {
      case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_)
    }

    classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) {
      case (entries, entry) =>
        val mapping = entry.mapping

        rulesMap.get(mapping) match {
          case Some(_: Assembly.Rule.Exclude) =>
            entries
          case Some(_: Assembly.Rule.Append) =>
            val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry)
            entries + (mapping -> newEntry)

          case _ if excludePatterns.exists(_(mapping)) =>
            entries
          case _ if appendPatterns.exists(_(mapping)) =>
            val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry)
            entries + (mapping -> newEntry)

          case _ if !entries.contains(mapping) =>
            entries + (mapping -> WriteOnceEntry(entry))
          case _ =>
            entries
        }
    }
  }

  private def classpathIterator(inputPaths: Agg[os.Path]): Generator[AssemblyEntry] = {
    Generator.from(inputPaths)
      .filter(os.exists)
      .flatMap {
        p =>
          if (os.isFile(p)) {
            val jf = new JarFile(p.toIO)
            Generator.from(
              for(entry <- jf.entries().asScala if !entry.isDirectory)
                yield JarFileEntry(entry.getName, () => jf.getInputStream(entry))
            )
          }
          else {
            os.walk.stream(p)
              .filter(os.isFile)
              .map(sub => PathEntry(sub.relativeTo(p).toString, sub))
          }
      }
  }
}

private[modules] sealed trait GroupedEntry {
  def append(entry: AssemblyEntry): GroupedEntry
}

private[modules] object AppendEntry {
  val empty: AppendEntry = AppendEntry(Nil)
}

private[modules] case class AppendEntry(entries: List[AssemblyEntry]) extends GroupedEntry {
  def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries)
}

private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry {
  def append(entry: AssemblyEntry): GroupedEntry = this
}

private[this] sealed trait AssemblyEntry {
  def mapping: String
  def inputStream: InputStream
}

private[this] case class PathEntry(mapping: String, path: os.Path) extends AssemblyEntry {
  def inputStream: InputStream = os.read.inputStream(path)
}

private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry {
  def inputStream: InputStream = getIs()
}