summaryrefslogtreecommitdiff
path: root/src/library/scala/sys/process/ProcessImpl.scala
blob: 2b7fcdeb73b6a24895c1e64b39b90749ffa5e6a8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
/*                     __                                               *\
**     ________ ___   / /  ___     Scala API                            **
**    / __/ __// _ | / /  / _ |    (c) 2003-2013, LAMP/EPFL             **
**  __\ \/ /__/ __ |/ /__/ __ |    http://scala-lang.org/               **
** /____/\___/_/ |_/____/_/ | |                                         **
**                          |/                                          **
\*                                                                      */

package scala
package sys
package process

import processInternal._
import java.io.{ PipedInputStream, PipedOutputStream }

private[process] trait ProcessImpl {
  self: Process.type =>

  /** Runs provided code in a new Thread and returns the Thread instance. */
  private[process] object Spawn {
    def apply(f: => Unit): Thread = apply(f, daemon = false)
    def apply(f: => Unit, daemon: Boolean): Thread = {
      val thread = new Thread() { override def run() = { f } }
      thread.setDaemon(daemon)
      thread.start()
      thread
    }
  }
  private[process] object Future {
    def apply[T](f: => T): () => T = {
      val result = new SyncVar[Either[Throwable, T]]
      def run(): Unit =
        try result set Right(f)
        catch { case e: Exception => result set Left(e) }

      Spawn(run())

      () => result.get match {
        case Right(value)    => value
        case Left(exception) => throw exception
      }
    }
  }

  private[process] class AndProcess(
    a: ProcessBuilder,
    b: ProcessBuilder,
    io: ProcessIO
  ) extends SequentialProcess(a, b, io, _ == 0)

  private[process] class OrProcess(
    a: ProcessBuilder,
    b: ProcessBuilder,
    io: ProcessIO
  ) extends SequentialProcess(a, b, io, _ != 0)

  private[process] class ProcessSequence(
    a: ProcessBuilder,
    b: ProcessBuilder,
    io: ProcessIO
  ) extends SequentialProcess(a, b, io, _ => true)

  private[process] class SequentialProcess(
    a: ProcessBuilder,
    b: ProcessBuilder,
    io: ProcessIO,
    evaluateSecondProcess: Int => Boolean
  ) extends CompoundProcess {

    protected[this] override def runAndExitValue() = {
      val first = a.run(io)
      runInterruptible(first.exitValue())(first.destroy()) flatMap { codeA =>
        if (evaluateSecondProcess(codeA)) {
          val second = b.run(io)
          runInterruptible(second.exitValue())(second.destroy())
        }
        else Some(codeA)
      }
    }
  }

  private[process] abstract class BasicProcess extends Process {
    def start(): Unit
  }

  private[process] abstract class CompoundProcess extends BasicProcess {
    def destroy()   = destroyer()
    def exitValue() = getExitValue() getOrElse scala.sys.error("No exit code: process destroyed.")
    def start()     = getExitValue

    protected lazy val (getExitValue, destroyer) = {
      val code = new SyncVar[Option[Int]]()
      code set None
      val thread = Spawn(code set runAndExitValue())

      (
        Future { thread.join(); code.get },
        () => thread.interrupt()
      )
    }

    /** Start and block until the exit value is available and then return it in Some.  Return None if destroyed (use 'run')*/
    protected[this] def runAndExitValue(): Option[Int]

    protected[this] def runInterruptible[T](action: => T)(destroyImpl: => Unit): Option[T] = {
      try   Some(action)
      catch onInterrupt { destroyImpl; None }
    }
  }

  private[process] class PipedProcesses(a: ProcessBuilder, b: ProcessBuilder, defaultIO: ProcessIO, toError: Boolean) extends CompoundProcess {
    protected[this] override def runAndExitValue() = {
      val currentSource = new SyncVar[Option[InputStream]]
      val pipeOut       = new PipedOutputStream
      val source        = new PipeSource(currentSource, pipeOut, a.toString)
      source.start()

      val pipeIn      = new PipedInputStream(pipeOut)
      val currentSink = new SyncVar[Option[OutputStream]]
      val sink        = new PipeSink(pipeIn, currentSink, b.toString)
      sink.start()

      def handleOutOrError(fromOutput: InputStream) = currentSource put Some(fromOutput)

      val firstIO =
        if (toError)
          defaultIO.withError(handleOutOrError)
        else
          defaultIO.withOutput(handleOutOrError)
      val secondIO = defaultIO.withInput(toInput => currentSink put Some(toInput))

      val second = b.run(secondIO)
      val first = a.run(firstIO)
      try {
        runInterruptible {
          val exit1 = first.exitValue()
          currentSource put None
          currentSink put None
          val exit2 = second.exitValue()
          // Since file redirection (e.g. #>) is implemented as a piped process,
          // we ignore its exit value so cmd #> file doesn't always return 0.
          if (b.hasExitValue) exit2 else exit1
        } {
          first.destroy()
          second.destroy()
        }
      }
      finally {
        BasicIO close pipeIn
        BasicIO close pipeOut
      }
    }
  }

  private[process] abstract class PipeThread(isSink: Boolean, labelFn: () => String) extends Thread {
    def run(): Unit

    private[process] def runloop(src: InputStream, dst: OutputStream): Unit = {
      try     BasicIO.transferFully(src, dst)
      catch   ioFailure(ioHandler)
      finally BasicIO close {
        if (isSink) dst else src
      }
    }
    private def ioHandler(e: IOException) {
      println("I/O error " + e.getMessage + " for process: " + labelFn())
      e.printStackTrace()
    }
  }

  private[process] class PipeSource(
    currentSource: SyncVar[Option[InputStream]],
    pipe: PipedOutputStream,
    label: => String
  ) extends PipeThread(false, () => label) {

    final override def run(): Unit = currentSource.get match {
      case Some(source) =>
        try runloop(source, pipe)
        finally currentSource.unset()

        run()
      case None =>
        currentSource.unset()
        BasicIO close pipe
    }
  }
  private[process] class PipeSink(
    pipe: PipedInputStream,
    currentSink: SyncVar[Option[OutputStream]],
    label: => String
  ) extends PipeThread(true, () => label) {

    final override def run(): Unit = currentSink.get match {
      case Some(sink) =>
        try runloop(pipe, sink)
        finally currentSink.unset()

        run()
      case None =>
        currentSink.unset()
    }
  }

  /** A thin wrapper around a java.lang.Process.  `ioThreads` are the Threads created to do I/O.
  * The implementation of `exitValue` waits until these threads die before returning. */
  private[process] class DummyProcess(action: => Int) extends Process {
    private[this] val exitCode = Future(action)
    override def exitValue() = exitCode()
    override def destroy() { }
  }
  /** A thin wrapper around a java.lang.Process.  `outputThreads` are the Threads created to read from the
  * output and error streams of the process.  `inputThread` is the Thread created to write to the input stream of
  * the process.
  * The implementation of `exitValue` interrupts `inputThread` and then waits until all I/O threads die before
  * returning. */
  private[process] class SimpleProcess(p: JProcess, inputThread: Thread, outputThreads: List[Thread]) extends Process {
    override def exitValue() = {
      try p.waitFor()                   // wait for the process to terminate
      finally inputThread.interrupt()   // we interrupt the input thread to notify it that it can terminate
      outputThreads foreach (_.join())  // this ensures that all output is complete before returning (waitFor does not ensure this)

      p.exitValue()
    }
    override def destroy() = {
      try {
        outputThreads foreach (_.interrupt()) // on destroy, don't bother consuming any more output
        p.destroy()
      }
      finally inputThread.interrupt()
    }
  }
  private[process] final class ThreadProcess(thread: Thread, success: SyncVar[Boolean]) extends Process {
    override def exitValue() = {
      thread.join()
      if (success.get) 0 else 1
    }
    override def destroy() { thread.interrupt() }
  }
}