diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/context.py | 29 | ||||
-rw-r--r-- | python/pyspark/files.py | 2 | ||||
-rw-r--r-- | python/pyspark/java_gateway.py | 12 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 191 | ||||
-rw-r--r-- | python/pyspark/rddsampler.py | 112 | ||||
-rw-r--r-- | python/pyspark/shell.py | 20 | ||||
-rw-r--r-- | python/pyspark/statcounter.py | 109 | ||||
-rw-r--r-- | python/pyspark/tests.py | 24 | ||||
-rw-r--r-- | python/pyspark/worker.py | 13 |
9 files changed, 473 insertions, 39 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2f741cb345..8fbf296509 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -46,6 +46,7 @@ class SparkContext(object): _next_accum_id = 0 _active_spark_context = None _lock = Lock() + _python_includes = None # zip and egg files that need to be added to PYTHONPATH def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -103,16 +104,19 @@ class SparkContext(object): # send. self._pickled_broadcast_vars = set() + SparkFiles._sc = self + root_dir = SparkFiles.getRootDirectory() + sys.path.append(root_dir) + # Deploy any code dependencies specified in the constructor + self._python_includes = list() for path in (pyFiles or []): self.addPyFile(path) - SparkFiles._sc = self - sys.path.append(SparkFiles.getRootDirectory()) # Create a temporary directory inside spark.local.dir: - local_dir = self._jvm.spark.Utils.getLocalDir() + local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir() self._temp_dir = \ - self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): @@ -141,14 +145,21 @@ class SparkContext(object): def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. + + >>> sc.parallelize(range(5), 5).glom().collect() + [[0], [1], [2], [3], [4]] """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - if self.batchSize != 1: - c = batched(c, self.batchSize) + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = min(len(c) // numSlices, self.batchSize) + if batchSize > 1: + c = batched(c, batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() @@ -250,7 +261,11 @@ class SparkContext(object): HTTP, HTTPS or FTP URI. """ self.addFile(path) - filename = path.split("/")[-1] + (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix + + if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): + self._python_includes.append(filename) + sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode def setCheckpointDir(self, dirName, useExisting=False): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 89bcbcfe06..57ee14eeb7 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -52,4 +52,4 @@ class SparkFiles(object): return cls._root_directory else: # This will have to change if we support multiple SparkContexts: - return cls._sc._jvm.spark.SparkFiles.getRootDirectory() + return cls._sc._jvm.org.apache.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index e503fb7621..26fbe0f080 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,6 +17,7 @@ import os import sys +import signal from subprocess import Popen, PIPE from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient @@ -28,9 +29,12 @@ SPARK_HOME = os.environ["SPARK_HOME"] def launch_gateway(): # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and SPARK_MEM settings from spark-env.sh - command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", + command = [os.path.join(SPARK_HOME, "spark-class"), "py4j.GatewayServer", "--die-on-broken-pipe", "0"] - proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_function(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_function) # Determine which ephemeral port the server started on: port = int(proc.stdout.readline()) # Create a thread to echo output from the GatewayServer, which is required @@ -49,7 +53,7 @@ def launch_gateway(): # Connect to the gateway gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark - java_import(gateway.jvm, "spark.api.java.*") - java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.api.java.*") + java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c6a6b24c5a..914118ccdd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -21,6 +21,7 @@ from collections import defaultdict from itertools import chain, ifilter, imap, product import operator import os +import sys import shlex from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile @@ -31,6 +32,8 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from pyspark.statcounter import StatCounter +from pyspark.rddsampler import RDDSampler from py4j.java_collections import ListConverter, MapConverter @@ -160,18 +163,64 @@ class RDD(object): >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) [1, 2, 3] """ - return self.map(lambda x: (x, "")) \ + return self.map(lambda x: (x, None)) \ .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) - # TODO: sampling needs to be re-implemented due to Batch - #def sample(self, withReplacement, fraction, seed): - # jrdd = self._jrdd.sample(withReplacement, fraction, seed) - # return RDD(jrdd, self.ctx) + def sample(self, withReplacement, fraction, seed): + """ + Return a sampled subset of this RDD (relies on numpy and falls back + on default random generator if numpy is unavailable). + + >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP + [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] + """ + return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True) + + # this is ported from scala/spark/RDD.scala + def takeSample(self, withReplacement, num, seed): + """ + Return a fixed-size sampled subset of this RDD (currently requires numpy). + + >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP + [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + """ + + fraction = 0.0 + total = 0 + multiplier = 3.0 + initialCount = self.count() + maxSelected = 0 + + if (num < 0): + raise ValueError - #def takeSample(self, withReplacement, num, seed): - # vals = self._jrdd.takeSample(withReplacement, num, seed) - # return [load_pickle(bytes(x)) for x in vals] + if initialCount > sys.maxint - 1: + maxSelected = sys.maxint - 1 + else: + maxSelected = initialCount + + if num > initialCount and not withReplacement: + total = maxSelected + fraction = multiplier * (maxSelected + 1) / initialCount + else: + fraction = multiplier * (num + 1) / initialCount + total = num + + samples = self.sample(withReplacement, fraction, seed).collect() + + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < total: + if seed > sys.maxint - 2: + seed = -1 + seed += 1 + samples = self.sample(withReplacement, fraction, seed).collect() + + sampler = RDDSampler(withReplacement, fraction, seed+1) + sampler.shuffle(samples) + return samples[0:total] def union(self, other): """ @@ -267,7 +316,11 @@ class RDD(object): >>> def f(x): print x >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - self.map(f).collect() # Force evaluation + def processPartition(iterator): + for x in iterator: + f(x) + yield None + self.mapPartitions(processPartition).collect() # Force evaluation def collect(self): """ @@ -353,6 +406,63 @@ class RDD(object): 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + + def stats(self): + """ + Return a L{StatCounter} object that captures the mean, variance + and count of the RDD's elements in one operation. + """ + def redFunc(left_counter, right_counter): + return left_counter.mergeStats(right_counter) + + return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc) + + def mean(self): + """ + Compute the mean of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).mean() + 2.0 + """ + return self.stats().mean() + + def variance(self): + """ + Compute the variance of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).variance() + 0.666... + """ + return self.stats().variance() + + def stdev(self): + """ + Compute the standard deviation of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).stdev() + 0.816... + """ + return self.stats().stdev() + + def sampleStdev(self): + """ + Compute the sample standard deviation of this RDD's elements (which corrects for bias in + estimating the standard deviation by dividing by N-1 instead of N). + + >>> sc.parallelize([1, 2, 3]).sampleStdev() + 1.0 + """ + return self.stats().sampleStdev() + + def sampleVariance(self): + """ + Compute the sample variance of this RDD's elements (which corrects for bias in + estimating the variance by dividing by N-1 instead of N). + + >>> sc.parallelize([1, 2, 3]).sampleVariance() + 1.0 + """ + return self.stats().sampleVariance() def countByValue(self): """ @@ -386,13 +496,16 @@ class RDD(object): >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ + def takeUpToNum(iterator): + taken = 0 + while taken < num: + yield next(iterator) + taken += 1 + # Take only up to num elements from each partition we try + mapped = self.mapPartitions(takeUpToNum) items = [] - for partition in range(self._jrdd.splits().size()): - iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) - # Each item in the iterator is a string, Python object, batch of - # Python objects. Regardless, it is sufficient to take `num` - # of these objects in order to collect `num` Python objects: - iterator = iterator.take(num) + for partition in range(mapped._jrdd.splits().size()): + iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break @@ -689,6 +802,43 @@ class RDD(object): """ return python_cogroup(self, other, numPartitions) + def subtractByKey(self, other, numPartitions=None): + """ + Return each (key, value) pair in C{self} that has no pair with matching key + in C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)]) + >>> y = sc.parallelize([("a", 3), ("c", None)]) + >>> sorted(x.subtractByKey(y).collect()) + [('b', 4), ('b', 5)] + """ + filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0 + map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]] + return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) + + def subtract(self, other, numPartitions=None): + """ + Return each value in C{self} that is not contained in C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)]) + >>> y = sc.parallelize([("a", 3), ("c", None)]) + >>> sorted(x.subtract(y).collect()) + [('a', 1), ('b', 4), ('b', 5)] + """ + rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder + return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder + + def keyBy(self, f): + """ + Creates tuples of the elements in this RDD by applying C{f}. + + >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) + >>> y = sc.parallelize(zip(range(0,5), range(0,5))) + >>> sorted(x.cogroup(y).collect()) + [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + """ + return self.map(lambda x: (f(x), x)) + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those @@ -749,11 +899,12 @@ class PipelinedRDD(RDD): self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() - env = copy.copy(self.ctx.environment) - env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") - env = MapConverter().convert(env, self.ctx._gateway._gateway_client) + env = MapConverter().convert(self.ctx.environment, + self.ctx._gateway._gateway_client) + includes = ListConverter().convert(self.ctx._python_includes, + self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, + pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val @@ -769,7 +920,7 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py new file mode 100644 index 0000000000..aca2ef3b51 --- /dev/null +++ b/python/pyspark/rddsampler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import random + +class RDDSampler(object): + def __init__(self, withReplacement, fraction, seed): + try: + import numpy + self._use_numpy = True + except ImportError: + print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." + self._use_numpy = False + + self._seed = seed + self._withReplacement = withReplacement + self._fraction = fraction + self._random = None + self._split = None + self._rand_initialized = False + + def initRandomGenerator(self, split): + if self._use_numpy: + import numpy + self._random = numpy.random.RandomState(self._seed) + for _ in range(0, split): + # discard the next few values in the sequence to have a + # different seed for the different splits + self._random.randint(sys.maxint) + else: + import random + random.seed(self._seed) + for _ in range(0, split): + # discard the next few values in the sequence to have a + # different seed for the different splits + random.randint(0, sys.maxint) + self._split = split + self._rand_initialized = True + + def getUniformSample(self, split): + if not self._rand_initialized or split != self._split: + self.initRandomGenerator(split) + + if self._use_numpy: + return self._random.random_sample() + else: + return random.uniform(0.0, 1.0) + + def getPoissonSample(self, split, mean): + if not self._rand_initialized or split != self._split: + self.initRandomGenerator(split) + + if self._use_numpy: + return self._random.poisson(mean) + else: + # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by + # drawing a sequence of numbers delta_j ~ Exp(mean) + num_arrivals = 1 + cur_time = 0.0 + + cur_time += random.expovariate(mean) + + if cur_time > 1.0: + return 0 + + while(cur_time <= 1.0): + cur_time += random.expovariate(mean) + num_arrivals += 1 + + return (num_arrivals - 1) + + def shuffle(self, vals): + if self._random == None or split != self._split: + self.initRandomGenerator(0) # this should only ever called on the master so + # the split does not matter + + if self._use_numpy: + self._random.shuffle(vals) + else: + random.shuffle(vals, self._random) + + def func(self, split, iterator): + if self._withReplacement: + for obj in iterator: + # For large datasets, the expected number of occurrences of each element in a sample with + # replacement is Poisson(frac). We use that to get a count for each element. + count = self.getPoissonSample(split, mean = self._fraction) + for _ in range(0, count): + yield obj + else: + for obj in iterator: + if self.getUniformSample(split) <= self._fraction: + yield obj + + + + diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index cc8cd9e3c4..54823f8037 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -21,13 +21,31 @@ An interactive shell. This file is designed to be launched as a PYTHONSTARTUP script. """ import os +import platform import pyspark from pyspark.context import SparkContext +# this is the equivalent of ADD_JARS +add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None -sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell") +sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files) + +print """Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /__ / .__/\_,_/_/ /_/\_\ version 0.8.0 + /_/ +""" +print "Using Python version %s (%s, %s)" % ( + platform.python_version(), + platform.python_build()[0], + platform.python_build()[1]) print "Spark context avaiable as sc." +if add_files != None: + print "Adding files: [%s]" % ", ".join(add_files) + # The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py new file mode 100644 index 0000000000..8e1cbd4ad9 --- /dev/null +++ b/python/pyspark/statcounter.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file is ported from spark/util/StatCounter.scala + +import copy +import math + +class StatCounter(object): + + def __init__(self, values=[]): + self.n = 0L # Running count of our values + self.mu = 0.0 # Running mean of our values + self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) + + for v in values: + self.merge(v) + + # Add a value into this StatCounter, updating the internal statistics. + def merge(self, value): + delta = value - self.mu + self.n += 1 + self.mu += delta / self.n + self.m2 += delta * (value - self.mu) + return self + + # Merge another StatCounter into this one, adding up the internal statistics. + def mergeStats(self, other): + if not isinstance(other, StatCounter): + raise Exception("Can only merge Statcounters!") + + if other is self: # reference equality holds + self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order + else: + if self.n == 0: + self.mu = other.mu + self.m2 = other.m2 + self.n = other.n + elif other.n != 0: + delta = other.mu - self.mu + if other.n * 10 < self.n: + self.mu = self.mu + (delta * other.n) / (self.n + other.n) + elif self.n * 10 < other.n: + self.mu = other.mu - (delta * self.n) / (self.n + other.n) + else: + self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) + + self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) + self.n += other.n + return self + + # Clone this StatCounter + def copy(self): + return copy.deepcopy(self) + + def count(self): + return self.n + + def mean(self): + return self.mu + + def sum(self): + return self.n * self.mu + + # Return the variance of the values. + def variance(self): + if self.n == 0: + return float('nan') + else: + return self.m2 / self.n + + # + # Return the sample variance, which corrects for bias in estimating the variance by dividing + # by N-1 instead of N. + # + def sampleVariance(self): + if self.n <= 1: + return float('nan') + else: + return self.m2 / (self.n - 1) + + # Return the standard deviation of the values. + def stdev(self): + return math.sqrt(self.variance()) + + # + # Return the sample standard deviation of the values, which corrects for bias in estimating the + # variance by dividing by N-1 instead of N. + # + def sampleStdev(self): + return math.sqrt(self.sampleVariance()) + + def __repr__(self): + return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev()) + diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dfd841b10a..29d6a128f6 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -64,7 +64,7 @@ class TestCheckpoint(PySparkTestCase): flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertIsNone(flatMappedRDD.getCheckpointFile()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) flatMappedRDD.checkpoint() result = flatMappedRDD.collect() @@ -79,13 +79,13 @@ class TestCheckpoint(PySparkTestCase): flatMappedRDD = parCollection.flatMap(lambda x: [x]) self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertIsNone(flatMappedRDD.getCheckpointFile()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) flatMappedRDD.checkpoint() flatMappedRDD.count() # forces a checkpoint to be computed time.sleep(1) # 1 second - self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) self.assertEquals([1, 2, 3, 4], recovered.collect()) @@ -125,6 +125,17 @@ class TestAddFile(PySparkTestCase): from userlibrary import UserClass self.assertEqual("Hello World!", UserClass().hello()) + def test_add_egg_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlib import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg") + self.sc.addPyFile(path) + from userlib import UserClass + self.assertEqual("Hello World from inside a package!", UserClass().hello()) + class TestIO(PySparkTestCase): @@ -164,9 +175,12 @@ class TestDaemon(unittest.TestCase): time.sleep(1) # daemon should no longer accept connections - with self.assertRaises(EnvironmentError) as trap: + try: self.connect(port) - self.assertEqual(trap.exception.errno, ECONNREFUSED) + except EnvironmentError as exception: + self.assertEqual(exception.errno, ECONNREFUSED) + else: + self.fail("Expected EnvironmentError to be raised") def test_termination_stdin(self): """Ensure that daemon and workers terminate when stdin is closed.""" diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 75d692beeb..695f6dfb84 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -49,15 +49,26 @@ def main(infile, outfile): split_index = read_int(infile) if split_index == -1: # for unit tests return + + # fetch name of workdir spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True - sys.path.append(spark_files_dir) + + # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + + # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH + sys.path.append(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + + # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) if bypassSerializer: |