aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-09-06 17:53:01 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-09-06 17:53:01 +0530
commit4106ae9fbf8a582697deba2198b3b966dec00bfe (patch)
tree7c3046faee5f62f9ec4c4176125988d7cb5d70e2 /python
parente0dd24dc858777904335218f3001a24bffe73b27 (diff)
parent5c7494d7c1b7301138fb3dc155a1b0c961126ec6 (diff)
downloadspark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.gz
spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.bz2
spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.zip
Merged with master
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf18
-rwxr-xr-xpython/examples/als.py22
-rwxr-xr-x[-rw-r--r--]python/examples/kmeans.py20
-rwxr-xr-xpython/examples/logistic_regression.py71
-rwxr-xr-xpython/examples/pagerank.py70
-rwxr-xr-x[-rw-r--r--]python/examples/pi.py20
-rwxr-xr-x[-rw-r--r--]python/examples/transitive_closure.py22
-rwxr-xr-x[-rw-r--r--]python/examples/wordcount.py22
-rw-r--r--python/lib/py4j0.7.jarbin103286 -> 0 bytes
-rw-r--r--python/pyspark/__init__.py19
-rw-r--r--python/pyspark/accumulators.py17
-rw-r--r--python/pyspark/broadcast.py17
-rw-r--r--python/pyspark/context.py46
-rw-r--r--python/pyspark/daemon.py17
-rw-r--r--python/pyspark/files.py19
-rw-r--r--python/pyspark/java_gateway.py36
-rw-r--r--python/pyspark/rdd.py208
-rw-r--r--python/pyspark/rddsampler.py112
-rw-r--r--python/pyspark/serializers.py17
-rw-r--r--python/pyspark/shell.py37
-rw-r--r--python/pyspark/statcounter.py109
-rw-r--r--python/pyspark/tests.py41
-rw-r--r--python/pyspark/worker.py41
-rwxr-xr-xpython/run-tests45
-rw-r--r--python/test_support/userlib-0.1-py2.7.eggbin0 -> 1945 bytes
25 files changed, 948 insertions, 98 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 45102cd9fe..1d0d002d36 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -1,5 +1,22 @@
[epydoc] # Epydoc section marker (required by ConfigParser)
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
# Information about the project.
name: PySpark
url: http://spark-project.org
@@ -17,3 +34,4 @@ private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
+ pyspark.rddsampler pyspark.daemon
diff --git a/python/examples/als.py b/python/examples/als.py
index 010f80097f..a77dfb2577 100755
--- a/python/examples/als.py
+++ b/python/examples/als.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
This example requires numpy (http://www.numpy.org/)
"""
@@ -31,8 +48,7 @@ def update(i, vec, mat, ratings):
if __name__ == "__main__":
if len(sys.argv) < 2:
- print >> sys.stderr, \
- "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
+ print >> sys.stderr, "Usage: als <master> <M> <U> <F> <iters> <slices>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
@@ -67,5 +83,5 @@ if __name__ == "__main__":
usb = sc.broadcast(us)
error = rmse(R, ms, us)
- print "Iteration %d:" % i
+ print "Iteration %d:" % i
print "\nRMSE: %5.4f\n" % error
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
index 72cf9f88c6..ba31af92fc 100644..100755
--- a/python/examples/kmeans.py
+++ b/python/examples/kmeans.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
This example requires numpy (http://www.numpy.org/)
"""
@@ -24,8 +41,7 @@ def closestPoint(p, centers):
if __name__ == "__main__":
if len(sys.argv) < 5:
- print >> sys.stderr, \
- "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ print >> sys.stderr, "Usage: kmeans <master> <file> <k> <convergeDist>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonKMeans")
lines = sc.textFile(sys.argv[2])
diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py
index f13698a86f..1117dea538 100755
--- a/python/examples/logistic_regression.py
+++ b/python/examples/logistic_regression.py
@@ -1,5 +1,23 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
-This example requires numpy (http://www.numpy.org/)
+A logistic regression implementation that uses NumPy (http://www.numpy.org) to act on batches
+of input data using efficient matrix operations.
"""
from collections import namedtuple
from math import exp
@@ -10,48 +28,45 @@ import numpy as np
from pyspark import SparkContext
-N = 100000 # Number of data points
D = 10 # Number of dimensions
-R = 0.7 # Scaling factor
-ITERATIONS = 5
-np.random.seed(42)
-
-
-DataPoint = namedtuple("DataPoint", ['x', 'y'])
-from lr import DataPoint # So that DataPoint is properly serialized
-def generateData():
- def generatePoint(i):
- y = -1 if i % 2 == 0 else 1
- x = np.random.normal(size=D) + (y * R)
- return DataPoint(x, y)
- return [generatePoint(i) for i in range(N)]
-
+# Read a batch of points from the input file into a NumPy matrix object. We operate on batches to
+# make further computations faster.
+# The data file contains lines of the form <label> <x1> <x2> ... <xD>. We load each block of these
+# into a NumPy array of size numLines * (D + 1) and pull out column 0 vs the others in gradient().
+def readPointBatch(iterator):
+ strs = list(iterator)
+ matrix = np.zeros((len(strs), D + 1))
+ for i in xrange(len(strs)):
+ matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ')
+ return [matrix]
if __name__ == "__main__":
- if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonLR <master> [<slices>]"
+ if len(sys.argv) != 4:
+ print >> sys.stderr, "Usage: logistic_regression <master> <file> <iters>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
- slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
- points = sc.parallelize(generateData(), slices).cache()
+ points = sc.textFile(sys.argv[2]).mapPartitions(readPointBatch).cache()
+ iterations = int(sys.argv[3])
# Initialize w to a random value
w = 2 * np.random.ranf(size=D) - 1
print "Initial w: " + str(w)
+ # Compute logistic regression gradient for a matrix of data points
+ def gradient(matrix, w):
+ Y = matrix[:,0] # point labels (first column of input file)
+ X = matrix[:,1:] # point coordinates
+ # For each point (x, y), compute gradient function, then sum these up
+ return ((1.0 / (1.0 + np.exp(-Y * X.dot(w))) - 1.0) * Y * X.T).sum(1)
+
def add(x, y):
x += y
return x
- for i in range(1, ITERATIONS + 1):
- print "On iteration %i" % i
-
- gradient = points.map(lambda p:
- (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
- ).reduce(add)
- w -= gradient
+ for i in range(iterations):
+ print "On iteration %i" % (i + 1)
+ w -= points.map(lambda m: gradient(m, w)).reduce(add)
print "Final w: " + str(w)
diff --git a/python/examples/pagerank.py b/python/examples/pagerank.py
new file mode 100755
index 0000000000..cd774cf3a3
--- /dev/null
+++ b/python/examples/pagerank.py
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python
+
+import re, sys
+from operator import add
+
+from pyspark import SparkContext
+
+
+def computeContribs(urls, rank):
+ """Calculates URL contributions to the rank of other URLs."""
+ num_urls = len(urls)
+ for url in urls: yield (url, rank / num_urls)
+
+
+def parseNeighbors(urls):
+ """Parses a urls pair string into urls pair."""
+ parts = re.split(r'\s+', urls)
+ return parts[0], parts[1]
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 3:
+ print >> sys.stderr, "Usage: pagerank <master> <file> <number_of_iterations>"
+ exit(-1)
+
+ # Initialize the spark context.
+ sc = SparkContext(sys.argv[1], "PythonPageRank")
+
+ # Loads in input file. It should be in format of:
+ # URL neighbor URL
+ # URL neighbor URL
+ # URL neighbor URL
+ # ...
+ lines = sc.textFile(sys.argv[2], 1)
+
+ # Loads all URLs from input file and initialize their neighbors.
+ links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
+
+ # Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one.
+ ranks = links.map(lambda (url, neighbors): (url, 1.0))
+
+ # Calculates and updates URL ranks continuously using PageRank algorithm.
+ for iteration in xrange(int(sys.argv[3])):
+ # Calculates URL contributions to the rank of other URLs.
+ contribs = links.join(ranks).flatMap(lambda (url, (urls, rank)):
+ computeContribs(urls, rank))
+
+ # Re-calculates URL ranks based on neighbor contributions.
+ ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15)
+
+ # Collects all URL ranks and dump them to console.
+ for (link, rank) in ranks.collect():
+ print "%s has rank: %s." % (link, rank)
diff --git a/python/examples/pi.py b/python/examples/pi.py
index 127cba029b..ab0645fc2f 100644..100755
--- a/python/examples/pi.py
+++ b/python/examples/pi.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from random import random
from operator import add
@@ -7,8 +24,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonPi <master> [<slices>]"
+ print >> sys.stderr, "Usage: pi <master> [<slices>]"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonPi")
slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py
index 7f85a1008e..744cce6651 100644..100755
--- a/python/examples/transitive_closure.py
+++ b/python/examples/transitive_closure.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from random import Random
@@ -20,10 +37,9 @@ def generateGraph():
if __name__ == "__main__":
if len(sys.argv) == 1:
- print >> sys.stderr, \
- "Usage: PythonTC <master> [<slices>]"
+ print >> sys.stderr, "Usage: transitive_closure <master> [<slices>]"
exit(-1)
- sc = SparkContext(sys.argv[1], "PythonTC")
+ sc = SparkContext(sys.argv[1], "PythonTransitiveClosure")
slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
tc = sc.parallelize(generateGraph(), slices).cache()
diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py
index 857160624b..b9139b9d76 100644..100755
--- a/python/examples/wordcount.py
+++ b/python/examples/wordcount.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import sys
from operator import add
@@ -6,8 +23,7 @@ from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) < 3:
- print >> sys.stderr, \
- "Usage: PythonWordCount <master> <file>"
+ print >> sys.stderr, "Usage: wordcount <master> <file>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonWordCount")
lines = sc.textFile(sys.argv[2], 1)
@@ -16,4 +32,4 @@ if __name__ == "__main__":
.reduceByKey(add)
output = counts.collect()
for (word, count) in output:
- print "%s : %i" % (word, count)
+ print "%s: %i" % (word, count)
diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar
deleted file mode 100644
index 73b7ddb7d1..0000000000
--- a/python/lib/py4j0.7.jar
+++ /dev/null
Binary files differ
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 3e8bca62f0..fd5972d381 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -1,5 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
-PySpark is a Python API for Spark.
+PySpark is the Python API for Spark.
Public classes:
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 3e9d7d36da..d367f91967 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index def810dd46..dfdaba274f 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 657fe6f989..8fbf296509 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
import shutil
import sys
@@ -29,6 +46,7 @@ class SparkContext(object):
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
+ _python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -86,16 +104,19 @@ class SparkContext(object):
# send.
self._pickled_broadcast_vars = set()
+ SparkFiles._sc = self
+ root_dir = SparkFiles.getRootDirectory()
+ sys.path.append(root_dir)
+
# Deploy any code dependencies specified in the constructor
+ self._python_includes = list()
for path in (pyFiles or []):
self.addPyFile(path)
- SparkFiles._sc = self
- sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir:
- local_dir = self._jvm.spark.Utils.getLocalDir()
+ local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir()
self._temp_dir = \
- self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
+ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
@@ -124,14 +145,21 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD.
+
+ >>> sc.parallelize(range(5), 5).glom().collect()
+ [[0], [1], [2], [3], [4]]
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- if self.batchSize != 1:
- c = batched(c, self.batchSize)
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = min(len(c) // numSlices, self.batchSize)
+ if batchSize > 1:
+ c = batched(c, batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
@@ -233,7 +261,11 @@ class SparkContext(object):
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
- filename = path.split("/")[-1]
+ (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
+
+ if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
+ self._python_includes.append(filename)
+ sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
def setCheckpointDir(self, dirName, useExisting=False):
"""
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 78c9457b84..eb18ec08c9 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
import signal
import socket
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index 001b7a28b6..57ee14eeb7 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
@@ -35,4 +52,4 @@ class SparkFiles(object):
return cls._root_directory
else:
# This will have to change if we support multiple SparkContexts:
- return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
+ return cls._sc._jvm.org.apache.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 2329e536cc..e615c1e9b6 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -1,5 +1,24 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
import sys
+import signal
+import platform
from subprocess import Popen, PIPE
from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
@@ -11,9 +30,18 @@ SPARK_HOME = os.environ["SPARK_HOME"]
def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and SPARK_MEM settings from spark-env.sh
- command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer",
+ on_windows = platform.system() == "Windows"
+ script = "spark-class.cmd" if on_windows else "spark-class"
+ command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer",
"--die-on-broken-pipe", "0"]
- proc = Popen(command, stdout=PIPE, stdin=PIPE)
+ if not on_windows:
+ # Don't send ctrl-c / SIGINT to the Java gateway:
+ def preexec_func():
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func)
+ else:
+ # preexec_fn not supported on Windows
+ proc = Popen(command, stdout=PIPE, stdin=PIPE)
# Determine which ephemeral port the server started on:
port = int(proc.stdout.readline())
# Create a thread to echo output from the GatewayServer, which is required
@@ -32,7 +60,7 @@ def launch_gateway():
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
- java_import(gateway.jvm, "spark.api.java.*")
- java_import(gateway.jvm, "spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.api.java.*")
+ java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 86bade9546..1831435b33 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,9 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap, product
import operator
import os
+import sys
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
@@ -14,6 +32,8 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
+from pyspark.statcounter import StatCounter
+from pyspark.rddsampler import RDDSampler
from py4j.java_collections import ListConverter, MapConverter
@@ -143,18 +163,64 @@ class RDD(object):
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
- return self.map(lambda x: (x, "")) \
+ return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
- # TODO: sampling needs to be re-implemented due to Batch
- #def sample(self, withReplacement, fraction, seed):
- # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
- # return RDD(jrdd, self.ctx)
+ def sample(self, withReplacement, fraction, seed):
+ """
+ Return a sampled subset of this RDD (relies on numpy and falls back
+ on default random generator if numpy is unavailable).
+
+ >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
+ [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+ """
+ return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
+
+ # this is ported from scala/spark/RDD.scala
+ def takeSample(self, withReplacement, num, seed):
+ """
+ Return a fixed-size sampled subset of this RDD (currently requires numpy).
+
+ >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
+ [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+ """
+
+ fraction = 0.0
+ total = 0
+ multiplier = 3.0
+ initialCount = self.count()
+ maxSelected = 0
+
+ if (num < 0):
+ raise ValueError
+
+ if initialCount > sys.maxint - 1:
+ maxSelected = sys.maxint - 1
+ else:
+ maxSelected = initialCount
- #def takeSample(self, withReplacement, num, seed):
- # vals = self._jrdd.takeSample(withReplacement, num, seed)
- # return [load_pickle(bytes(x)) for x in vals]
+ if num > initialCount and not withReplacement:
+ total = maxSelected
+ fraction = multiplier * (maxSelected + 1) / initialCount
+ else:
+ fraction = multiplier * (num + 1) / initialCount
+ total = num
+
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ # If the first sample didn't turn out large enough, keep trying to take samples;
+ # this shouldn't happen often because we use a big multiplier for their initial size.
+ # See: scala/spark/RDD.scala
+ while len(samples) < total:
+ if seed > sys.maxint - 2:
+ seed = -1
+ seed += 1
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ sampler = RDDSampler(withReplacement, fraction, seed+1)
+ sampler.shuffle(samples)
+ return samples[0:total]
def union(self, other):
"""
@@ -250,7 +316,11 @@ class RDD(object):
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
- self.map(f).collect() # Force evaluation
+ def processPartition(iterator):
+ for x in iterator:
+ f(x)
+ yield None
+ self.mapPartitions(processPartition).collect() # Force evaluation
def collect(self):
"""
@@ -336,6 +406,63 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def stats(self):
+ """
+ Return a L{StatCounter} object that captures the mean, variance
+ and count of the RDD's elements in one operation.
+ """
+ def redFunc(left_counter, right_counter):
+ return left_counter.mergeStats(right_counter)
+
+ return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
+
+ def mean(self):
+ """
+ Compute the mean of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).mean()
+ 2.0
+ """
+ return self.stats().mean()
+
+ def variance(self):
+ """
+ Compute the variance of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).variance()
+ 0.666...
+ """
+ return self.stats().variance()
+
+ def stdev(self):
+ """
+ Compute the standard deviation of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).stdev()
+ 0.816...
+ """
+ return self.stats().stdev()
+
+ def sampleStdev(self):
+ """
+ Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ estimating the standard deviation by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleStdev()
+ 1.0
+ """
+ return self.stats().sampleStdev()
+
+ def sampleVariance(self):
+ """
+ Compute the sample variance of this RDD's elements (which corrects for bias in
+ estimating the variance by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleVariance()
+ 1.0
+ """
+ return self.stats().sampleVariance()
def countByValue(self):
"""
@@ -369,13 +496,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -672,6 +802,43 @@ class RDD(object):
"""
return python_cogroup(self, other, numPartitions)
+ def subtractByKey(self, other, numPartitions=None):
+ """
+ Return each (key, value) pair in C{self} that has no pair with matching key
+ in C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
+ >>> y = sc.parallelize([("a", 3), ("c", None)])
+ >>> sorted(x.subtractByKey(y).collect())
+ [('b', 4), ('b', 5)]
+ """
+ filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
+ map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
+ return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
+
+ def subtract(self, other, numPartitions=None):
+ """
+ Return each value in C{self} that is not contained in C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)])
+ >>> y = sc.parallelize([("a", 3), ("c", None)])
+ >>> sorted(x.subtract(y).collect())
+ [('a', 1), ('b', 4), ('b', 5)]
+ """
+ rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder
+ return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder
+
+ def keyBy(self, f):
+ """
+ Creates tuples of the elements in this RDD by applying C{f}.
+
+ >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
+ >>> y = sc.parallelize(zip(range(0,5), range(0,5)))
+ >>> sorted(x.cogroup(y).collect())
+ [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
+ """
+ return self.map(lambda x: (f(x), x))
+
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
@@ -732,11 +899,12 @@ class PipelinedRDD(RDD):
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_tag = self._prev_jrdd.classTag()
- env = copy.copy(self.ctx.environment)
- env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ env = MapConverter().convert(self.ctx.environment,
+ self.ctx._gateway._gateway_client)
+ includes = ListConverter().convert(self.ctx._python_includes,
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
@@ -752,7 +920,7 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
new file mode 100644
index 0000000000..aca2ef3b51
--- /dev/null
+++ b/python/pyspark/rddsampler.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import random
+
+class RDDSampler(object):
+ def __init__(self, withReplacement, fraction, seed):
+ try:
+ import numpy
+ self._use_numpy = True
+ except ImportError:
+ print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
+ self._use_numpy = False
+
+ self._seed = seed
+ self._withReplacement = withReplacement
+ self._fraction = fraction
+ self._random = None
+ self._split = None
+ self._rand_initialized = False
+
+ def initRandomGenerator(self, split):
+ if self._use_numpy:
+ import numpy
+ self._random = numpy.random.RandomState(self._seed)
+ for _ in range(0, split):
+ # discard the next few values in the sequence to have a
+ # different seed for the different splits
+ self._random.randint(sys.maxint)
+ else:
+ import random
+ random.seed(self._seed)
+ for _ in range(0, split):
+ # discard the next few values in the sequence to have a
+ # different seed for the different splits
+ random.randint(0, sys.maxint)
+ self._split = split
+ self._rand_initialized = True
+
+ def getUniformSample(self, split):
+ if not self._rand_initialized or split != self._split:
+ self.initRandomGenerator(split)
+
+ if self._use_numpy:
+ return self._random.random_sample()
+ else:
+ return random.uniform(0.0, 1.0)
+
+ def getPoissonSample(self, split, mean):
+ if not self._rand_initialized or split != self._split:
+ self.initRandomGenerator(split)
+
+ if self._use_numpy:
+ return self._random.poisson(mean)
+ else:
+ # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
+ # drawing a sequence of numbers delta_j ~ Exp(mean)
+ num_arrivals = 1
+ cur_time = 0.0
+
+ cur_time += random.expovariate(mean)
+
+ if cur_time > 1.0:
+ return 0
+
+ while(cur_time <= 1.0):
+ cur_time += random.expovariate(mean)
+ num_arrivals += 1
+
+ return (num_arrivals - 1)
+
+ def shuffle(self, vals):
+ if self._random == None or split != self._split:
+ self.initRandomGenerator(0) # this should only ever called on the master so
+ # the split does not matter
+
+ if self._use_numpy:
+ self._random.shuffle(vals)
+ else:
+ random.shuffle(vals, self._random)
+
+ def func(self, split, iterator):
+ if self._withReplacement:
+ for obj in iterator:
+ # For large datasets, the expected number of occurrences of each element in a sample with
+ # replacement is Poisson(frac). We use that to get a count for each element.
+ count = self.getPoissonSample(split, mean = self._fraction)
+ for _ in range(0, count):
+ yield obj
+ else:
+ for obj in iterator:
+ if self.getUniformSample(split) <= self._fraction:
+ yield obj
+
+
+
+
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 5a95144983..fecacd1241 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import struct
import cPickle
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 54ff1bf8e7..54823f8037 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -1,16 +1,51 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
An interactive shell.
This file is designed to be launched as a PYTHONSTARTUP script.
"""
import os
+import platform
import pyspark
from pyspark.context import SparkContext
+# this is the equivalent of ADD_JARS
+add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None
+
+sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files)
-sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell")
+print """Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /__ / .__/\_,_/_/ /_/\_\ version 0.8.0
+ /_/
+"""
+print "Using Python version %s (%s, %s)" % (
+ platform.python_version(),
+ platform.python_build()[0],
+ platform.python_build()[1])
print "Spark context avaiable as sc."
+if add_files != None:
+ print "Adding files: [%s]" % ", ".join(add_files)
+
# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
# which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
new file mode 100644
index 0000000000..8e1cbd4ad9
--- /dev/null
+++ b/python/pyspark/statcounter.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This file is ported from spark/util/StatCounter.scala
+
+import copy
+import math
+
+class StatCounter(object):
+
+ def __init__(self, values=[]):
+ self.n = 0L # Running count of our values
+ self.mu = 0.0 # Running mean of our values
+ self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
+
+ for v in values:
+ self.merge(v)
+
+ # Add a value into this StatCounter, updating the internal statistics.
+ def merge(self, value):
+ delta = value - self.mu
+ self.n += 1
+ self.mu += delta / self.n
+ self.m2 += delta * (value - self.mu)
+ return self
+
+ # Merge another StatCounter into this one, adding up the internal statistics.
+ def mergeStats(self, other):
+ if not isinstance(other, StatCounter):
+ raise Exception("Can only merge Statcounters!")
+
+ if other is self: # reference equality holds
+ self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order
+ else:
+ if self.n == 0:
+ self.mu = other.mu
+ self.m2 = other.m2
+ self.n = other.n
+ elif other.n != 0:
+ delta = other.mu - self.mu
+ if other.n * 10 < self.n:
+ self.mu = self.mu + (delta * other.n) / (self.n + other.n)
+ elif self.n * 10 < other.n:
+ self.mu = other.mu - (delta * self.n) / (self.n + other.n)
+ else:
+ self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
+
+ self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
+ self.n += other.n
+ return self
+
+ # Clone this StatCounter
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def count(self):
+ return self.n
+
+ def mean(self):
+ return self.mu
+
+ def sum(self):
+ return self.n * self.mu
+
+ # Return the variance of the values.
+ def variance(self):
+ if self.n == 0:
+ return float('nan')
+ else:
+ return self.m2 / self.n
+
+ #
+ # Return the sample variance, which corrects for bias in estimating the variance by dividing
+ # by N-1 instead of N.
+ #
+ def sampleVariance(self):
+ if self.n <= 1:
+ return float('nan')
+ else:
+ return self.m2 / (self.n - 1)
+
+ # Return the standard deviation of the values.
+ def stdev(self):
+ return math.sqrt(self.variance())
+
+ #
+ # Return the sample standard deviation of the values, which corrects for bias in estimating the
+ # variance by dividing by N-1 instead of N.
+ #
+ def sampleStdev(self):
+ return math.sqrt(self.sampleVariance())
+
+ def __repr__(self):
+ return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev())
+
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1e34d47365..29d6a128f6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
@@ -47,7 +64,7 @@ class TestCheckpoint(PySparkTestCase):
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
self.assertFalse(flatMappedRDD.isCheckpointed())
- self.assertIsNone(flatMappedRDD.getCheckpointFile())
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
flatMappedRDD.checkpoint()
result = flatMappedRDD.collect()
@@ -62,13 +79,13 @@ class TestCheckpoint(PySparkTestCase):
flatMappedRDD = parCollection.flatMap(lambda x: [x])
self.assertFalse(flatMappedRDD.isCheckpointed())
- self.assertIsNone(flatMappedRDD.getCheckpointFile())
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
flatMappedRDD.checkpoint()
flatMappedRDD.count() # forces a checkpoint to be computed
time.sleep(1) # 1 second
- self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
self.assertEquals([1, 2, 3, 4], recovered.collect())
@@ -108,6 +125,17 @@ class TestAddFile(PySparkTestCase):
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
+ def test_add_egg_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlib import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
+ self.sc.addPyFile(path)
+ from userlib import UserClass
+ self.assertEqual("Hello World from inside a package!", UserClass().hello())
+
class TestIO(PySparkTestCase):
@@ -147,9 +175,12 @@ class TestDaemon(unittest.TestCase):
time.sleep(1)
# daemon should no longer accept connections
- with self.assertRaises(EnvironmentError) as trap:
+ try:
self.connect(port)
- self.assertEqual(trap.exception.errno, ECONNREFUSED)
+ except EnvironmentError as exception:
+ self.assertEqual(exception.errno, ECONNREFUSED)
+ else:
+ self.fail("Expected EnvironmentError to be raised")
def test_termination_stdin(self):
"""Ensure that daemon and workers terminate when stdin is closed."""
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 379bbfd4c2..d63c2aaef7 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1,9 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Worker that receives input from Piped RDD.
"""
import os
import sys
import time
+import socket
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -32,15 +50,26 @@ def main(infile, outfile):
split_index = read_int(infile)
if split_index == -1: # for unit tests
return
+
+ # fetch name of workdir
spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
- sys.path.append(spark_files_dir)
+
+ # fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+
+ # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
+ sys.path.append(spark_files_dir) # *.py files that were added will be copied here
+ num_python_includes = read_int(infile)
+ for _ in range(num_python_includes):
+ sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+
+ # now load function
func = load_obj(infile)
bypassSerializer = load_obj(infile)
if bypassSerializer:
@@ -66,7 +95,9 @@ def main(infile, outfile):
if __name__ == '__main__':
- # Redirect stdout to stderr so that users must return values from functions.
- old_stdout = os.fdopen(os.dup(1), 'w')
- os.dup2(2, 1)
- main(sys.stdin, old_stdout)
+ # Read a local port to connect to from stdin
+ java_port = int(sys.stdin.readline())
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(("127.0.0.1", java_port))
+ sock_file = sock.makefile("a+", 65536)
+ main(sock_file, sock_file)
diff --git a/python/run-tests b/python/run-tests
index a3a9ff5dcb..cbc554ea9d 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -1,24 +1,43 @@
#!/usr/bin/env bash
-# Figure out where the Scala framework is installed
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+# Figure out where the Spark framework is installed
FWDIR="$(cd `dirname $0`; cd ../; pwd)"
-FAILED=0
-
-$FWDIR/pyspark pyspark/rdd.py
-FAILED=$(($?||$FAILED))
+# CD into the python directory to find things on the right path
+cd "$FWDIR/python"
-$FWDIR/pyspark pyspark/context.py
-FAILED=$(($?||$FAILED))
+FAILED=0
-$FWDIR/pyspark -m doctest pyspark/broadcast.py
-FAILED=$(($?||$FAILED))
+rm -f unit-tests.log
-$FWDIR/pyspark -m doctest pyspark/accumulators.py
-FAILED=$(($?||$FAILED))
+function run_test() {
+ $FWDIR/pyspark $1 2>&1 | tee -a unit-tests.log
+ FAILED=$((PIPESTATUS[0]||$FAILED))
+}
-$FWDIR/pyspark -m unittest pyspark.tests
-FAILED=$(($?||$FAILED))
+run_test "pyspark/rdd.py"
+run_test "pyspark/context.py"
+run_test "-m doctest pyspark/broadcast.py"
+run_test "-m doctest pyspark/accumulators.py"
+run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg
new file mode 100644
index 0000000000..1674c9cb22
--- /dev/null
+++ b/python/test_support/userlib-0.1-py2.7.egg
Binary files differ