From 60e18ce7dd1016647b63586520b713efc45494a8 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 4 Apr 2014 17:29:29 -0700 Subject: SPARK-1414. Python API for SparkContext.wholeTextFiles Also clarified comment on each file having to fit in memory Author: Matei Zaharia Closes #327 from mateiz/py-whole-files and squashes the following commits: 9ad64a5 [Matei Zaharia] SPARK-1414. Python API for SparkContext.wholeTextFiles --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/api/java/JavaSparkContext.scala | 2 +- .../org/apache/spark/api/python/PythonRDD.scala | 6 ++- python/pyspark/context.py | 44 +++++++++++++++++++++- python/pyspark/serializers.py | 2 +- 5 files changed, 49 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 28a865c0ad..835cffe37a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -395,7 +395,7 @@ class SparkContext( * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @note Small files are perferred, large file is also allowable, but may cause bad performance. + * @note Small files are preferred, as each file will be loaded fully in memory. */ def wholeTextFiles(path: String): RDD[(String, String)] = { newAPIHadoopFile( diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 6cbdeac58d..a2855d4db1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @note Small files are perferred, large file is also allowable, but may cause bad performance. + * @note Small files are preferred, as each file will be loaded fully in memory. */ def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path)) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b67286a4e3..32f1100406 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ @@ -206,6 +207,7 @@ private object SpecialLengths { } private[spark] object PythonRDD { + val UTF8 = Charset.forName("UTF-8") def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { @@ -266,7 +268,7 @@ private[spark] object PythonRDD { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes("UTF-8") + val bytes = str.getBytes(UTF8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -286,7 +288,7 @@ private[spark] object PythonRDD { private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") + override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8) } /** diff --git a/python/pyspark/context.py b/python/pyspark/context.py index bf2454fd7e..ff1023bbfa 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -28,7 +28,8 @@ from pyspark.broadcast import Broadcast from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ + PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -257,6 +258,45 @@ class SparkContext(object): return RDD(self._jsc.textFile(name, minSplits), self, UTF8Deserializer()) + def wholeTextFiles(self, path): + """ + Read a directory of text files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system + URI. Each file is read as a single record and returned in a + key-value pair, where the key is the path of each file, the + value is the content of each file. + + For example, if you have the following files:: + + hdfs://a-hdfs-path/part-00000 + hdfs://a-hdfs-path/part-00001 + ... + hdfs://a-hdfs-path/part-nnnnn + + Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")}, + then C{rdd} contains:: + + (a-hdfs-path/part-00000, its content) + (a-hdfs-path/part-00001, its content) + ... + (a-hdfs-path/part-nnnnn, its content) + + NOTE: Small files are preferred, as each file will be loaded + fully in memory. + + >>> dirPath = os.path.join(tempdir, "files") + >>> os.mkdir(dirPath) + >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: + ... file1.write("1") + >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: + ... file2.write("2") + >>> textFiles = sc.wholeTextFiles(dirPath) + >>> sorted(textFiles.collect()) + [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] + """ + return RDD(self._jsc.wholeTextFiles(path), self, + PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) @@ -425,7 +465,7 @@ def _test(): globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - (failure_count, test_count) = doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 12c63f186a..4d802924df 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer): class UTF8Deserializer(Serializer): """ - Deserializes streams written by getBytes. + Deserializes streams written by String.getBytes. """ def loads(self, stream): -- cgit v1.2.3