From 1dac964c1b996d38c65818414fc8401961a1de8a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 26 Jan 2016 17:31:19 -0800 Subject: [SPARK-11622][MLLIB] Make LibSVMRelation extends HadoopFsRelation and… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … Add LibSVMOutputWriter The behavior of LibSVMRelation is not changed except adding LibSVMOutputWriter * Partition is still not supported * Multiple input paths is not supported Author: Jeff Zhang Closes #9595 from zjffdu/SPARK-11622. --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 102 ++++++++++++++++++--- .../ml/source/libsvm/LibSVMRelationSuite.scala | 23 ++++- project/MimaExcludes.scala | 4 + 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1bed542c40..b9c364b05d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,16 +17,21 @@ package org.apache.spark.ml.source.libsvm +import java.io.IOException + import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types._ /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -37,14 +42,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging with Serializable { - - override def schema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil - ) + extends HadoopFsRelation with Serializable { - override def buildScan(): RDD[Row] = { + override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) + : RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) val sparse = vectorType == "sparse" @@ -66,8 +67,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val case _ => false } + + override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job): + _root_.org.apache.spark.sql.sources.OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def paths: Array[String] = Array(path) + + override def dataSchema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil) } + +private[libsvm] class LibSVMOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = { + val label = row.get(0) + val vector = row.get(1).asInstanceOf[Vector] + val sb = new StringBuilder(label.toString) + vector.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" + } + buffer.set(sb.mkString) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} /** * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and @@ -99,16 +155,32 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends RelationProvider with DataSourceRegister { +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - @Since("1.6.0") - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) - : BaseRelation = { - val path = parameters.getOrElse("path", - throw new IllegalArgumentException("'path' must be specified")) + private def verifySchema(dataSchema: StructType): Unit = { + if (dataSchema.size != 2 || + (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) + || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + } + } + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + val path = if (paths.length == 1) paths(0) + else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") + else throw new IOException("Multiple input paths are not supported for libsvm data") + if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) { + throw new IOException("Partition is not supported for libsvm data") + } + dataSchema.foreach(verifySchema(_)) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 5f4d5f11bd..528d9e21cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import com.google.common.base.Charsets import com.google.common.io.Files @@ -25,6 +25,7 @@ import com.google.common.io.Files import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -82,4 +83,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } + + test("write libsvm data and read it again") { + val df = sqlContext.read.format("libsvm").load(path) + val tempDir2 = Utils.createTempDir() + val writepath = tempDir2.toURI.toString + df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) + + val df2 = sqlContext.read.format("libsvm").load(writepath) + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("write libsvm data failed due to invalid schema") { + val df = sqlContext.read.format("text").load(path) + val e = intercept[IOException] { + df.write.format("libsvm").save(path + "_2") + } + assert(e.getMessage.contains("Illegal schema for libsvm data")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 643bee6969..fc7dc2181d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -203,6 +203,10 @@ object MimaExcludes { // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") ) case v if v.startsWith("1.6") => Seq( -- cgit v1.2.3