aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala135
1 files changed, 63 insertions, 72 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b9c364b05d..976343ed96 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,74 +19,23 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException
-import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-
-/**
- * LibSVMRelation provides the DataFrame constructed from LibSVM format data.
- * @param path File path of LibSVM format
- * @param numFeatures The number of features
- * @param vectorType The type of vector. It can be 'sparse' or 'dense'
- * @param sqlContext The Spark SQLContext
- */
-private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
- (@transient val sqlContext: SQLContext)
- extends HadoopFsRelation with Serializable {
-
- override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus])
- : RDD[Row] = {
- val sc = sqlContext.sparkContext
- val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
- val sparse = vectorType == "sparse"
- baseRdd.map { pt =>
- val features = if (sparse) pt.features.toSparse else pt.features.toDense
- Row(pt.label, features)
- }
- }
-
- override def hashCode(): Int = {
- Objects.hashCode(path, Double.box(numFeatures), vectorType)
- }
-
- override def equals(other: Any): Boolean = other match {
- case that: LibSVMRelation =>
- path == that.path &&
- numFeatures == that.numFeatures &&
- vectorType == that.vectorType
- case _ =>
- false
- }
-
- override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job):
- _root_.org.apache.spark.sql.sources.OutputWriterFactory = {
- new OutputWriterFactory {
- override def newInstance(
- path: String,
- dataSchema: StructType,
- context: TaskAttemptContext): OutputWriter = {
- new LibSVMOutputWriter(path, dataSchema, context)
- }
- }
- }
-
- override def paths: Array[String] = Array(path)
-
- override def dataSchema: StructType = StructType(
- StructField("label", DoubleType, nullable = false) ::
- StructField("features", new VectorUDT(), nullable = false) :: Nil)
-}
-
+import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.BitSet
private[libsvm] class LibSVMOutputWriter(
path: String,
@@ -124,6 +73,7 @@ private[libsvm] class LibSVMOutputWriter(
recordWriter.close(context)
}
}
+
/**
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
@@ -155,7 +105,7 @@ private[libsvm] class LibSVMOutputWriter(
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
-class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
@@ -167,22 +117,63 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
}
}
+ override def inferSchema(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ Some(
+ StructType(
+ StructField("label", DoubleType, nullable = false) ::
+ StructField("features", new VectorUDT(), nullable = false) :: Nil))
+ }
- override def createRelation(
+ override def prepareWrite(
sqlContext: SQLContext,
- paths: Array[String],
- dataSchema: Option[StructType],
- partitionColumns: Option[StructType],
- parameters: Map[String, String]): HadoopFsRelation = {
- val path = if (paths.length == 1) paths(0)
- else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data")
- else throw new IOException("Multiple input paths are not supported for libsvm data")
- if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) {
- throw new IOException("Partition is not supported for libsvm data")
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ bucketId: Option[Int],
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") }
+ new LibSVMOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+
+ override def buildInternalScan(
+ sqlContext: SQLContext,
+ dataSchema: StructType,
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ bucketSet: Option[BitSet],
+ inputFiles: Array[FileStatus],
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ options: Map[String, String]): RDD[InternalRow] = {
+ // TODO: This does not handle cases where column pruning has been performed.
+
+ verifySchema(dataSchema)
+ val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
+
+ val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
+ else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
+ else throw new IOException("Multiple input paths are not supported for libsvm data.")
+
+ val numFeatures = options.getOrElse("numFeatures", "-1").toInt
+ val vectorType = options.getOrElse("vectorType", "sparse")
+
+ val sc = sqlContext.sparkContext
+ val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
+ val sparse = vectorType == "sparse"
+ baseRdd.map { pt =>
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
+ Row(pt.label, features)
+ }.mapPartitions { externalRows =>
+ val converter = RowEncoder(dataSchema)
+ externalRows.map(converter.toRow)
}
- dataSchema.foreach(verifySchema(_))
- val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
- val vectorType = parameters.getOrElse("vectorType", "sparse")
- new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
}