diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-02-03 21:36:36 -0800 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-02-03 21:36:36 -0800 |
commit | f7b4e428be75d189b9ae50c4302c08f3c49e0161 (patch) | |
tree | 4a8d5ffa8bf1d8094295653c0a2fe8ccaa65558b | |
parent | 3bfaf3ab1d14a85a749f87f1bcd37e553e8440e7 (diff) | |
parent | e61729113d3bf165d1ab9bd83ea55d52fd0bb72e (diff) | |
download | spark-f7b4e428be75d189b9ae50c4302c08f3c49e0161.tar.gz spark-f7b4e428be75d189b9ae50c4302c08f3c49e0161.tar.bz2 spark-f7b4e428be75d189b9ae50c4302c08f3c49e0161.zip |
Merge pull request #445 from JoshRosen/pyspark_fixes
Fix exit status in PySpark unit tests; fix/optimize PySpark's RDD.take()
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 11 | ||||
-rw-r--r-- | python/pyspark/accumulators.py | 9 | ||||
-rw-r--r-- | python/pyspark/broadcast.py | 9 | ||||
-rw-r--r-- | python/pyspark/context.py | 4 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 8 |
5 files changed, 19 insertions, 22 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 39758e94f4..ab8351e55e 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -238,6 +238,11 @@ private[spark] object PythonRDD { } def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeIteratorToPickleFile(items.asScala, filename) + } + + def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) @@ -245,8 +250,10 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = - rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head + def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { + implicit val cm : ClassManifest[T] = rdd.elementClassManifest + rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator + } } private object Pickle { diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 61fcbbd376..3e9d7d36da 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -196,12 +196,3 @@ def _start_update_server(): thread.daemon = True thread.start() return server - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 93876fa738..def810dd46 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -37,12 +37,3 @@ class Broadcast(object): def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6831f9b7f8..657fe6f989 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -256,8 +256,10 @@ def _test(): globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 41ea6e6e14..4cda6cf661 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -372,6 +372,10 @@ class RDD(object): items = [] for partition in range(self._jrdd.splits().size()): iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) + # Each item in the iterator is a string, Python object, batch of + # Python objects. Regardless, it is sufficient to take `num` + # of these objects in order to collect `num` Python objects: + iterator = iterator.take(num) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break @@ -748,8 +752,10 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__": |