aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala87
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala9
5 files changed, 141 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 13a13f0a7e..2e9b6be9a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
@@ -26,12 +27,16 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
+ override def toString: String = "LibSVM"
+
private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
- throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
+ throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
+
override def inferSchema(
sqlContext: SQLContext,
options: Map[String, String],
@@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with DataSourceRegister {
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}
+ override def prepareRead(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Map[String, String] = {
+ def computeNumFeatures(): Int = {
+ val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
+ val path = if (dataFiles.length == 1) {
+ dataFiles.head.getPath.toUri.toString
+ } else if (dataFiles.isEmpty) {
+ throw new IOException("No input path specified for libsvm data")
+ } else {
+ throw new IOException("Multiple input paths are not supported for libsvm data.")
+ }
+
+ val sc = sqlContext.sparkContext
+ val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
+ MLUtils.computeNumFeatures(parsed)
+ }
+
+ val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
+ computeNumFeatures()
+ }
+
+ new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+ }
+
override def prepareWrite(
sqlContext: SQLContext,
job: Job,
@@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
verifySchema(dataSchema)
val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
- val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
+ val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString
else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
else throw new IOException("Multiple input paths are not supported for libsvm data.")
@@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with DataSourceRegister {
externalRows.map(converter.toRow)
}
}
+
+ override def buildReader(
+ sqlContext: SQLContext,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
+ val numFeatures = options("numFeatures").toInt
+ assert(numFeatures > 0)
+
+ val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
+
+ val broadcastedConf = sqlContext.sparkContext.broadcast(
+ new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
+ )
+
+ (file: PartitionedFile) => {
+ val points =
+ new HadoopFileLinesReader(file, broadcastedConf.value.value)
+ .map(_.toString.trim)
+ .filterNot(line => line.isEmpty || line.startsWith("#"))
+ .map { line =>
+ val (label, indices, values) = MLUtils.parseLibSVMRecord(line)
+ LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
+ }
+
+ val converter = RowEncoder(requiredSchema)
+
+ val unsafeRowIterator = points.map { pt =>
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
+ converter.toRow(Row(pt.label, features))
+ }
+
+ def toAttribute(f: StructField): AttributeReference =
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+
+ // Appends partition values
+ val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c3b1d5cdd7..4b9d77949f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -67,42 +67,14 @@ object MLUtils {
path: String,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = {
- val parsed = sc.textFile(path, minPartitions)
- .map(_.trim)
- .filter(line => !(line.isEmpty || line.startsWith("#")))
- .map { line =>
- val items = line.split(' ')
- val label = items.head.toDouble
- val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
- val indexAndValue = item.split(':')
- val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
- val value = indexAndValue(1).toDouble
- (index, value)
- }.unzip
-
- // check if indices are one-based and in ascending order
- var previous = -1
- var i = 0
- val indicesLength = indices.length
- while (i < indicesLength) {
- val current = indices(i)
- require(current > previous, s"indices should be one-based and in ascending order;"
- + " found current=$current, previous=$previous; line=\"$line\"")
- previous = current
- i += 1
- }
-
- (label, indices.toArray, values.toArray)
- }
+ val parsed = parseLibSVMFile(sc, path, minPartitions)
// Determine number of features.
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.persist(StorageLevel.MEMORY_ONLY)
- parsed.map { case (label, indices, values) =>
- indices.lastOption.getOrElse(0)
- }.reduce(math.max) + 1
+ computeNumFeatures(parsed)
}
parsed.map { case (label, indices, values) =>
@@ -110,6 +82,47 @@ object MLUtils {
}
}
+ private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = {
+ rdd.map { case (label, indices, values) =>
+ indices.lastOption.getOrElse(0)
+ }.reduce(math.max) + 1
+ }
+
+ private[spark] def parseLibSVMFile(
+ sc: SparkContext,
+ path: String,
+ minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
+ sc.textFile(path, minPartitions)
+ .map(_.trim)
+ .filter(line => !(line.isEmpty || line.startsWith("#")))
+ .map(parseLibSVMRecord)
+ }
+
+ private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
+ val items = line.split(' ')
+ val label = items.head.toDouble
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
+ val indexAndValue = item.split(':')
+ val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
+ val value = indexAndValue(1).toDouble
+ (index, value)
+ }.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, s"indices should be one-based and in ascending order;"
+ + " found current=$current, previous=$previous; line=\"$line\"")
+ previous = current
+ i += 1
+ }
+
+ (label, indices, values)
+ }
+
/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index c66921f485..1850810270 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -299,6 +299,9 @@ case class DataSource(
"It must be specified manually")
}
+ val enrichedOptions =
+ format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles())
+
HadoopFsRelation(
sqlContext,
fileCatalog,
@@ -306,7 +309,7 @@ case class DataSource(
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
- options)
+ enrichedOptions)
case _ =>
throw new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 554298772a..a143ac6aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
if (files.fileFormat.toString == "TestFileFormat" ||
files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
files.fileFormat.toString == "ORC" ||
+ files.fileFormat.toString == "LibSVM" ||
files.fileFormat.isInstanceOf[csv.DefaultSource] ||
files.fileFormat.isInstanceOf[text.DefaultSource] ||
files.fileFormat.isInstanceOf[json.DefaultSource]) &&
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 6b95a3d25b..e8834d052c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -439,6 +439,15 @@ trait FileFormat {
files: Seq[FileStatus]): Option[StructType]
/**
+ * Prepares a read job and returns a potentially updated data source option [[Map]]. This method
+ * can be useful for collecting necessary global information for scanning input data.
+ */
+ def prepareRead(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Map[String, String] = options
+
+ /**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.