aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf4
-rw-r--r--python/pyspark/__init__.py32
-rw-r--r--python/pyspark/accumulators.py19
-rw-r--r--python/pyspark/broadcast.py11
-rw-r--r--python/pyspark/conf.py171
-rw-r--r--python/pyspark/context.py174
-rw-r--r--python/pyspark/java_gateway.py2
-rw-r--r--python/pyspark/mllib/__init__.py20
-rw-r--r--python/pyspark/mllib/_common.py227
-rw-r--r--python/pyspark/mllib/classification.py86
-rw-r--r--python/pyspark/mllib/clustering.py79
-rw-r--r--python/pyspark/mllib/recommendation.py74
-rw-r--r--python/pyspark/mllib/regression.py110
-rw-r--r--python/pyspark/rdd.py182
-rw-r--r--python/pyspark/serializers.py303
-rw-r--r--python/pyspark/shell.py2
-rw-r--r--python/pyspark/tests.py22
-rw-r--r--python/pyspark/worker.py44
-rwxr-xr-xpython/run-tests4
-rwxr-xr-xpython/test_support/userlibrary.py17
20 files changed, 1381 insertions, 202 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36..95a6af0974 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
private: no
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
- pyspark.rddsampler pyspark.daemon
+ pyspark.rddsampler pyspark.daemon pyspark.mllib._common
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 1f35f6f939..2b2c3a061a 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -20,28 +20,34 @@ PySpark is the Python API for Spark.
Public classes:
- - L{SparkContext<pyspark.context.SparkContext>}
- Main entry point for Spark functionality.
- - L{RDD<pyspark.rdd.RDD>}
- A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- - L{Broadcast<pyspark.broadcast.Broadcast>}
- A broadcast variable that gets reused across tasks.
- - L{Accumulator<pyspark.accumulators.Accumulator>}
- An "add-only" shared variable that tasks can only add values to.
- - L{SparkFiles<pyspark.files.SparkFiles>}
- Access files shipped with jobs.
- - L{StorageLevel<pyspark.storagelevel.StorageLevel>}
- Finer-grained cache persistence levels.
+ - L{SparkContext<pyspark.context.SparkContext>}
+ Main entry point for Spark functionality.
+ - L{RDD<pyspark.rdd.RDD>}
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ - L{Broadcast<pyspark.broadcast.Broadcast>}
+ A broadcast variable that gets reused across tasks.
+ - L{Accumulator<pyspark.accumulators.Accumulator>}
+ An "add-only" shared variable that tasks can only add values to.
+ - L{SparkConf<pyspark.conf.SparkConf>}
+ For configuring Spark.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
+ - L{StorageLevel<pyspark.storagelevel.StorageLevel>}
+ Finer-grained cache persistence levels.
"""
+
+
+
import sys
import os
sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg"))
+from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
-__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"]
+__all__ = ["SparkConf", "SparkContext", "RDD", "SparkFiles", "StorageLevel"]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index d367f91967..2204e9c9ca 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -42,6 +42,13 @@
>>> a.value
13
+>>> b = sc.accumulator(0)
+>>> def g(x):
+... b.add(x)
+>>> rdd.foreach(g)
+>>> b.value
+6
+
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@@ -83,8 +90,10 @@ import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
+
+pickleSer = PickleSerializer()
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
@@ -139,9 +148,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
+ def add(self, term):
+ """Adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
- self._value = self.accum_param.addInPlace(self._value, term)
+ self.add(term)
return self
def __str__(self):
@@ -200,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
- (aid, update) = load_pickle(read_with_length(self.rfile))
+ (aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index dfdaba274f..43f40f8783 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -45,7 +45,18 @@ def _from_id(bid):
class Broadcast(object):
+ """
+ A broadcast variable created with
+ L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}.
+ Access its value through C{.value}.
+ """
+
def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ """
+ Should not be called directly by users -- use
+ L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
+ instead.
+ """
self.value = value
self.bid = bid
self._jbroadcast = java_broadcast
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
new file mode 100644
index 0000000000..d72aed6a30
--- /dev/null
+++ b/python/pyspark/conf.py
@@ -0,0 +1,171 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+>>> from pyspark.conf import SparkConf
+>>> from pyspark.context import SparkContext
+>>> conf = SparkConf()
+>>> conf.setMaster("local").setAppName("My app")
+<pyspark.conf.SparkConf object at ...>
+>>> conf.get("spark.master")
+u'local'
+>>> conf.get("spark.app.name")
+u'My app'
+>>> sc = SparkContext(conf=conf)
+>>> sc.master
+u'local'
+>>> sc.appName
+u'My app'
+>>> sc.sparkHome == None
+True
+
+>>> conf = SparkConf()
+>>> conf.setSparkHome("/path")
+<pyspark.conf.SparkConf object at ...>
+>>> conf.get("spark.home")
+u'/path'
+>>> conf.setExecutorEnv("VAR1", "value1")
+<pyspark.conf.SparkConf object at ...>
+>>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")])
+<pyspark.conf.SparkConf object at ...>
+>>> conf.get("spark.executorEnv.VAR1")
+u'value1'
+>>> print conf.toDebugString()
+spark.executorEnv.VAR1=value1
+spark.executorEnv.VAR3=value3
+spark.executorEnv.VAR4=value4
+spark.home=/path
+>>> sorted(conf.getAll(), key=lambda p: p[0])
+[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
+"""
+
+
+class SparkConf(object):
+ """
+ Configuration for a Spark application. Used to set various Spark
+ parameters as key-value pairs.
+
+ Most of the time, you would create a SparkConf object with
+ C{SparkConf()}, which will load values from C{spark.*} Java system
+ properties and any C{spark.conf} on your Spark classpath. In this
+ case, system properties take priority over C{spark.conf}, and any
+ parameters you set directly on the C{SparkConf} object take priority
+ over both of those.
+
+ For unit tests, you can also call C{SparkConf(false)} to skip
+ loading external settings and get the same configuration no matter
+ what is on the classpath.
+
+ All setter methods in this class support chaining. For example,
+ you can write C{conf.setMaster("local").setAppName("My app")}.
+
+ Note that once a SparkConf object is passed to Spark, it is cloned
+ and can no longer be modified by the user.
+ """
+
+ def __init__(self, loadDefaults=True, _jvm=None):
+ """
+ Create a new Spark configuration.
+
+ @param loadDefaults: whether to load values from Java system
+ properties and classpath (True by default)
+ @param _jvm: internal parameter used to pass a handle to the
+ Java VM; does not need to be set by users
+ """
+ from pyspark.context import SparkContext
+ SparkContext._ensure_initialized()
+ _jvm = _jvm or SparkContext._jvm
+ self._jconf = _jvm.SparkConf(loadDefaults)
+
+ def set(self, key, value):
+ """Set a configuration property."""
+ self._jconf.set(key, unicode(value))
+ return self
+
+ def setMaster(self, value):
+ """Set master URL to connect to."""
+ self._jconf.setMaster(value)
+ return self
+
+ def setAppName(self, value):
+ """Set application name."""
+ self._jconf.setAppName(value)
+ return self
+
+ def setSparkHome(self, value):
+ """Set path where Spark is installed on worker nodes."""
+ self._jconf.setSparkHome(value)
+ return self
+
+ def setExecutorEnv(self, key=None, value=None, pairs=None):
+ """Set an environment variable to be passed to executors."""
+ if (key != None and pairs != None) or (key == None and pairs == None):
+ raise Exception("Either pass one key-value pair or a list of pairs")
+ elif key != None:
+ self._jconf.setExecutorEnv(key, value)
+ elif pairs != None:
+ for (k, v) in pairs:
+ self._jconf.setExecutorEnv(k, v)
+ return self
+
+ def setAll(self, pairs):
+ """
+ Set multiple parameters, passed as a list of key-value pairs.
+
+ @param pairs: list of key-value pairs to set
+ """
+ for (k, v) in pairs:
+ self._jconf.set(k, v)
+ return self
+
+ def get(self, key, defaultValue=None):
+ """Get the configured value for some key, or return a default otherwise."""
+ if defaultValue == None: # Py4J doesn't call the right get() if we pass None
+ if not self._jconf.contains(key):
+ return None
+ return self._jconf.get(key)
+ else:
+ return self._jconf.get(key, defaultValue)
+
+ def getAll(self):
+ """Get all values as a list of key-value pairs."""
+ pairs = []
+ for elem in self._jconf.getAll():
+ pairs.append((elem._1(), elem._2()))
+ return pairs
+
+ def contains(self, key):
+ """Does this configuration contain a given key?"""
+ return self._jconf.contains(key)
+
+ def toDebugString(self):
+ """
+ Returns a printable version of the configuration, as a list of
+ key=value pairs, one per line.
+ """
+ return self._jconf.toDebugString()
+
+
+def _test():
+ import doctest
+ (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 597110321a..f955aad7a4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -24,9 +24,10 @@ from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -42,21 +43,22 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
- _takePartition = None
+ _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
- def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024):
+
+ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
+ environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None):
"""
- Create a new SparkContext.
+ Create a new SparkContext. At least the master and app name should be set,
+ either through the named parameters here or through C{conf}.
@param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
- @param jobName: A name for your job, to display on the cluster web UI
+ @param appName: A name for your job, to display on the cluster web UI.
@param sparkHome: Location where Spark is installed on cluster nodes.
@param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
@@ -66,29 +68,59 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+ @param serializer: The serializer for RDDs.
+ @param conf: A L{SparkConf} object setting Spark properties.
+
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+
+ >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
- with SparkContext._lock:
- if SparkContext._active_spark_context:
- raise ValueError("Cannot run multiple SparkContexts at once")
- else:
- SparkContext._active_spark_context = self
- if not SparkContext._gateway:
- SparkContext._gateway = launch_gateway()
- SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
- SparkContext._takePartition = \
- SparkContext._jvm.PythonRDD.takePartition
- self.master = master
- self.jobName = jobName
- self.sparkHome = sparkHome or None # None becomes null in Py4J
+ SparkContext._ensure_initialized(self)
+
self.environment = environment or {}
- self.batchSize = batchSize # -1 represents a unlimited batch size
+ self._conf = conf or SparkConf(_jvm=self._jvm)
+ self._batchSize = batchSize # -1 represents an unlimited batch size
+ self._unbatched_serializer = serializer
+ if batchSize == 1:
+ self.serializer = self._unbatched_serializer
+ else:
+ self.serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+
+ # Set any parameters passed directly to us on the conf
+ if master:
+ self._conf.setMaster(master)
+ if appName:
+ self._conf.setAppName(appName)
+ if sparkHome:
+ self._conf.setSparkHome(sparkHome)
+ if environment:
+ for key, value in environment.iteritems():
+ self._conf.setExecutorEnv(key, value)
+
+ # Check that we have at least the required parameters
+ if not self._conf.contains("spark.master"):
+ raise Exception("A master URL must be set in your configuration")
+ if not self._conf.contains("spark.app.name"):
+ raise Exception("An application name must be set in your configuration")
+
+ # Read back our properties from the conf in case we loaded some of them from
+ # the classpath or an external config file
+ self.master = self._conf.get("spark.master")
+ self.appName = self._conf.get("spark.app.name")
+ self.sparkHome = self._conf.get("spark.home", None)
+ for (k, v) in self._conf.getAll():
+ if k.startswith("spark.executorEnv."):
+ varName = k[len("spark.executorEnv."):]
+ self.environment[varName] = v
# Create the Java SparkContext through Py4J
- empty_string_array = self._gateway.new_array(self._jvm.String, 0)
- self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
- empty_string_array)
+ self._jsc = self._jvm.JavaSparkContext(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
@@ -99,6 +131,7 @@ class SparkContext(object):
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
@@ -115,10 +148,33 @@ class SparkContext(object):
self.addPyFile(path)
# Create a temporary directory inside spark.local.dir:
- local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir()
+ local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ @classmethod
+ def _ensure_initialized(cls, instance=None):
+ with SparkContext._lock:
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
+
+ if instance:
+ if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = instance
+
+ @classmethod
+ def setSystemProperty(cls, key, value):
+ """
+ Set a Java system property, such as spark.executor.memory. This must
+ must be invoked before instantiating SparkContext.
+ """
+ SparkContext._ensure_initialized()
+ SparkContext._jvm.java.lang.System.setProperty(key, value)
+
@property
def defaultParallelism(self):
"""
@@ -158,15 +214,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self.batchSize)
+ batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
- c = batched(c, batchSize)
- for x in c:
- write_with_length(dump_pickle(x), tempFile)
+ serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+ else:
+ serializer = self._unbatched_serializer
+ serializer.dump_stream(c, tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
- return RDD(jrdd, self)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@@ -175,29 +233,50 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jsc.textFile(name, minSplits)
- return RDD(jrdd, self)
+ return RDD(self._jsc.textFile(name, minSplits), self,
+ MUTF8Deserializer())
- def _checkpointFile(self, name):
+ def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
- return RDD(jrdd, self)
+ return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
+
+ This supports unions() of RDDs with different serialized formats,
+ although this forces them to be reserialized using the default
+ serializer:
+
+ >>> path = os.path.join(tempdir, "union-text.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("Hello")
+ >>> textFile = sc.textFile(path)
+ >>> textFile.collect()
+ [u'Hello']
+ >>> parallelized = sc.parallelize(["World!"])
+ >>> sorted(sc.union([textFile, parallelized]).collect())
+ [u'Hello', 'World!']
"""
+ first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+ if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+ rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
- rest = ListConverter().convert(rest, self.gateway._gateway_client)
- return RDD(self._jsc.union(first, rest), self)
+ rest = ListConverter().convert(rest, self._gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self,
+ rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
- Broadcast a read-only variable to the cluster, returning a C{Broadcast}
+ Broadcast a read-only variable to the cluster, returning a
+ L{Broadcast<pyspark.broadcast.Broadcast>}
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
- jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ pickleSer = PickleSerializer()
+ pickled = pickleSer.dumps(value)
+ jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@@ -209,7 +288,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
- if accum_param == None:
+ if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
@@ -268,17 +347,12 @@ class SparkContext(object):
self._python_includes.append(filename)
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
- def setCheckpointDir(self, dirName, useExisting=False):
+ def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
directory must be a HDFS path if running on a cluster.
-
- If the directory does not exist, it will be created. If the directory
- exists and C{useExisting} is set to true, then the exisiting directory
- will be used. Otherwise an exception will be thrown to prevent
- accidental overriding of checkpoint files in the existing directory.
"""
- self._jsc.sc().setCheckpointDir(dirName, useExisting)
+ self._jsc.sc().setCheckpointDir(dirName)
def _getJavaStorageLevel(self, storageLevel):
"""
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index b872ae61d5..7243ee6861 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -60,7 +60,9 @@ def launch_gateway():
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
+ java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
new file mode 100644
index 0000000000..b1a5df109b
--- /dev/null
+++ b/python/pyspark/mllib/__init__.py
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python bindings for MLlib.
+"""
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
new file mode 100644
index 0000000000..e74ba0fabc
--- /dev/null
+++ b/python/pyspark/mllib/_common.py
@@ -0,0 +1,227 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from pyspark import SparkContext
+
+# Double vector format:
+#
+# [8-byte 1] [8-byte length] [length*8 bytes of data]
+#
+# Double matrix format:
+#
+# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
+#
+# This is all in machine-endian. That means that the Java interpreter and the
+# Python interpreter must agree on what endian the machine is.
+
+def _deserialize_byte_array(shape, ba, offset):
+ """Wrapper around ndarray aliasing hack.
+
+ >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0])
+ >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+ True
+ >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2)
+ >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+ True
+ """
+ ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64",
+ order='C')
+ return ar.copy()
+
+def _serialize_double_vector(v):
+ """Serialize a double vector into a mutually understood format."""
+ if type(v) != ndarray:
+ raise TypeError("_serialize_double_vector called on a %s; "
+ "wanted ndarray" % type(v))
+ if v.dtype != float64:
+ raise TypeError("_serialize_double_vector called on an ndarray of %s; "
+ "wanted ndarray of float64" % v.dtype)
+ if v.ndim != 1:
+ raise TypeError("_serialize_double_vector called on a %ddarray; "
+ "wanted a 1darray" % v.ndim)
+ length = v.shape[0]
+ ba = bytearray(16 + 8*length)
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ header[0] = 1
+ header[1] = length
+ copyto(ndarray(shape=[length], buffer=ba, offset=16,
+ dtype="float64"), v)
+ return ba
+
+def _deserialize_double_vector(ba):
+ """Deserialize a double vector from a mutually understood format.
+
+ >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
+ >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x)))
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double_vector called on a %s; "
+ "wanted bytearray" % type(ba))
+ if len(ba) < 16:
+ raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+ "which is too short" % len(ba))
+ if (len(ba) & 7) != 0:
+ raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+ "which is not a multiple of 8" % len(ba))
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ if header[0] != 1:
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong magic")
+ length = header[1]
+ if len(ba) != 8*length + 16:
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([length], ba, 16)
+
+def _serialize_double_matrix(m):
+ """Serialize a double matrix into a mutually understood format."""
+ if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
+ rows = m.shape[0]
+ cols = m.shape[1]
+ ba = bytearray(24 + 8 * rows * cols)
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ header[0] = 2
+ header[1] = rows
+ header[2] = cols
+ copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24,
+ dtype="float64", order='C'), m)
+ return ba
+ else:
+ raise TypeError("_serialize_double_matrix called on a "
+ "non-double-matrix")
+
+def _deserialize_double_matrix(ba):
+ """Deserialize a double matrix from a mutually understood format."""
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double_matrix called on a %s; "
+ "wanted bytearray" % type(ba))
+ if len(ba) < 24:
+ raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+ "which is too short" % len(ba))
+ if (len(ba) & 7) != 0:
+ raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+ "which is not a multiple of 8" % len(ba))
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ if (header[0] != 2):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong magic")
+ rows = header[1]
+ cols = header[2]
+ if (len(ba) != 8*rows*cols + 24):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([rows, cols], ba, 24)
+
+def _linear_predictor_typecheck(x, coeffs):
+ """Check that x is a one-dimensional vector of the right shape.
+ This is a temporary hackaround until I actually implement bulk predict."""
+ if type(x) == ndarray:
+ if x.ndim == 1:
+ if x.shape == coeffs.shape:
+ pass
+ else:
+ raise RuntimeError("Got array of %d elements; wanted %d"
+ % (shape(x)[0], shape(coeffs)[0]))
+ else:
+ raise RuntimeError("Bulk predict not yet supported.")
+ elif (type(x) == RDD):
+ raise RuntimeError("Bulk predict not yet supported.")
+ else:
+ raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
+
+def _get_unmangled_rdd(data, serializer):
+ dataBytes = data.map(serializer)
+ dataBytes._bypass_serializer = True
+ dataBytes.cache()
+ return dataBytes
+
+# Map a pickled Python RDD of numpy double vectors to a Java RDD of
+# _serialized_double_vectors
+def _get_unmangled_double_vector_rdd(data):
+ return _get_unmangled_rdd(data, _serialize_double_vector)
+
+class LinearModel(object):
+ """Something that has a vector of coefficients and an intercept."""
+ def __init__(self, coeff, intercept):
+ self._coeff = coeff
+ self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+ """A linear regression model.
+
+ >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+ >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+ True
+ """
+ def predict(self, x):
+ """Predict the value of the dependent variable given a vector x"""
+ """containing values for the independent variables."""
+ _linear_predictor_typecheck(x, self._coeff)
+ return dot(self._coeff, x) + self._intercept
+
+# If we weren't given initial weights, take a zero vector of the appropriate
+# length.
+def _get_initial_weights(initial_weights, data):
+ if initial_weights is None:
+ initial_weights = data.first()
+ if type(initial_weights) != ndarray:
+ raise TypeError("At least one data element has type "
+ + type(initial_weights).__name__ + " which is not ndarray")
+ if initial_weights.ndim != 1:
+ raise TypeError("At least one data element has "
+ + initial_weights.ndim + " dimensions, which is not 1")
+ initial_weights = ones([initial_weights.shape[0] - 1])
+ return initial_weights
+
+# train_func should take two parameters, namely data and initial_weights, and
+# return the result of a call to the appropriate JVM stub.
+# _regression_train_wrapper is responsible for setup and error checking.
+def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
+ initial_weights = _get_initial_weights(initial_weights, data)
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = train_func(dataBytes, _serialize_double_vector(initial_weights))
+ if len(ans) != 2:
+ raise RuntimeError("JVM call result had unexpected length")
+ elif type(ans[0]) != bytearray:
+ raise RuntimeError("JVM call result had first element of type "
+ + type(ans[0]).__name__ + " which is not bytearray")
+ elif type(ans[1]) != float:
+ raise RuntimeError("JVM call result had second element of type "
+ + type(ans[0]).__name__ + " which is not float")
+ return klass(_deserialize_double_vector(ans[0]), ans[1])
+
+def _serialize_rating(r):
+ ba = bytearray(16)
+ intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+ doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8)
+ intpart[0], intpart[1], doublepart[0] = r
+ return ba
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
new file mode 100644
index 0000000000..70de332d34
--- /dev/null
+++ b/python/pyspark/mllib/classification.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot, shape
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ LinearModel, _linear_predictor_typecheck
+from math import exp, log
+
+class LogisticRegressionModel(LinearModel):
+ """A linear binary classification model derived from logistic regression.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+ >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data))
+ >>> lrm.predict(array([1.0])) != None
+ True
+ """
+ def predict(self, x):
+ _linear_predictor_typecheck(x, self._coeff)
+ margin = dot(x, self._coeff) + self._intercept
+ prob = 1/(1 + exp(-margin))
+ return 1 if prob > 0.5 else 0
+
+class LogisticRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a logistic regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
+ iterations, step, mini_batch_fraction, i),
+ LogisticRegressionModel, data, initial_weights)
+
+class SVMModel(LinearModel):
+ """A support vector machine.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+ >>> svm = SVMWithSGD.train(sc, sc.parallelize(data))
+ >>> svm.predict(array([1.0])) != None
+ True
+ """
+ def predict(self, x):
+ _linear_predictor_typecheck(x, self._coeff)
+ margin = dot(x, self._coeff) + self._intercept
+ return 1 if margin >= 0 else 0
+
+class SVMWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a support vector machine on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ SVMModel, data, initial_weights)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
new file mode 100644
index 0000000000..8cf20e591a
--- /dev/null
+++ b/python/pyspark/mllib/clustering.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from math import sqrt
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class KMeansModel(object):
+ """A clustering model derived from the k-means method.
+
+ >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
+ >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random")
+ >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0]))
+ True
+ >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0]))
+ True
+ >>> clusters = KMeans.train(sc, sc.parallelize(data), 2)
+ """
+ def __init__(self, centers_):
+ self.centers = centers_
+
+ def predict(self, x):
+ """Find the cluster to which x belongs in this model."""
+ best = 0
+ best_distance = 1e75
+ for i in range(0, self.centers.shape[0]):
+ diff = x - self.centers[i]
+ distance = sqrt(dot(diff, diff))
+ if distance < best_distance:
+ best = i
+ best_distance = distance
+ return best
+
+class KMeans(object):
+ @classmethod
+ def train(cls, sc, data, k, maxIterations=100, runs=1,
+ initialization_mode="k-means||"):
+ """Train a k-means clustering model."""
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd,
+ k, maxIterations, runs, initialization_mode)
+ if len(ans) != 1:
+ raise RuntimeError("JVM call result had unexpected length")
+ elif type(ans[0]) != bytearray:
+ raise RuntimeError("JVM call result had first element of type "
+ + type(ans[0]) + " which is not bytearray")
+ return KMeansModel(_deserialize_double_matrix(ans[0]))
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
new file mode 100644
index 0000000000..14d06cba21
--- /dev/null
+++ b/python/pyspark/mllib/recommendation.py
@@ -0,0 +1,74 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class MatrixFactorizationModel(object):
+ """A matrix factorisation model trained by regularized alternating
+ least-squares.
+
+ >>> r1 = (1, 1, 1.0)
+ >>> r2 = (1, 2, 2.0)
+ >>> r3 = (2, 1, 2.0)
+ >>> ratings = sc.parallelize([r1, r2, r3])
+ >>> model = ALS.trainImplicit(sc, ratings, 1)
+ >>> model.predict(2,2) is not None
+ True
+ """
+
+ def __init__(self, sc, java_model):
+ self._context = sc
+ self._java_model = java_model
+
+ def __del__(self):
+ self._context._gateway.detach(self._java_model)
+
+ def predict(self, user, product):
+ return self._java_model.predict(user, product)
+
+class ALS(object):
+ @classmethod
+ def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks)
+ return MatrixFactorizationModel(sc, mod)
+
+ @classmethod
+ def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks, alpha)
+ return MatrixFactorizationModel(sc, mod)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
new file mode 100644
index 0000000000..a3a68b29e0
--- /dev/null
+++ b/python/pyspark/mllib/regression.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ _linear_predictor_typecheck
+
+class LinearModel(object):
+ """Something that has a vector of coefficients and an intercept."""
+ def __init__(self, coeff, intercept):
+ self._coeff = coeff
+ self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+ """A linear regression model.
+
+ >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+ >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+ True
+ """
+ def predict(self, x):
+ """Predict the value of the dependent variable given a vector x"""
+ """containing values for the independent variables."""
+ _linear_predictor_typecheck(x, self._coeff)
+ return dot(self._coeff, x) + self._intercept
+
+class LinearRegressionModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class LinearRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a linear regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
+ d._jrdd, iterations, step, mini_batch_fraction, i),
+ LinearRegressionModel, data, initial_weights)
+
+class LassoModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit with an
+ l_1 penalty term.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class LassoWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a Lasso regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ LassoModel, data, initial_weights)
+
+class RidgeRegressionModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit with an
+ l_2 penalty term.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class RidgeRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a ridge regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ RidgeRegressionModel, data, initial_weights)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 58e1849cad..f87923e6fa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
import operator
import os
import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -48,12 +47,15 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx):
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
- self._partitionFunc = None
+ self._jrdd_deserializer = jrdd_deserializer
+
+ def __repr__(self):
+ return self._jrdd.toString()
@property
def context(self):
@@ -117,8 +119,6 @@ class RDD(object):
else:
return None
- # TODO persist(self, storageLevel)
-
def map(self, f, preservesPartitioning=False):
"""
Return a new RDD containing the distinct elements in this RDD.
@@ -227,7 +227,7 @@ class RDD(object):
total = num
samples = self.sample(withReplacement, fraction, seed).collect()
-
+
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
@@ -249,7 +249,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return RDD(self._jrdd.union(other._jrdd), self.ctx)
+ if self._jrdd_deserializer == other._jrdd_deserializer:
+ rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+ self._jrdd_deserializer)
+ return rdd
+ else:
+ # These RDDs contain data in different serialized formats, so we
+ # must normalize them to the default serializer.
+ self_copy = self._reserialize()
+ other_copy = other._reserialize()
+ return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+
+ def _reserialize(self):
+ if self._jrdd_deserializer == self.ctx.serializer:
+ return self
+ else:
+ return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@@ -263,7 +279,55 @@ class RDD(object):
raise TypeError
return self.union(other)
- # TODO: sort
+ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+ """
+ Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+ >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+ >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
+ >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
+ >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
+ [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+ """
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
+
+ bounds = list()
+
+ # first compute the boundary of each part via sampling: we want to partition
+ # the key-space into bins such that the bins have roughly the same
+ # number of (key, value) pairs falling into them
+ if numPartitions > 1:
+ rddSize = self.count()
+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+ samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+ # we have numPartitions many parts but one of the them has
+ # an implicit boundary
+ for i in range(0, numPartitions - 1):
+ index = (len(samples) - 1) * (i + 1) / numPartitions
+ bounds.append(samples[index])
+
+ def rangePartitionFunc(k):
+ p = 0
+ while p < len(bounds) and keyfunc(k) > bounds[p]:
+ p += 1
+ if ascending:
+ return p
+ else:
+ return numPartitions-1-p
+
+ def mapFunc(iterator):
+ yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+ return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
+ .mapPartitions(mapFunc,preservesPartitioning=True)
+ .flatMap(lambda x: x, preservesPartitioning=True))
def glom(self):
"""
@@ -288,17 +352,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
- java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- def unpack_batches(pair):
- (x, y) = pair
- if type(x) == Batch or type(y) == Batch:
- xs = x.items if type(x) == Batch else [x]
- ys = y.items if type(y) == Batch else [y]
- for pair in product(xs, ys):
- yield pair
- else:
- yield pair
- return java_cartesian.flatMap(unpack_batches)
+ deserializer = CartesianDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@@ -345,8 +401,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -354,10 +410,10 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
- for item in read_from_pickle_file(tempFile):
+ for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@@ -425,7 +481,7 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
-
+
def stats(self):
"""
Return a L{StatCounter} object that captures the mean, variance
@@ -462,7 +518,7 @@ class RDD(object):
0.816...
"""
return self.stats().stdev()
-
+
def sampleStdev(self):
"""
Compute the sample standard deviation of this RDD's elements (which corrects for bias in
@@ -523,9 +579,14 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
- items.extend(self._collect_iterator_through_file(iterator))
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@@ -552,7 +613,10 @@ class RDD(object):
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
"""
def func(split, iterator):
- return (str(x).encode("utf-8") for x in iterator)
+ for x in iterator:
+ if not isinstance(x, basestring):
+ x = unicode(x)
+ yield x.encode("utf-8")
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
@@ -689,20 +753,23 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
+ outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
+
buckets = defaultdict(list)
+
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
- yield str(split)
- yield dump_pickle(Batch(items))
+ yield pack_long(split)
+ yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
- rdd = RDD(jrdd, self.ctx)
+ rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@@ -739,7 +806,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
- for (k, v) in iterator:
+ for x in iterator:
+ (k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@@ -830,9 +898,9 @@ class RDD(object):
>>> y = sc.parallelize([("a", 3), ("c", None)])
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
- """
- filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
- map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
+ """
+ filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
+ map_func = lambda (key, vals): [(key, val) for val in vals[0]]
return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
def subtract(self, other, numPartitions=None):
@@ -881,50 +949,52 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+ # This transformation is the first in its stage:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_jrdd_deserializer = prev._jrdd_deserializer
+ else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
- self._prev_jrdd = prev._prev_jrdd
- else:
- self.func = func
- self.preservesPartitioning = preservesPartitioning
- self._prev_jrdd = prev._jrdd
+ self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
+ self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- func = self.func
- if not self._bypass_serializer and self.ctx.batchSize != 1:
- oldfunc = self.func
- batchSize = self.ctx.batchSize
- def batched_func(split, iterator):
- return batched(oldfunc(split, iterator), batchSize)
- func = batched_func
- cmds = [func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ if self._bypass_serializer:
+ serializer = NoOpSerializer()
+ else:
+ serializer = self.ctx.serializer
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer().dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
- class_manifest = self._prev_jrdd.classManifest()
+ class_tag = self._prev_jrdd.classTag()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fecacd1241..2a500ab919 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
# limitations under the License.
#
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+ END_OF_DATA_SECTION = -1
+ PYTHON_EXCEPTION_THROWN = -2
+ TIMING_DATA = -3
+
+
+class Serializer(object):
+
+ def dump_stream(self, iterator, stream):
+ """
+ Serialize an iterator of objects to the output stream.
+ """
+ raise NotImplementedError
+
+ def load_stream(self, stream):
+ """
+ Return an iterator of deserialized objects from the input stream.
+ """
+ raise NotImplementedError
+
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.load_stream(stream)
+
+ # Note: our notion of "equality" is that output generated by
+ # equal serializers can be deserialized using the same serializer.
+
+ # This default implementation handles the simple cases;
+ # subclasses should override __eq__ as appropriate.
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+ """
+ Serializer that writes objects as a stream of (length, data) pairs,
+ where C{length} is a 32-bit integer and data is C{length} bytes.
+ """
+
+ def dump_stream(self, iterator, stream):
+ for obj in iterator:
+ self._write_with_length(obj, stream)
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self._read_with_length(stream)
+ except EOFError:
+ return
+
+ def _write_with_length(self, obj, stream):
+ serialized = self.dumps(obj)
+ write_int(len(serialized), stream)
+ stream.write(serialized)
+
+ def _read_with_length(self, stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return self.loads(obj)
+
+ def dumps(self, obj):
+ """
+ Serialize an object into a byte array.
+ When batching is used, this will be called with an array of objects.
+ """
+ raise NotImplementedError
+
+ def loads(self, obj):
+ """
+ Deserialize an object from a byte array.
+ """
+ raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+ """
+ Serializes a stream of objects in batches by calling its wrapped
+ Serializer with streams of objects.
+ """
+
+ UNLIMITED_BATCH_SIZE = -1
+
+ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+ self.serializer = serializer
+ self.batchSize = batchSize
+
+ def _batched(self, iterator):
+ if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+ yield list(iterator)
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == self.batchSize:
+ yield items
+ items = []
+ count = 0
+ if items:
+ yield items
+
+ def dump_stream(self, iterator, stream):
+ self.serializer.dump_stream(self._batched(iterator), stream)
+
+ def load_stream(self, stream):
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __eq__(self, other):
+ return isinstance(other, BatchedSerializer) and \
+ other.serializer == self.serializer
+
+ def __str__(self):
+ return "BatchedSerializer<%s>" % str(self.serializer)
+
+
+class CartesianDeserializer(FramedSerializer):
+ """
+ Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ key_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_stream = self.val_ser._load_stream_without_unbatching(stream)
+ key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+ val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+ for (keys, vals) in izip(key_stream, val_stream):
+ keys = keys if key_is_batched else [keys]
+ vals = vals if val_is_batched else [vals]
+ for pair in product(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, CartesianDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "CartesianDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
+
+
+class NoOpSerializer(FramedSerializer):
+
+ def loads(self, obj): return obj
+ def dumps(self, obj): return obj
-class Batch(object):
+class PickleSerializer(FramedSerializer):
"""
- Used to store multiple RDD entries as a single Java object.
+ Serializes objects using Python's cPickle serializer:
- This relieves us from having to explicitly track whether an RDD
- is stored as batches of objects and avoids problems when processing
- the union() of batched and unbatched RDDs (e.g. the union() of textFile()
- with another RDD).
+ http://docs.python.org/2/library/pickle.html
+
+ This serializer supports nearly any Python object, but may
+ not be as fast as more specialized serializers.
+ """
+
+ def dumps(self, obj): return cPickle.dumps(obj, 2)
+ loads = cPickle.loads
+
+class CloudPickleSerializer(PickleSerializer):
+
+ def dumps(self, obj): return cloudpickle.dumps(obj, 2)
+
+
+class MarshalSerializer(FramedSerializer):
"""
- def __init__(self, items):
- self.items = items
+ Serializes objects using Python's Marshal serializer:
+ http://docs.python.org/2/library/marshal.html
-def batched(iterator, batchSize):
- if batchSize == -1: # unlimited batch size
- yield Batch(list(iterator))
- else:
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
- yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
+ This serializer is faster than PickleSerializer but supports fewer datatypes.
+ """
+ dumps = marshal.dumps
+ loads = marshal.loads
-def dump_pickle(obj):
- return cPickle.dumps(obj, 2)
+class MUTF8Deserializer(Serializer):
+ """
+ Deserializes streams written by Java's DataOutputStream.writeUTF().
+ """
+
+ def loads(self, stream):
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
-load_pickle = cPickle.loads
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def read_long(stream):
@@ -67,6 +291,10 @@ def write_long(value, stream):
stream.write(struct.pack("!q", value))
+def pack_long(value):
+ return struct.pack("!q", value)
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
@@ -81,24 +309,3 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
-
-
-def read_with_length(stream):
- length = read_int(stream)
- obj = stream.read(length)
- if obj == "":
- raise EOFError
- return obj
-
-
-def read_from_pickle_file(stream):
- try:
- while True:
- obj = load_pickle(read_with_length(stream))
- if type(obj) == Batch: # We don't care about inheritance
- for item in obj.items:
- yield item
- else:
- yield obj
- except EOFError:
- return
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index a475959090..ef07eb437b 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -42,7 +42,7 @@ print "Using Python version %s (%s, %s)" % (
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1])
-print "Spark context avaiable as sc."
+print "Spark context available as sc."
if add_files != None:
print "Adding files: [%s]" % ", ".join(add_files)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6..7acb6eaf10 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -19,6 +19,8 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
+from fileinput import input
+from glob import glob
import os
import shutil
import sys
@@ -71,8 +73,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
- self.assertEqual(self.checkpointDir.name,
- os.path.dirname(flatMappedRDD.getCheckpointFile()))
+ self.assertEqual("file:" + self.checkpointDir.name,
+ os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
def test_checkpoint_and_restore(self):
parCollection = self.sc.parallelize([1, 2, 3, 4])
@@ -86,7 +88,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
- recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+ flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())
@@ -137,6 +140,19 @@ class TestAddFile(PySparkTestCase):
self.assertEqual("Hello World from inside a package!", UserClass().hello())
+class TestRDDFunctions(PySparkTestCase):
+
+ def test_save_as_textfile_with_unicode(self):
+ # Regression test for SPARK-970
+ x = u"\u00A1Hola, mundo!"
+ data = self.sc.parallelize([x])
+ tempFile = NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsTextFile(tempFile.name)
+ raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
+ self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef7..f2b3f3c142 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
import time
import socket
import traceback
-from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+ write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
-def load_obj(infile):
- return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
def report_times(outfile, boot, init, finish):
- write_int(-3, outfile)
+ write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = mutf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -60,38 +59,33 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ filename = mutf8_deserializer.loads(infile)
+ sys.path.append(os.path.join(spark_files_dir, filename))
- # now load function
- func = load_obj(infile)
- bypassSerializer = load_obj(infile)
- if bypassSerializer:
- dumps = lambda x: x
- else:
- dumps = dump_pickle
+ command = pickleSer._read_with_length(infile)
+ (func, deserializer, serializer) = command
init_time = time.time()
- iterator = read_from_pickle_file(infile)
try:
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), outfile)
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
- write_int(-2, outfile)
+ write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, outfile)
- for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), outfile)
- write_int(-1, outfile)
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ write_int(len(_accumulatorRegistry), outfile)
+ for (aid, accum) in _accumulatorRegistry.items():
+ pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index 8a08ae3df9..feba97cee0 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -29,14 +29,16 @@ FAILED=0
rm -f unit-tests.log
function run_test() {
- $FWDIR/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+ SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
}
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
+run_test "pyspark/conf.py"
run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
index 5bb6f5009f..8e4a6292bc 100755
--- a/python/test_support/userlibrary.py
+++ b/python/test_support/userlibrary.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Used to test shipping of code depenencies with SparkContext.addPyFile().
"""