aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-03-07 15:15:10 -0800
committerReynold Xin <rxin@databricks.com>2016-03-07 15:15:10 -0800
commite720dda42e806229ccfd970055c7b8a93eb447bf (patch)
tree641c3d454b638a347adc4c51db8cc69c41b44ac2 /mllib
parent0eea12a3d956b54bbbd73d21b296868852a04494 (diff)
downloadspark-e720dda42e806229ccfd970055c7b8a93eb447bf.tar.gz
spark-e720dda42e806229ccfd970055c7b8a93eb447bf.tar.bz2
spark-e720dda42e806229ccfd970055c7b8a93eb447bf.zip
[SPARK-13665][SQL] Separate the concerns of HadoopFsRelation
`HadoopFsRelation` is used for reading most files into Spark SQL. However today this class mixes the concerns of file management, schema reconciliation, scan building, bucketing, partitioning, and writing data. As a result, many data sources are forced to reimplement the same functionality and the various layers have accumulated a fair bit of inefficiency. This PR is a first cut at separating this into several components / interfaces that are each described below. Additionally, all implementations inside of Spark (parquet, csv, json, text, orc, svmlib) have been ported to the new API `FileFormat`. External libraries, such as spark-avro will also need to be ported to work with Spark 2.0. ### HadoopFsRelation A simple `case class` that acts as a container for all of the metadata required to read from a datasource. All discovery, resolution and merging logic for schemas and partitions has been removed. This an internal representation that no longer needs to be exposed to developers. ```scala case class HadoopFsRelation( sqlContext: SQLContext, location: FileCatalog, partitionSchema: StructType, dataSchema: StructType, bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String]) extends BaseRelation ``` ### FileFormat The primary interface that will be implemented by each different format including external libraries. Implementors are responsible for reading a given format and converting it into `InternalRow` as well as writing out an `InternalRow`. A format can optionally return a schema that is inferred from a set of files. ```scala trait FileFormat { def inferSchema( sqlContext: SQLContext, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] def prepareWrite( sqlContext: SQLContext, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory def buildInternalScan( sqlContext: SQLContext, dataSchema: StructType, requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], inputFiles: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] } ``` The current interface is based on what was required to get all the tests passing again, but still mixes a couple of concerns (i.e. `bucketSet` is passed down to the scan instead of being resolved by the planner). Additionally, scans are still returning `RDD`s instead of iterators for single files. In a future PR, bucketing should be removed from this interface and the scan should be isolated to a single file. ### FileCatalog This interface is used to list the files that make up a given relation, as well as handle directory based partitioning. ```scala trait FileCatalog { def paths: Seq[Path] def partitionSpec(schema: Option[StructType]): PartitionSpec def allFiles(): Seq[FileStatus] def getStatus(path: Path): Array[FileStatus] def refresh(): Unit } ``` Currently there are two implementations: - `HDFSFileCatalog` - based on code from the old `HadoopFsRelation`. Infers partitioning by recursive listing and caches this data for performance - `HiveFileCatalog` - based on the above, but it uses the partition spec from the Hive Metastore. ### ResolvedDataSource Produces a logical plan given the following description of a Data Source (which can come from DataFrameReader or a metastore): - `paths: Seq[String] = Nil` - `userSpecifiedSchema: Option[StructType] = None` - `partitionColumns: Array[String] = Array.empty` - `bucketSpec: Option[BucketSpec] = None` - `provider: String` - `options: Map[String, String]` This class is responsible for deciding which of the Data Source APIs a given provider is using (including the non-file based ones). All reconciliation of partitions, buckets, schema from metastores or inference is done here. ### DataSourceAnalysis / DataSourceStrategy Responsible for analyzing and planning reading/writing of data using any of the Data Source APIs, including: - pruning the files from partitions that will be read based on filters. - appending partition columns* - applying additional filters when a data source can not evaluate them internally. - constructing an RDD that is bucketed correctly when required* - sanity checking schema match-up and other analysis when writing. *In the future we should do that following: - Break out file handling into its own Strategy as its sufficiently complex / isolated. - Push the appending of partition columns down in to `FileFormat` to avoid an extra copy / unvectorization. - Use a custom RDD for scans instead of `SQLNewNewHadoopRDD2` Author: Michael Armbrust <michael@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Closes #11509 from marmbrus/fileDataSource.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala135
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala8
2 files changed, 67 insertions, 76 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b9c364b05d..976343ed96 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,74 +19,23 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException
-import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-
-/**
- * LibSVMRelation provides the DataFrame constructed from LibSVM format data.
- * @param path File path of LibSVM format
- * @param numFeatures The number of features
- * @param vectorType The type of vector. It can be 'sparse' or 'dense'
- * @param sqlContext The Spark SQLContext
- */
-private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
- (@transient val sqlContext: SQLContext)
- extends HadoopFsRelation with Serializable {
-
- override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus])
- : RDD[Row] = {
- val sc = sqlContext.sparkContext
- val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
- val sparse = vectorType == "sparse"
- baseRdd.map { pt =>
- val features = if (sparse) pt.features.toSparse else pt.features.toDense
- Row(pt.label, features)
- }
- }
-
- override def hashCode(): Int = {
- Objects.hashCode(path, Double.box(numFeatures), vectorType)
- }
-
- override def equals(other: Any): Boolean = other match {
- case that: LibSVMRelation =>
- path == that.path &&
- numFeatures == that.numFeatures &&
- vectorType == that.vectorType
- case _ =>
- false
- }
-
- override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job):
- _root_.org.apache.spark.sql.sources.OutputWriterFactory = {
- new OutputWriterFactory {
- override def newInstance(
- path: String,
- dataSchema: StructType,
- context: TaskAttemptContext): OutputWriter = {
- new LibSVMOutputWriter(path, dataSchema, context)
- }
- }
- }
-
- override def paths: Array[String] = Array(path)
-
- override def dataSchema: StructType = StructType(
- StructField("label", DoubleType, nullable = false) ::
- StructField("features", new VectorUDT(), nullable = false) :: Nil)
-}
-
+import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.BitSet
private[libsvm] class LibSVMOutputWriter(
path: String,
@@ -124,6 +73,7 @@ private[libsvm] class LibSVMOutputWriter(
recordWriter.close(context)
}
}
+
/**
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
@@ -155,7 +105,7 @@ private[libsvm] class LibSVMOutputWriter(
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
-class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
@@ -167,22 +117,63 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
}
}
+ override def inferSchema(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ Some(
+ StructType(
+ StructField("label", DoubleType, nullable = false) ::
+ StructField("features", new VectorUDT(), nullable = false) :: Nil))
+ }
- override def createRelation(
+ override def prepareWrite(
sqlContext: SQLContext,
- paths: Array[String],
- dataSchema: Option[StructType],
- partitionColumns: Option[StructType],
- parameters: Map[String, String]): HadoopFsRelation = {
- val path = if (paths.length == 1) paths(0)
- else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data")
- else throw new IOException("Multiple input paths are not supported for libsvm data")
- if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) {
- throw new IOException("Partition is not supported for libsvm data")
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ bucketId: Option[Int],
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") }
+ new LibSVMOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+
+ override def buildInternalScan(
+ sqlContext: SQLContext,
+ dataSchema: StructType,
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ bucketSet: Option[BitSet],
+ inputFiles: Array[FileStatus],
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ options: Map[String, String]): RDD[InternalRow] = {
+ // TODO: This does not handle cases where column pruning has been performed.
+
+ verifySchema(dataSchema)
+ val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
+
+ val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
+ else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
+ else throw new IOException("Multiple input paths are not supported for libsvm data.")
+
+ val numFeatures = options.getOrElse("numFeatures", "-1").toInt
+ val vectorType = options.getOrElse("vectorType", "sparse")
+
+ val sc = sqlContext.sparkContext
+ val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
+ val sparse = vectorType == "sparse"
+ baseRdd.map { pt =>
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
+ Row(pt.label, features)
+ }.mapPartitions { externalRows =>
+ val converter = RowEncoder(dataSchema)
+ externalRows.map(converter.toRow)
}
- dataSchema.foreach(verifySchema(_))
- val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
- val vectorType = parameters.getOrElse("vectorType", "sparse")
- new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 528d9e21cb..84fc08be09 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -22,7 +22,7 @@ import java.io.{File, IOException}
import com.google.common.base.Charsets
import com.google.common.io.Files
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
@@ -88,7 +88,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val df = sqlContext.read.format("libsvm").load(path)
val tempDir2 = Utils.createTempDir()
val writepath = tempDir2.toURI.toString
- df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
+ // TODO: Remove requirement to coalesce by supporting mutiple reads.
+ df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
val df2 = sqlContext.read.format("libsvm").load(writepath)
val row1 = df2.first()
@@ -98,9 +99,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
- val e = intercept[IOException] {
+ val e = intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
- assert(e.getMessage.contains("Illegal schema for libsvm data"))
}
}