diff options
Diffstat (limited to 'main/src')
-rw-r--r-- | main/src/mill/modules/Assembly.scala | 127 | ||||
-rw-r--r-- | main/src/mill/modules/Jvm.scala | 80 |
2 files changed, 162 insertions, 45 deletions
diff --git a/main/src/mill/modules/Assembly.scala b/main/src/mill/modules/Assembly.scala new file mode 100644 index 00000000..b7b91248 --- /dev/null +++ b/main/src/mill/modules/Assembly.scala @@ -0,0 +1,127 @@ +package mill.modules + +import java.io.InputStream +import java.util.jar.JarFile +import java.util.regex.Pattern + +import ammonite.ops._ +import geny.Generator +import mill.Agg + +import scala.collection.JavaConverters._ + +object Assembly { + + val defaultRules: Seq[Rule] = Seq( + Rule.Append("reference.conf"), + Rule.Exclude(JarFile.MANIFEST_NAME), + Rule.ExcludePattern(".*\\.[sS][fF]"), + Rule.ExcludePattern(".*\\.[dD][sS][aA]"), + Rule.ExcludePattern(".*\\.[rR][sS][aA]") + ) + + sealed trait Rule extends Product with Serializable + object Rule { + case class Append(path: String) extends Rule + + object AppendPattern { + def apply(pattern: String): AppendPattern = AppendPattern(Pattern.compile(pattern)) + } + case class AppendPattern(pattern: Pattern) extends Rule + + case class Exclude(path: String) extends Rule + + object ExcludePattern { + def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern)) + } + case class ExcludePattern(pattern: Pattern) extends Rule + } + + def groupAssemblyEntries(inputPaths: Agg[Path], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = { + val rulesMap = assemblyRules.collect { + case r@Rule.Append(path) => path -> r + case r@Rule.Exclude(path) => path -> r + }.toMap + + val appendPatterns = assemblyRules.collect { + case Rule.AppendPattern(pattern) => pattern.asPredicate().test(_) + } + + val excludePatterns = assemblyRules.collect { + case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_) + } + + classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) { + case (entries, entry) => + val mapping = entry.mapping + + rulesMap.get(mapping) match { + case Some(_: Assembly.Rule.Exclude) => + entries + case Some(_: Assembly.Rule.Append) => + val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry) + entries + (mapping -> newEntry) + + case _ if excludePatterns.exists(_(mapping)) => + entries + case _ if appendPatterns.exists(_(mapping)) => + val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry) + entries + (mapping -> newEntry) + + case _ if !entries.contains(mapping) => + entries + (mapping -> WriteOnceEntry(entry)) + case _ => + entries + } + } + } + + private def classpathIterator(inputPaths: Agg[Path]): Generator[AssemblyEntry] = { + Generator.from(inputPaths) + .filter(exists) + .flatMap { + p => + if (p.isFile) { + val jf = new JarFile(p.toIO) + Generator.from( + for(entry <- jf.entries().asScala if !entry.isDirectory) + yield JarFileEntry(entry.getName, () => jf.getInputStream(entry)) + ) + } + else { + ls.rec.iter(p) + .filter(_.isFile) + .map(sub => PathEntry(sub.relativeTo(p).toString, sub)) + } + } + } +} + +private[modules] sealed trait GroupedEntry { + def append(entry: AssemblyEntry): GroupedEntry +} + +private[modules] object AppendEntry { + val empty: AppendEntry = AppendEntry(Nil) +} + +private[modules] case class AppendEntry(entries: List[AssemblyEntry]) extends GroupedEntry { + def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries) +} + +private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry { + def append(entry: AssemblyEntry): GroupedEntry = this +} + +private[this] sealed trait AssemblyEntry { + def mapping: String + def inputStream: InputStream +} + +private[this] case class PathEntry(mapping: String, path: Path) extends AssemblyEntry { + def inputStream: InputStream = read.getInputStream(path) +} + +private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry { + def inputStream: InputStream = getIs() +} diff --git a/main/src/mill/modules/Jvm.scala b/main/src/mill/modules/Jvm.scala index 1a28189f..be683e4a 100644 --- a/main/src/mill/modules/Jvm.scala +++ b/main/src/mill/modules/Jvm.scala @@ -1,14 +1,14 @@ package mill.modules -import java.io.{ByteArrayInputStream, File, FileOutputStream} +import java.io._ import java.lang.reflect.Modifier -import java.net.{URI, URLClassLoader} -import java.nio.file.{FileSystems, Files, OpenOption, StandardOpenOption} +import java.net.URI +import java.nio.file.{FileSystems, Files, StandardOpenOption} import java.nio.file.attribute.PosixFilePermission +import java.util.Collections import java.util.jar.{JarEntry, JarFile, JarOutputStream} import ammonite.ops._ -import ammonite.util.Util import coursier.{Cache, Dependency, Fetch, Repository, Resolution} import geny.Generator import mill.main.client.InputPumper @@ -17,7 +17,7 @@ import mill.util.{Ctx, IO} import mill.util.Loose.Agg import scala.collection.mutable - +import scala.collection.JavaConverters._ object Jvm { @@ -232,18 +232,20 @@ object Jvm { PathRef(outputPath) } - def newOutputStream(p: java.nio.file.Path) = Files.newOutputStream( - p, - StandardOpenOption.TRUNCATE_EXISTING, - StandardOpenOption.CREATE - ) + def newOutputStream(p: java.nio.file.Path, append: Boolean = false) = { + val options = + if(append) Seq(StandardOpenOption.APPEND) + else Seq(StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE) + Files.newOutputStream(p, options:_*) + } def createAssembly(inputPaths: Agg[Path], mainClass: Option[String] = None, prependShellScript: String = "", base: Option[Path] = None, - isWin: Boolean = scala.util.Properties.isWin) - (implicit ctx: Ctx.Dest) = { + assemblyRules: Seq[Assembly.Rule] = Assembly.defaultRules) + (implicit ctx: Ctx.Dest with Ctx.Log): PathRef = { + val tmp = ctx.dest / "out-tmp.jar" val baseUri = "jar:" + tmp.toIO.getCanonicalFile.toURI.toASCIIString @@ -263,20 +265,22 @@ object Jvm { manifest.write(manifestOut) manifestOut.close() - def isSignatureFile(mapping: String): Boolean = - Set("sf", "rsa", "dsa").exists(ext => mapping.toLowerCase.endsWith(s".$ext")) - - for(v <- classpathIterator(inputPaths)){ - val (file, mapping) = v - val p = zipFs.getPath(mapping) - if (p.getParent != null) Files.createDirectories(p.getParent) - if (!isSignatureFile(mapping)) { - val outputStream = newOutputStream(p) - IO.stream(file, outputStream) - outputStream.close() + Assembly.groupAssemblyEntries(inputPaths, assemblyRules).view + .map { + case (mapping, aggregate) => + zipFs.getPath(mapping) -> aggregate } - file.close() - } + .foreach { + case (path, AppendEntry(entries)) => + val concatenated = new SequenceInputStream( + Collections.enumeration(entries.map(_.inputStream).asJava)) + writeEntry(path, concatenated, append = Files.exists(path)) + case (path, WriteOnceEntry(entry)) => + if (Files.notExists(path)) { + writeEntry(path, entry.inputStream, append = false) + } + } + zipFs.close() val output = ctx.dest / "out.jar" @@ -301,28 +305,14 @@ object Jvm { PathRef(output) } + private def writeEntry(p: java.nio.file.Path, is: InputStream, append: Boolean): Unit = { + if (p.getParent != null) Files.createDirectories(p.getParent) + val outputStream = newOutputStream(p, append) - def classpathIterator(inputPaths: Agg[Path]) = { - Generator.from(inputPaths) - .filter(exists) - .flatMap{ - p => - if (p.isFile) { - val jf = new JarFile(p.toIO) - import collection.JavaConverters._ - Generator.selfClosing(( - for(entry <- jf.entries().asScala if !entry.isDirectory) - yield (jf.getInputStream(entry), entry.getName), - () => jf.close() - )) - } - else { - ls.rec.iter(p) - .filter(_.isFile) - .map(sub => read.getInputStream(sub) -> sub.relativeTo(p).toString) - } - } + IO.stream(is, outputStream) + outputStream.close() + is.close() } def universalScript(shellCommands: String, |