aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
blob: 08e3f670f57f6ea63feeb59b2b2a8a397acda394 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.api.python

import java.io.{File, DataInputStream, IOException}
import java.net.{Socket, SocketException, InetAddress}

import scala.collection.JavaConversions._

import org.apache.spark._

private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
    extends Logging {
  var daemon: Process = null
  val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
  var daemonPort: Int = 0

  def create(): Socket = {
    synchronized {
      // Start the daemon if it hasn't been started
      startDaemon()

      // Attempt to connect, restart and retry once if it fails
      try {
        new Socket(daemonHost, daemonPort)
      } catch {
        case exc: SocketException => {
          logWarning("Python daemon unexpectedly quit, attempting to restart")
          stopDaemon()
          startDaemon()
          new Socket(daemonHost, daemonPort)
        }
        case e => throw e
      }
    }
  }

  def stop() {
    stopDaemon()
  }

  private def startDaemon() {
    synchronized {
      // Is it already running?
      if (daemon != null) {
        return
      }

      try {
        // Create and start the daemon
        val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
        val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
        val workerEnv = pb.environment()
        workerEnv.putAll(envVars)
        val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
        workerEnv.put("PYTHONPATH", pythonPath)
        daemon = pb.start()

        // Redirect the stderr to ours
        new Thread("stderr reader for " + pythonExec) {
          override def run() {
            scala.util.control.Exception.ignoring(classOf[IOException]) {
              // FIXME HACK: We copy the stream on the level of bytes to
              // attempt to dodge encoding problems.
              val in = daemon.getErrorStream
              var buf = new Array[Byte](1024)
              var len = in.read(buf)
              while (len != -1) {
                System.err.write(buf, 0, len)
                len = in.read(buf)
              }
            }
          }
        }.start()

        val in = new DataInputStream(daemon.getInputStream)
        daemonPort = in.readInt()

        // Redirect further stdout output to our stderr
        new Thread("stdout reader for " + pythonExec) {
          override def run() {
            scala.util.control.Exception.ignoring(classOf[IOException]) {
              // FIXME HACK: We copy the stream on the level of bytes to
              // attempt to dodge encoding problems.
              var buf = new Array[Byte](1024)
              var len = in.read(buf)
              while (len != -1) {
                System.err.write(buf, 0, len)
                len = in.read(buf)
              }
            }
          }
        }.start()
      } catch {
        case e => {
          stopDaemon()
          throw e
        }
      }

      // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
      // detect our disappearance.
    }
  }

  private def stopDaemon() {
    synchronized {
      // Request shutdown of existing daemon by sending SIGTERM
      if (daemon != null) {
        daemon.destroy()
      }

      daemon = null
      daemonPort = 0
    }
  }
}