diff options
-rw-r--r-- | core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 3 | ||||
-rw-r--r-- | python/pyspark/java_gateway.py | 6 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 99 | ||||
-rw-r--r-- | python/pyspark/rddsampler.py | 112 |
5 files changed, 215 insertions, 9 deletions
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index 6c41b97780..e512423fd6 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.util.ReflectionUtils import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, SparkEnv, TaskContext} import spark.util.NextIterator -import org.apache.hadoop.conf.Configurable +import org.apache.hadoop.conf.{Configuration, Configurable} /** @@ -132,4 +132,6 @@ class HadoopRDD[K, V]( override def checkpoint() { // Do nothing. Hadoop RDD should not be checkpointed. } + + def getConf: Configuration = confBroadcast.value.value } diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 184685528e..b1877dc06e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -120,4 +120,7 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopPartition] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } + + def getConf: Configuration = confBroadcast.value.value } + diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 18011c0dc9..3ccf062c86 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,6 +17,7 @@ import os import sys +import signal from subprocess import Popen, PIPE from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient @@ -30,7 +31,10 @@ def launch_gateway(): # proper classpath and SPARK_MEM settings from spark-env.sh command = [os.path.join(SPARK_HOME, "spark-class"), "py4j.GatewayServer", "--die-on-broken-pipe", "0"] - proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_function(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_function) # Determine which ephemeral port the server started on: port = int(proc.stdout.readline()) # Create a thread to echo output from the GatewayServer, which is required diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1e9b3bb5c0..914118ccdd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -21,6 +21,7 @@ from collections import defaultdict from itertools import chain, ifilter, imap, product import operator import os +import sys import shlex from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile @@ -32,6 +33,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter +from pyspark.rddsampler import RDDSampler from py4j.java_collections import ListConverter, MapConverter @@ -165,14 +167,60 @@ class RDD(object): .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) - # TODO: sampling needs to be re-implemented due to Batch - #def sample(self, withReplacement, fraction, seed): - # jrdd = self._jrdd.sample(withReplacement, fraction, seed) - # return RDD(jrdd, self.ctx) + def sample(self, withReplacement, fraction, seed): + """ + Return a sampled subset of this RDD (relies on numpy and falls back + on default random generator if numpy is unavailable). + + >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP + [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] + """ + return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True) + + # this is ported from scala/spark/RDD.scala + def takeSample(self, withReplacement, num, seed): + """ + Return a fixed-size sampled subset of this RDD (currently requires numpy). + + >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP + [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + """ + + fraction = 0.0 + total = 0 + multiplier = 3.0 + initialCount = self.count() + maxSelected = 0 + + if (num < 0): + raise ValueError + + if initialCount > sys.maxint - 1: + maxSelected = sys.maxint - 1 + else: + maxSelected = initialCount + + if num > initialCount and not withReplacement: + total = maxSelected + fraction = multiplier * (maxSelected + 1) / initialCount + else: + fraction = multiplier * (num + 1) / initialCount + total = num - #def takeSample(self, withReplacement, num, seed): - # vals = self._jrdd.takeSample(withReplacement, num, seed) - # return [load_pickle(bytes(x)) for x in vals] + samples = self.sample(withReplacement, fraction, seed).collect() + + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < total: + if seed > sys.maxint - 2: + seed = -1 + seed += 1 + samples = self.sample(withReplacement, fraction, seed).collect() + + sampler = RDDSampler(withReplacement, fraction, seed+1) + sampler.shuffle(samples) + return samples[0:total] def union(self, other): """ @@ -754,6 +802,43 @@ class RDD(object): """ return python_cogroup(self, other, numPartitions) + def subtractByKey(self, other, numPartitions=None): + """ + Return each (key, value) pair in C{self} that has no pair with matching key + in C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)]) + >>> y = sc.parallelize([("a", 3), ("c", None)]) + >>> sorted(x.subtractByKey(y).collect()) + [('b', 4), ('b', 5)] + """ + filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0 + map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]] + return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) + + def subtract(self, other, numPartitions=None): + """ + Return each value in C{self} that is not contained in C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)]) + >>> y = sc.parallelize([("a", 3), ("c", None)]) + >>> sorted(x.subtract(y).collect()) + [('a', 1), ('b', 4), ('b', 5)] + """ + rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder + return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder + + def keyBy(self, f): + """ + Creates tuples of the elements in this RDD by applying C{f}. + + >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) + >>> y = sc.parallelize(zip(range(0,5), range(0,5))) + >>> sorted(x.cogroup(y).collect()) + [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + """ + return self.map(lambda x: (f(x), x)) + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py new file mode 100644 index 0000000000..aca2ef3b51 --- /dev/null +++ b/python/pyspark/rddsampler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import random + +class RDDSampler(object): + def __init__(self, withReplacement, fraction, seed): + try: + import numpy + self._use_numpy = True + except ImportError: + print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." + self._use_numpy = False + + self._seed = seed + self._withReplacement = withReplacement + self._fraction = fraction + self._random = None + self._split = None + self._rand_initialized = False + + def initRandomGenerator(self, split): + if self._use_numpy: + import numpy + self._random = numpy.random.RandomState(self._seed) + for _ in range(0, split): + # discard the next few values in the sequence to have a + # different seed for the different splits + self._random.randint(sys.maxint) + else: + import random + random.seed(self._seed) + for _ in range(0, split): + # discard the next few values in the sequence to have a + # different seed for the different splits + random.randint(0, sys.maxint) + self._split = split + self._rand_initialized = True + + def getUniformSample(self, split): + if not self._rand_initialized or split != self._split: + self.initRandomGenerator(split) + + if self._use_numpy: + return self._random.random_sample() + else: + return random.uniform(0.0, 1.0) + + def getPoissonSample(self, split, mean): + if not self._rand_initialized or split != self._split: + self.initRandomGenerator(split) + + if self._use_numpy: + return self._random.poisson(mean) + else: + # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by + # drawing a sequence of numbers delta_j ~ Exp(mean) + num_arrivals = 1 + cur_time = 0.0 + + cur_time += random.expovariate(mean) + + if cur_time > 1.0: + return 0 + + while(cur_time <= 1.0): + cur_time += random.expovariate(mean) + num_arrivals += 1 + + return (num_arrivals - 1) + + def shuffle(self, vals): + if self._random == None or split != self._split: + self.initRandomGenerator(0) # this should only ever called on the master so + # the split does not matter + + if self._use_numpy: + self._random.shuffle(vals) + else: + random.shuffle(vals, self._random) + + def func(self, split, iterator): + if self._withReplacement: + for obj in iterator: + # For large datasets, the expected number of occurrences of each element in a sample with + # replacement is Poisson(frac). We use that to get a count for each element. + count = self.getPoissonSample(split, mean = self._fraction) + for _ in range(0, count): + yield obj + else: + for obj in iterator: + if self.getUniformSample(split) <= self._fraction: + yield obj + + + + |