aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/input/PortableDataStream.scala218
-rw-r--r--core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java79
-rw-r--r--core/src/test/scala/org/apache/spark/FileSuite.scala184
10 files changed, 892 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6bfcd8ceae..8b4db78397 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -41,7 +41,7 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -533,6 +533,69 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
minPartitions).setName(path)
}
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
+ * (useful for binary data)
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ *
+ * @note Small files are preferred; very large files may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, PortableDataStream)] = {
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new BinaryFileRDD(
+ this,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ updateConf,
+ minPartitions).setName(path)
+ }
+
+ /**
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @param recordLength The length at which to split the records
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
+ : RDD[Array[Byte]] = {
+ conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
+ val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
+ classOf[FixedLengthBinaryInputFormat],
+ classOf[LongWritable],
+ classOf[BytesWritable],
+ conf=conf)
+ val data = br.map{ case (k, v) => v.getBytes}
+ data
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 0565adf4d4..e3aeba7e6c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,6 +21,11 @@ import java.io.Closeable
import java.util
import java.util.{Map => JMap}
+import java.io.DataInputStream
+
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat}
+
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -32,7 +37,8 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
@@ -202,6 +208,8 @@ class JavaSparkContext(val sc: SparkContext)
def textFile(path: String, minPartitions: Int): JavaRDD[String] =
sc.textFile(path, minPartitions)
+
+
/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
@@ -245,6 +253,78 @@ class JavaSparkContext(val sc: SparkContext)
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD<String, byte[]> rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD<String, byte[]> rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
new file mode 100644
index 0000000000..89b29af200
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+/**
+ * Custom Input Format for reading and splitting flat binary files that contain records,
+ * each of which are a fixed size in bytes. The fixed record size is specified through
+ * a parameter recordLength in the Hadoop configuration.
+ */
+private[spark] object FixedLengthBinaryInputFormat {
+ /** Property name to set in Hadoop JobConfs for record length */
+ val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength"
+
+ /** Retrieves the record length property from a Hadoop configuration */
+ def getRecordLength(context: JobContext): Int = {
+ context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ }
+}
+
+private[spark] class FixedLengthBinaryInputFormat
+ extends FileInputFormat[LongWritable, BytesWritable] {
+
+ private var recordLength = -1
+
+ /**
+ * Override of isSplitable to ensure initial computation of the record length
+ */
+ override def isSplitable(context: JobContext, filename: Path): Boolean = {
+ if (recordLength == -1) {
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ }
+ if (recordLength <= 0) {
+ println("record length is less than 0, file cannot be split")
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This input format overrides computeSplitSize() to make sure that each split
+ * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader
+ * will start at the first byte of a record, and the last byte will the last byte of a record.
+ */
+ override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = {
+ val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize)
+ // If the default size is less than the length of a record, make it equal to it
+ // Otherwise, make sure the split size is as close to possible as the default size,
+ // but still contains a complete set of records, with the first record
+ // starting at the first byte in the split and the last record ending with the last byte
+ if (defaultSize < recordLength) {
+ recordLength.toLong
+ } else {
+ (Math.floor(defaultSize / recordLength) * recordLength).toLong
+ }
+ }
+
+ /**
+ * Create a FixedLengthBinaryRecordReader
+ */
+ override def createRecordReader(split: InputSplit, context: TaskAttemptContext)
+ : RecordReader[LongWritable, BytesWritable] = {
+ new FixedLengthBinaryRecordReader
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
new file mode 100644
index 0000000000..5164a74bec
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.FSDataInputStream
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+
+/**
+ * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
+ * It uses the record length set in FixedLengthBinaryInputFormat to
+ * read one record at a time from the given InputSplit.
+ *
+ * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value.
+ *
+ * key = record index (Long)
+ * value = the record itself (BytesWritable)
+ */
+private[spark] class FixedLengthBinaryRecordReader
+ extends RecordReader[LongWritable, BytesWritable] {
+
+ private var splitStart: Long = 0L
+ private var splitEnd: Long = 0L
+ private var currentPosition: Long = 0L
+ private var recordLength: Int = 0
+ private var fileInputStream: FSDataInputStream = null
+ private var recordKey: LongWritable = null
+ private var recordValue: BytesWritable = null
+
+ override def close() {
+ if (fileInputStream != null) {
+ fileInputStream.close()
+ }
+ }
+
+ override def getCurrentKey: LongWritable = {
+ recordKey
+ }
+
+ override def getCurrentValue: BytesWritable = {
+ recordValue
+ }
+
+ override def getProgress: Float = {
+ splitStart match {
+ case x if x == splitEnd => 0.0.toFloat
+ case _ => Math.min(
+ ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0
+ ).toFloat
+ }
+ }
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ // the file input
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+
+ // the byte position this fileSplit starts at
+ splitStart = fileSplit.getStart
+
+ // splitEnd byte marker that the fileSplit ends at
+ splitEnd = splitStart + fileSplit.getLength
+
+ // the actual file we will be reading from
+ val file = fileSplit.getPath
+ // job configuration
+ val job = context.getConfiguration
+ // check compression
+ val codec = new CompressionCodecFactory(job).getCodec(file)
+ if (codec != null) {
+ throw new IOException("FixedLengthRecordReader does not support reading compressed files")
+ }
+ // get the record length
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ // get the filesystem
+ val fs = file.getFileSystem(job)
+ // open the File
+ fileInputStream = fs.open(file)
+ // seek to the splitStart position
+ fileInputStream.seek(splitStart)
+ // set our current position
+ currentPosition = splitStart
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (recordKey == null) {
+ recordKey = new LongWritable()
+ }
+ // the key is a linear index of the record, given by the
+ // position the record starts divided by the record length
+ recordKey.set(currentPosition / recordLength)
+ // the recordValue to place the bytes into
+ if (recordValue == null) {
+ recordValue = new BytesWritable(new Array[Byte](recordLength))
+ }
+ // read a record if the currentPosition is less than the split end
+ if (currentPosition < splitEnd) {
+ // setup a buffer to store the record
+ val buffer = recordValue.getBytes
+ fileInputStream.read(buffer, 0, recordLength)
+ // update our current position
+ currentPosition = currentPosition + recordLength
+ // return true
+ return true
+ }
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
new file mode 100644
index 0000000000..457472547f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * A general format for reading whole files in as streams, byte arrays,
+ * or other functions to be added
+ */
+private[spark] abstract class StreamFileInputFormat[T]
+ extends CombineFileInputFormat[String, T]
+{
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
+ * which is set through setMaxSplitSize
+ */
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+
+ val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+
+ def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T]
+
+}
+
+/**
+ * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]]
+ * to reading files out as streams
+ */
+private[spark] abstract class StreamBasedRecordReader[T](
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, T] {
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private var key = ""
+ private var value: T = null.asInstanceOf[T]
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = new PortableDataStream(split, context, index)
+ value = parseStream(fileIn)
+ fileIn.close() // if it has not been open yet, close does nothing
+ key = fileIn.getPath
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Parse the stream (and close it afterwards) and return the value as in type T
+ * @param inStream the stream to be read in
+ * @return the data formatted as
+ */
+ def parseStream(inStream: PortableDataStream): T
+}
+
+/**
+ * Reads the record in directly as a stream for other objects to manipulate and handle
+ */
+private[spark] class StreamRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
+
+ def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
+}
+
+/**
+ * The format for the PortableDataStream files
+ */
+private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] {
+ override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = {
+ new CombineFileRecordReader[String, PortableDataStream](
+ split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader])
+ }
+}
+
+/**
+ * A class that allows DataStreams to be serialized and moved around by not creating them
+ * until they need to be read
+ * @note TaskAttemptContext is not serializable resulting in the confBytes construct
+ * @note CombineFileSplit is not serializable resulting in the splitBytes construct
+ */
+@Experimental
+class PortableDataStream(
+ @transient isplit: CombineFileSplit,
+ @transient context: TaskAttemptContext,
+ index: Integer)
+ extends Serializable {
+
+ // transient forces file to be reopened after being serialization
+ // it is also used for non-serializable classes
+
+ @transient private var fileIn: DataInputStream = null
+ @transient private var isOpen = false
+
+ private val confBytes = {
+ val baos = new ByteArrayOutputStream()
+ context.getConfiguration.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ private val splitBytes = {
+ val baos = new ByteArrayOutputStream()
+ isplit.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ @transient private lazy val split = {
+ val bais = new ByteArrayInputStream(splitBytes)
+ val nsplit = new CombineFileSplit()
+ nsplit.readFields(new DataInputStream(bais))
+ nsplit
+ }
+
+ @transient private lazy val conf = {
+ val bais = new ByteArrayInputStream(confBytes)
+ val nconf = new Configuration()
+ nconf.readFields(new DataInputStream(bais))
+ nconf
+ }
+ /**
+ * Calculate the path name independently of opening the file
+ */
+ @transient private lazy val path = {
+ val pathp = split.getPath(index)
+ pathp.toString
+ }
+
+ /**
+ * Create a new DataInputStream from the split and context
+ */
+ def open(): DataInputStream = {
+ if (!isOpen) {
+ val pathp = split.getPath(index)
+ val fs = pathp.getFileSystem(conf)
+ fileIn = fs.open(pathp)
+ isOpen = true
+ }
+ fileIn
+ }
+
+ /**
+ * Read the file as a byte array
+ */
+ def toArray(): Array[Byte] = {
+ open()
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+ close()
+ innerBuffer
+ }
+
+ /**
+ * Close the file (if it is currently open)
+ */
+ def close() = {
+ if (isOpen) {
+ try {
+ fileIn.close()
+ isOpen = false
+ } catch {
+ case ioe: java.io.IOException => // do nothing
+ }
+ }
+ }
+
+ def getPath(): String = path
+}
+
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 4cb4505777..183bce3d8d 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str
}
/**
- * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
+ * which is set through setMaxSplitSize
*/
- def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context)
val totalLen = files.map { file =>
if (file.isDir) 0L else file.getLen
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
new file mode 100644
index 0000000000..6e66ddbdef
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+import org.apache.spark.input.StreamFileInputFormat
+import org.apache.spark.{ Partition, SparkContext }
+
+private[spark] class BinaryFileRDD[T](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
+ keyClass: Class[String],
+ valueClass: Class[T],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 3245632487..6d6b86721c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -263,7 +263,7 @@ private[spark] class WholeTextFileRDD(
case _ =>
}
val jobContext = newJobContext(conf, jobId)
- inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index c21a4b30d7..59c86eecac 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -18,10 +18,13 @@
package org.apache.spark;
import java.io.*;
+import java.nio.channels.FileChannel;
+import java.nio.ByteBuffer;
import java.net.URI;
import java.util.*;
import java.util.concurrent.*;
+import org.apache.spark.input.PortableDataStream;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -863,6 +866,82 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(pairs, readRDD.collect());
}
+ @Test
+ public void binaryFiles() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
+
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ channel1.close();
+ JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName, 3);
+ List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
+ for (Tuple2<String, PortableDataStream> res : result) {
+ Assert.assertArrayEquals(content1, res._2().toArray());
+ }
+ }
+
+ @Test
+ public void binaryFilesCaching() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
+
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ channel1.close();
+
+ JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName).cache();
+ readRDD.foreach(new VoidFunction<Tuple2<String,PortableDataStream>>() {
+ @Override
+ public void call(Tuple2<String, PortableDataStream> pair) throws Exception {
+ pair._2().toArray(); // force the file to read
+ }
+ });
+
+ List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
+ for (Tuple2<String, PortableDataStream> res : result) {
+ Assert.assertArrayEquals(content1, res._2().toArray());
+ }
+ }
+
+ @Test
+ public void binaryRecords() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8");
+ int numOfCopies = 10;
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+
+ for (int i = 0; i < numOfCopies; i++) {
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ }
+ channel1.close();
+
+ JavaRDD<byte[]> readRDD = sc.binaryRecords(tempDirName, content1.length);
+ Assert.assertEquals(numOfCopies,readRDD.count());
+ List<byte[]> result = readRDD.collect();
+ for (byte[] res : result) {
+ Assert.assertArrayEquals(content1, res);
+ }
+ }
+
@SuppressWarnings("unchecked")
@Test
public void writeWithNewAPIHadoopFile() {
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index a2b74c4419..5e24196101 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark
import java.io.{File, FileWriter}
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.storage.StorageLevel
+
import scala.io.Source
import org.apache.hadoop.io._
@@ -224,6 +227,187 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
+ test("binary file input as byte array") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName)
+ val (infile: String, indata: PortableDataStream) = inRdd.collect.head
+
+ // Make sure the name and array match
+ assert(infile.contains(outFileName)) // a prefix may get added
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream caching tests") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName).cache()
+ inRdd.foreach{
+ curData: (String, PortableDataStream) =>
+ curData._2.toArray() // force the file to read
+ }
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head
+
+ // Try reading the output back as an object file
+
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream persist disk storage") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY)
+ inRdd.foreach{
+ curData: (String, PortableDataStream) =>
+ curData._2.toArray() // force the file to read
+ }
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head
+
+ // Try reading the output back as an object file
+
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream flatmap tests") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val numOfCopies = 3
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName)
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val copyRdd = mappedRdd.flatMap {
+ curData: (String, PortableDataStream) =>
+ for(i <- 1 to numOfCopies) yield (i, curData._2)
+ }
+
+ val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect()
+
+ // Try reading the output back as an object file
+ assert(copyArr.length == numOfCopies)
+ copyArr.foreach{
+ cEntry: (Int, PortableDataStream) =>
+ assert(cEntry._2.toArray === testOutput)
+ }
+
+ }
+
+ test("fixed record length binary file as byte array") {
+ // a fixed length of 6 bytes
+
+ sc = new SparkContext("local", "test")
+
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val testOutputCopies = 10
+
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ for(i <- 1 to testOutputCopies) {
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ channel.write(bbuf)
+ }
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryRecords(outFileName, testOutput.length)
+ // make sure there are enough elements
+ assert(inRdd.count == testOutputCopies)
+
+ // now just compare the first one
+ val indata: Array[Byte] = inRdd.collect.head
+ assert(indata === testOutput)
+ }
+
+ test ("negative binary record length should raise an exception") {
+ // a fixed length of 6 bytes
+ sc = new SparkContext("local", "test")
+
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val testOutputCopies = 10
+
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ for(i <- 1 to testOutputCopies) {
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ channel.write(bbuf)
+ }
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryRecords(outFileName, -1)
+
+ intercept[SparkException] {
+ inRdd.count
+ }
+ }
+
test("file caching") {
sc = new SparkContext("local", "test")
val out = new FileWriter(tempDir + "/input")