diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 135 |
1 files changed, 63 insertions, 72 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b9c364b05d..976343ed96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,74 +19,23 @@ package org.apache.spark.ml.source.libsvm import java.io.IOException -import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since +import org.apache.spark.broadcast.Broadcast import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ - -/** - * LibSVMRelation provides the DataFrame constructed from LibSVM format data. - * @param path File path of LibSVM format - * @param numFeatures The number of features - * @param vectorType The type of vector. It can be 'sparse' or 'dense' - * @param sqlContext The Spark SQLContext - */ -private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) - (@transient val sqlContext: SQLContext) - extends HadoopFsRelation with Serializable { - - override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) - : RDD[Row] = { - val sc = sqlContext.sparkContext - val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val sparse = vectorType == "sparse" - baseRdd.map { pt => - val features = if (sparse) pt.features.toSparse else pt.features.toDense - Row(pt.label, features) - } - } - - override def hashCode(): Int = { - Objects.hashCode(path, Double.box(numFeatures), vectorType) - } - - override def equals(other: Any): Boolean = other match { - case that: LibSVMRelation => - path == that.path && - numFeatures == that.numFeatures && - vectorType == that.vectorType - case _ => - false - } - - override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job): - _root_.org.apache.spark.sql.sources.OutputWriterFactory = { - new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new LibSVMOutputWriter(path, dataSchema, context) - } - } - } - - override def paths: Array[String] = Array(path) - - override def dataSchema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil) -} - +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.BitSet private[libsvm] class LibSVMOutputWriter( path: String, @@ -124,6 +73,7 @@ private[libsvm] class LibSVMOutputWriter( recordWriter.close(context) } } + /** * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and @@ -155,7 +105,7 @@ private[libsvm] class LibSVMOutputWriter( * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends FileFormat with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" @@ -167,22 +117,63 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") } } + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + Some( + StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil)) + } - override def createRelation( + override def prepareWrite( sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - val path = if (paths.length == 1) paths(0) - else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") - else throw new IOException("Multiple input paths are not supported for libsvm data") - if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) { - throw new IOException("Partition is not supported for libsvm data") + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") } + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def buildInternalScan( + sqlContext: SQLContext, + dataSchema: StructType, + requiredColumns: Array[String], + filters: Array[Filter], + bucketSet: Option[BitSet], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration], + options: Map[String, String]): RDD[InternalRow] = { + // TODO: This does not handle cases where column pruning has been performed. + + verifySchema(dataSchema) + val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") + + val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString + else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") + else throw new IOException("Multiple input paths are not supported for libsvm data.") + + val numFeatures = options.getOrElse("numFeatures", "-1").toInt + val vectorType = options.getOrElse("vectorType", "sparse") + + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + val sparse = vectorType == "sparse" + baseRdd.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + Row(pt.label, features) + }.mapPartitions { externalRows => + val converter = RowEncoder(dataSchema) + externalRows.map(converter.toRow) } - dataSchema.foreach(verifySchema(_)) - val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt - val vectorType = parameters.getOrElse("vectorType", "sparse") - new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) } } |