aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf17
-rwxr-xr-xpython/examples/als.py22
-rwxr-xr-x[-rw-r--r--]python/examples/kmeans.py20
-rwxr-xr-xpython/examples/logistic_regression.py71
-rwxr-xr-x[-rw-r--r--]python/examples/pi.py20
-rwxr-xr-x[-rw-r--r--]python/examples/transitive_closure.py24
-rwxr-xr-x[-rw-r--r--]python/examples/wordcount.py20
-rw-r--r--python/pyspark/accumulators.py17
-rw-r--r--python/pyspark/broadcast.py17
-rw-r--r--python/pyspark/context.py28
-rw-r--r--python/pyspark/daemon.py181
-rw-r--r--python/pyspark/files.py17
-rw-r--r--python/pyspark/java_gateway.py17
-rw-r--r--python/pyspark/join.py20
-rw-r--r--python/pyspark/rdd.py105
-rw-r--r--python/pyspark/serializers.py21
-rw-r--r--python/pyspark/shell.py17
-rw-r--r--python/pyspark/tests.py60
-rw-r--r--python/pyspark/worker.py72
-rwxr-xr-xpython/run-tests23
20 files changed, 673 insertions, 116 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 45102cd9fe..d5d5aa5454 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -1,5 +1,22 @@
[epydoc] # Epydoc section marker (required by ConfigParser)
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
# Information about the project.
name: PySpark
url: http://spark-project.org
diff --git a/python/examples/als.py b/python/examples/als.py
index 010f80097f..a77dfb2577 100755
--- a/python/examples/als.py
+++ b/python/examples/als.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
This example requires numpy (http://www.numpy.org/)
"""
@@ -31,8 +48,7 @@ def update(i, vec, mat, ratings):
if __name__ == "__main__":
if len(sys.argv) < 2:
- print >> sys.stderr, \
- "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
+ print >> sys.stderr, "Usage: als <master> <M> <U> <F> <iters> <slices>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
@@ -67,5 +83,5 @@ if __name__ == "__main__":
usb = sc.broadcast(us)
error = rmse(R, ms, us)
- print "Iteration %d:" % i
+ print "Iteration %d:" % i
print "\nRMSE: %5.4f\n" % error
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
index 72cf9f88c6..ba31af92fc 100644..100755
--- a/python/examples/kmeans.py
+++ b/python/examples/kmeans.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
This example requires numpy (http://www.numpy.org/)
"""
@@ -24,8 +41,7 @@ def closestPoint(p, centers):
if __name__ == "__main__":
if len(sys.argv) < 5:
- print >> sys.stderr, \
- "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ print >> sys.stderr, "Usage: kmeans <master> <file> <k> <convergeDist>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonKMeans")
lines = sc.textFile(sys.argv[2])
diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py
index f13698a86f..1117dea538 100755
--- a/python/examples/logistic_regression.py
+++ b/python/examples/logistic_regression.py
@@ -1,5 +1,23 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
-This example requires numpy (http://www.numpy.org/)
+A logistic regression implementation that uses NumPy (http://www.numpy.org) to act on batches
+of input data using efficient matrix operations.
"""
from collections import namedtuple
from math import exp
@@ -10,48 +28,45 @@ import numpy as np
from pyspark import SparkContext
-N = 100000 # Number of data points
D = 10 # Number of dimensions
-R = 0.7 # Scaling factor
-ITERATIONS = 5
-np.random.seed(42)
-
-
-DataPoint = namedtuple("DataPoint", ['x', 'y'])
-from lr import DataPoint # So that DataPoint is properly serialized
-def generateData():
- def generatePoint(i):
- y = -1 if i % 2 == 0 else 1
- x = np.random.normal(size=D) + (y * R)
- return DataPoint(x, y)
- return [generatePoint(i) for i in range(N)]
-
+# Read a batch of points from the input file into a NumPy matrix object. We operate on batches to
+# make further computations faster.
+# The data file contains lines of the form <label> <x1> <x2> ... <xD>. We load each block of these
+# into a NumPy array of size numLines * (D + 1) and pull out column 0 vs the others in gradient().
+def readPointBatch(iterator):
+ strs = list(iterator)
+ matrix = np.zeros((len(strs), D + 1))
+ for i in xrange(len(strs)):
+ matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ')
+ return [matrix]
if __name__ == "__main__":
- if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonLR <master> [<slices>]"
+ if len(sys.argv) != 4:
+ print >> sys.stderr, "Usage: logistic_regression <master> <file> <iters>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
- slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
- points = sc.parallelize(generateData(), slices).cache()
+ points = sc.textFile(sys.argv[2]).mapPartitions(readPointBatch).cache()
+ iterations = int(sys.argv[3])
# Initialize w to a random value
w = 2 * np.random.ranf(size=D) - 1
print "Initial w: " + str(w)
+ # Compute logistic regression gradient for a matrix of data points
+ def gradient(matrix, w):
+ Y = matrix[:,0] # point labels (first column of input file)
+ X = matrix[:,1:] # point coordinates
+ # For each point (x, y), compute gradient function, then sum these up
+ return ((1.0 / (1.0 + np.exp(-Y * X.dot(w))) - 1.0) * Y * X.T).sum(1)
+
def add(x, y):
x += y
return x
- for i in range(1, ITERATIONS + 1):
- print "On iteration %i" % i
-
- gradient = points.map(lambda p:
- (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
- ).reduce(add)
- w -= gradient
+ for i in range(iterations):
+ print "On iteration %i" % (i + 1)
+ w -= points.map(lambda m: gradient(m, w)).reduce(add)
print "Final w: " + str(w)
diff --git a/python/examples/pi.py b/python/examples/pi.py
index 127cba029b..ab0645fc2f 100644..100755
--- a/python/examples/pi.py
+++ b/python/examples/pi.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from random import random
from operator import add
@@ -7,8 +24,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonPi <master> [<slices>]"
+ print >> sys.stderr, "Usage: pi <master> [<slices>]"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonPi")
slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py
index 73f7f8fbaf..744cce6651 100644..100755
--- a/python/examples/transitive_closure.py
+++ b/python/examples/transitive_closure.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from random import Random
@@ -20,11 +37,10 @@ def generateGraph():
if __name__ == "__main__":
if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonTC <master> [<slices>]"
+ print >> sys.stderr, "Usage: transitive_closure <master> [<slices>]"
exit(-1)
- sc = SparkContext(sys.argv[1], "PythonTC")
- slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ sc = SparkContext(sys.argv[1], "PythonTransitiveClosure")
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
tc = sc.parallelize(generateGraph(), slices).cache()
# Linear transitive closure: each round grows paths by one edge,
diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py
index 857160624b..a6de22766a 100644..100755
--- a/python/examples/wordcount.py
+++ b/python/examples/wordcount.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from operator import add
@@ -6,8 +23,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) < 3:
- print >> sys.stderr, \
- "Usage: PythonWordCount <master> <file>"
+ print >> sys.stderr, "Usage: wordcount <master> <file>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonWordCount")
lines = sc.textFile(sys.argv[2], 1)
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 3e9d7d36da..d367f91967 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index def810dd46..dfdaba274f 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 657fe6f989..c2b49ff37a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
import shutil
import sys
@@ -124,14 +141,21 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD.
+
+ >>> sc.parallelize(range(5), 5).glom().collect()
+ [[0], [1], [2], [3], [4]]
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- if self.batchSize != 1:
- c = batched(c, self.batchSize)
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = min(len(c) // numSlices, self.batchSize)
+ if batchSize > 1:
+ c = batched(c, batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
new file mode 100644
index 0000000000..eb18ec08c9
--- /dev/null
+++ b/python/pyspark/daemon.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import signal
+import socket
+import sys
+import traceback
+import multiprocessing
+from ctypes import c_bool
+from errno import EINTR, ECHILD
+from socket import AF_INET, SOCK_STREAM, SOMAXCONN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from pyspark.worker import main as worker_main
+from pyspark.serializers import write_int
+
+try:
+ POOLSIZE = multiprocessing.cpu_count()
+except NotImplementedError:
+ POOLSIZE = 4
+
+exit_flag = multiprocessing.Value(c_bool, False)
+
+
+def should_exit():
+ global exit_flag
+ return exit_flag.value
+
+
+def compute_real_exit_code(exit_code):
+ # SystemExit's code can be integer or string, but os._exit only accepts integers
+ import numbers
+ if isinstance(exit_code, numbers.Integral):
+ return exit_code
+ else:
+ return 1
+
+
+def worker(listen_sock):
+ # Redirect stdout to stderr
+ os.dup2(2, 1)
+ sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
+
+ # Manager sends SIGHUP to request termination of workers in the pool
+ def handle_sighup(*args):
+ assert should_exit()
+ signal.signal(SIGHUP, handle_sighup)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ pid = status = None
+ try:
+ while (pid, status) != (0, 0):
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except EnvironmentError as err:
+ if err.errno == EINTR:
+ # retry
+ handle_sigchld()
+ elif err.errno != ECHILD:
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Handle clients
+ while not should_exit():
+ # Wait until a client arrives or we have to exit
+ sock = None
+ while not should_exit() and sock is None:
+ try:
+ sock, addr = listen_sock.accept()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ raise
+
+ if sock is not None:
+ # Fork a child to handle the client.
+ # The client is handled in the child so that the manager
+ # never receives SIGCHLD unless a worker crashes.
+ if os.fork() == 0:
+ # Leave the worker pool
+ signal.signal(SIGHUP, SIG_DFL)
+ listen_sock.close()
+ # Read the socket using fdopen instead of socket.makefile() because the latter
+ # seems to be very slow; note that we need to dup() the file descriptor because
+ # otherwise writes also cause a seek that makes us miss data on the read side.
+ infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ exit_code = 0
+ try:
+ worker_main(infile, outfile)
+ except SystemExit as exc:
+ exit_code = exc.code
+ finally:
+ outfile.flush()
+ sock.close()
+ os._exit(compute_real_exit_code(exit_code))
+ else:
+ sock.close()
+
+
+def launch_worker(listen_sock):
+ if os.fork() == 0:
+ try:
+ worker(listen_sock)
+ except Exception as err:
+ traceback.print_exc()
+ os._exit(1)
+ else:
+ assert should_exit()
+ os._exit(0)
+
+
+def manager():
+ # Create a new process group to corral our children
+ os.setpgid(0, 0)
+
+ # Create a listening socket on the AF_INET loopback interface
+ listen_sock = socket.socket(AF_INET, SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
+ listen_host, listen_port = listen_sock.getsockname()
+ write_int(listen_port, sys.stdout)
+
+ # Launch initial worker pool
+ for idx in range(POOLSIZE):
+ launch_worker(listen_sock)
+ listen_sock.close()
+
+ def shutdown():
+ global exit_flag
+ exit_flag.value = True
+
+ # Gracefully exit on SIGTERM, don't die on SIGHUP
+ signal.signal(SIGTERM, lambda signum, frame: shutdown())
+ signal.signal(SIGHUP, SIG_IGN)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ if status != 0 and not should_exit():
+ raise RuntimeError("worker crashed: %s, %s" % (pid, status))
+ except EnvironmentError as err:
+ if err.errno not in (ECHILD, EINTR):
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Initialization complete
+ sys.stdout.close()
+ try:
+ while not should_exit():
+ try:
+ # Spark tells us to exit by closing stdin
+ if os.read(0, 512) == '':
+ shutdown()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ shutdown()
+ raise
+ finally:
+ signal.signal(SIGTERM, SIG_DFL)
+ exit_flag.value = True
+ # Send SIGHUP to notify workers of shutdown
+ os.kill(0, SIGHUP)
+
+
+if __name__ == '__main__':
+ manager()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index 001b7a28b6..89bcbcfe06 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 2329e536cc..e503fb7621 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
import sys
from subprocess import Popen, PIPE
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 7036c47980..5f4294fb1b 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -32,13 +32,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
-def _do_python_join(rdd, other, numSplits, dispatch):
+def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
- return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
+ return vs.union(ws).groupByKey(numPartitions).flatMapValues(dispatch)
-def python_join(rdd, other, numSplits):
+def python_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -47,10 +47,10 @@ def python_join(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_right_outer_join(rdd, other, numSplits):
+def python_right_outer_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -61,10 +61,10 @@ def python_right_outer_join(rdd, other, numSplits):
if not vbuf:
vbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_left_outer_join(rdd, other, numSplits):
+def python_left_outer_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -75,10 +75,10 @@ def python_left_outer_join(rdd, other, numSplits):
if not wbuf:
wbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_cogroup(rdd, other, numSplits):
+def python_cogroup(rdd, other, numPartitions):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
def dispatch(seq):
@@ -89,4 +89,4 @@ def python_cogroup(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return (vbuf, wbuf)
- return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
+ return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4cda6cf661..51c2cb9806 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
@@ -143,7 +160,7 @@ class RDD(object):
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
- return self.map(lambda x: (x, "")) \
+ return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
@@ -215,7 +232,7 @@ class RDD(object):
yield pair
return java_cartesian.flatMap(unpack_batches)
- def groupBy(self, f, numSplits=None):
+ def groupBy(self, f, numPartitions=None):
"""
Return an RDD of grouped items.
@@ -224,7 +241,7 @@ class RDD(object):
>>> sorted([(x, sorted(y)) for (x, y) in result])
[(0, [2, 8]), (1, [1, 1, 3, 5])]
"""
- return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
+ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
def pipe(self, command, env={}):
"""
@@ -250,7 +267,11 @@ class RDD(object):
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
- self.map(f).collect() # Force evaluation
+ def processPartition(iterator):
+ for x in iterator:
+ f(x)
+ yield None
+ self.mapPartitions(processPartition).collect() # Force evaluation
def collect(self):
"""
@@ -274,8 +295,8 @@ class RDD(object):
def reduce(self, f):
"""
- Reduces the elements of this RDD using the specified associative binary
- operator.
+ Reduces the elements of this RDD using the specified commutative and
+ associative binary operator.
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
@@ -369,13 +390,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -399,7 +423,7 @@ class RDD(object):
>>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
>>> from fileinput import input
>>> from glob import glob
- >>> ''.join(input(glob(tempFile.name + "/part-0000*")))
+ >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*"))))
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
"""
def func(split, iterator):
@@ -422,22 +446,22 @@ class RDD(object):
"""
return dict(self.collect())
- def reduceByKey(self, func, numSplits=None):
+ def reduceByKey(self, func, numPartitions=None):
"""
Merge the values for each key using an associative reduce function.
This will also perform the merging locally on each mapper before
sending results to a reducer, similarly to a "combiner" in MapReduce.
- Output will be hash-partitioned with C{numSplits} splits, or the
- default parallelism level if C{numSplits} is not specified.
+ Output will be hash-partitioned with C{numPartitions} partitions, or
+ the default parallelism level if C{numPartitions} is not specified.
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
- return self.combineByKey(lambda x: x, func, func, numSplits)
+ return self.combineByKey(lambda x: x, func, func, numPartitions)
def reduceByKeyLocally(self, func):
"""
@@ -474,7 +498,7 @@ class RDD(object):
"""
return self.map(lambda x: x[0]).countByValue()
- def join(self, other, numSplits=None):
+ def join(self, other, numPartitions=None):
"""
Return an RDD containing all pairs of elements with matching keys in
C{self} and C{other}.
@@ -489,9 +513,9 @@ class RDD(object):
>>> sorted(x.join(y).collect())
[('a', (1, 2)), ('a', (1, 3))]
"""
- return python_join(self, other, numSplits)
+ return python_join(self, other, numPartitions)
- def leftOuterJoin(self, other, numSplits=None):
+ def leftOuterJoin(self, other, numPartitions=None):
"""
Perform a left outer join of C{self} and C{other}.
@@ -506,9 +530,9 @@ class RDD(object):
>>> sorted(x.leftOuterJoin(y).collect())
[('a', (1, 2)), ('b', (4, None))]
"""
- return python_left_outer_join(self, other, numSplits)
+ return python_left_outer_join(self, other, numPartitions)
- def rightOuterJoin(self, other, numSplits=None):
+ def rightOuterJoin(self, other, numPartitions=None):
"""
Perform a right outer join of C{self} and C{other}.
@@ -523,10 +547,10 @@ class RDD(object):
>>> sorted(y.rightOuterJoin(x).collect())
[('a', (2, 1)), ('b', (None, 4))]
"""
- return python_right_outer_join(self, other, numSplits)
+ return python_right_outer_join(self, other, numPartitions)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, partitionFunc=hash):
+ def partitionBy(self, numPartitions, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -535,22 +559,22 @@ class RDD(object):
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
- if numSplits is None:
- numSplits = self.ctx.defaultParallelism
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
# Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numSplits) objects
+ # form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[partitionFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx)
@@ -561,7 +585,7 @@ class RDD(object):
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
- numSplits=None):
+ numPartitions=None):
"""
Generic function to combine the elements for each key using a custom
set of aggregation functions.
@@ -586,8 +610,8 @@ class RDD(object):
>>> sorted(x.combineByKey(str, add, add).collect())
[('a', '11'), ('b', '1')]
"""
- if numSplits is None:
- numSplits = self.ctx.defaultParallelism
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
for (k, v) in iterator:
@@ -597,7 +621,7 @@ class RDD(object):
combiners[k] = mergeValue(combiners[k], v)
return combiners.iteritems()
locally_combined = self.mapPartitions(combineLocally)
- shuffled = locally_combined.partitionBy(numSplits)
+ shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
combiners = {}
for (k, v) in iterator:
@@ -609,10 +633,10 @@ class RDD(object):
return shuffled.mapPartitions(_mergeCombiners)
# TODO: support variant with custom partitioner
- def groupByKey(self, numSplits=None):
+ def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
- Hash-partitions the resulting RDD with into numSplits partitions.
+ Hash-partitions the resulting RDD with into numPartitions partitions.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
@@ -630,7 +654,7 @@ class RDD(object):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numSplits)
+ numPartitions)
# TODO: add tests
def flatMapValues(self, f):
@@ -659,7 +683,7 @@ class RDD(object):
return self.cogroup(other)
# TODO: add variant with custom parittioner
- def cogroup(self, other, numSplits=None):
+ def cogroup(self, other, numPartitions=None):
"""
For each key k in C{self} or C{other}, return a resulting RDD that
contains a tuple with the list of values for that key in C{self} as well
@@ -670,7 +694,7 @@ class RDD(object):
>>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
- return python_cogroup(self, other, numSplits)
+ return python_cogroup(self, other, numPartitions)
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
@@ -732,9 +756,8 @@ class PipelinedRDD(RDD):
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
- env = copy.copy(self.ctx.environment)
- env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ env = MapConverter().convert(self.ctx.environment,
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 115cf28cc2..fecacd1241 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import struct
import cPickle
@@ -46,6 +63,10 @@ def read_long(stream):
return struct.unpack("!q", length)[0]
+def write_long(value, stream):
+ stream.write(struct.pack("!q", value))
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index c8297b662e..9b4b4e78cb 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
An interactive shell.
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..dfd841b10a 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
@@ -12,6 +29,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +135,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 812e7a9da5..75d692beeb 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1,8 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Worker that receives input from Piped RDD.
"""
import os
import sys
+import time
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -12,48 +30,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
-# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = os.fdopen(os.dup(1), 'w')
-os.dup2(2, 1)
+def load_obj(infile):
+ return load_pickle(standard_b64decode(infile.readline().strip()))
-def load_obj():
- return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+def report_times(outfile, boot, init, finish):
+ write_int(-3, outfile)
+ write_long(1000 * boot, outfile)
+ write_long(1000 * init, outfile)
+ write_long(1000 * finish, outfile)
-def main():
- split_index = read_int(sys.stdin)
- spark_files_dir = load_pickle(read_with_length(sys.stdin))
+def main(infile, outfile):
+ boot_time = time.time()
+ split_index = read_int(infile)
+ if split_index == -1: # for unit tests
+ return
+ spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
- num_broadcast_variables = read_int(sys.stdin)
+ num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
- bid = read_long(sys.stdin)
- value = read_with_length(sys.stdin)
+ bid = read_long(infile)
+ value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
- func = load_obj()
- bypassSerializer = load_obj()
+ func = load_obj(infile)
+ bypassSerializer = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
- iterator = read_from_pickle_file(sys.stdin)
+ init_time = time.time()
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ write_with_length(dumps(obj), outfile)
except Exception as e:
- write_int(-2, old_stdout)
- write_with_length(traceback.format_exc(), old_stdout)
+ write_int(-2, outfile)
+ write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
+ finish_time = time.time()
+ report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, old_stdout)
+ write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), old_stdout)
+ write_with_length(dump_pickle((aid, accum._value)), outfile)
+ write_int(-1, outfile)
if __name__ == '__main__':
- main()
+ # Redirect stdout to stderr so that users must return values from functions.
+ old_stdout = os.fdopen(os.dup(1), 'w')
+ os.dup2(2, 1)
+ main(sys.stdin, old_stdout)
diff --git a/python/run-tests b/python/run-tests
index a3a9ff5dcb..6643faa2e0 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -1,8 +1,29 @@
#!/usr/bin/env bash
-# Figure out where the Scala framework is installed
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+# Figure out where the Spark framework is installed
FWDIR="$(cd `dirname $0`; cd ../; pwd)"
+# CD into the python directory to find things on the right path
+cd "$FWDIR/python"
+
FAILED=0
$FWDIR/pyspark pyspark/rdd.py