aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala102
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala23
-rw-r--r--project/MimaExcludes.scala4
3 files changed, 113 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 1bed542c40..b9c364b05d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -17,16 +17,21 @@
package org.apache.spark.ml.source.libsvm
+import java.io.IOException
+
import com.google.common.base.Objects
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{NullWritable, Text}
+import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
-import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types._
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
@@ -37,14 +42,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
- extends BaseRelation with TableScan with Logging with Serializable {
-
- override def schema: StructType = StructType(
- StructField("label", DoubleType, nullable = false) ::
- StructField("features", new VectorUDT(), nullable = false) :: Nil
- )
+ extends HadoopFsRelation with Serializable {
- override def buildScan(): RDD[Row] = {
+ override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus])
+ : RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
val sparse = vectorType == "sparse"
@@ -66,8 +67,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val
case _ =>
false
}
+
+ override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job):
+ _root_.org.apache.spark.sql.sources.OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new LibSVMOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+
+ override def paths: Array[String] = Array(path)
+
+ override def dataSchema: StructType = StructType(
+ StructField("label", DoubleType, nullable = false) ::
+ StructField("features", new VectorUDT(), nullable = false) :: Nil)
}
+
+private[libsvm] class LibSVMOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter {
+
+ private[this] val buffer = new Text()
+
+ private val recordWriter: RecordWriter[NullWritable, Text] = {
+ new TextOutputFormat[NullWritable, Text]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val configuration = context.getConfiguration
+ val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
+ val taskAttemptId = context.getTaskAttemptID
+ val split = taskAttemptId.getTaskID.getId
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ }
+ }.getRecordWriter(context)
+ }
+
+ override def write(row: Row): Unit = {
+ val label = row.get(0)
+ val vector = row.get(1).asInstanceOf[Vector]
+ val sb = new StringBuilder(label.toString)
+ vector.foreachActive { case (i, v) =>
+ sb += ' '
+ sb ++= s"${i + 1}:$v"
+ }
+ buffer.set(sb.mkString)
+ recordWriter.write(NullWritable.get(), buffer)
+ }
+
+ override def close(): Unit = {
+ recordWriter.close(context)
+ }
+}
/**
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
@@ -99,16 +155,32 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
-class DefaultSource extends RelationProvider with DataSourceRegister {
+class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
- @Since("1.6.0")
- override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
- : BaseRelation = {
- val path = parameters.getOrElse("path",
- throw new IllegalArgumentException("'path' must be specified"))
+ private def verifySchema(dataSchema: StructType): Unit = {
+ if (dataSchema.size != 2 ||
+ (!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
+ || !dataSchema(1).dataType.sameType(new VectorUDT()))) {
+ throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
+ }
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation = {
+ val path = if (paths.length == 1) paths(0)
+ else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data")
+ else throw new IOException("Multiple input paths are not supported for libsvm data")
+ if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) {
+ throw new IOException("Partition is not supported for libsvm data")
+ }
+ dataSchema.foreach(verifySchema(_))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 5f4d5f11bd..528d9e21cb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.source.libsvm
-import java.io.File
+import java.io.{File, IOException}
import com.google.common.base.Charsets
import com.google.common.io.Files
@@ -25,6 +25,7 @@ import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.SaveMode
import org.apache.spark.util.Utils
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -82,4 +83,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
+
+ test("write libsvm data and read it again") {
+ val df = sqlContext.read.format("libsvm").load(path)
+ val tempDir2 = Utils.createTempDir()
+ val writepath = tempDir2.toURI.toString
+ df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
+
+ val df2 = sqlContext.read.format("libsvm").load(writepath)
+ val row1 = df2.first()
+ val v = row1.getAs[SparseVector](1)
+ assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
+ }
+
+ test("write libsvm data failed due to invalid schema") {
+ val df = sqlContext.read.format("text").load(path)
+ val e = intercept[IOException] {
+ df.write.format("libsvm").save(path + "_2")
+ }
+ assert(e.getMessage.contains("Illegal schema for libsvm data"))
+ }
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 643bee6969..fc7dc2181d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -203,6 +203,10 @@ object MimaExcludes {
// SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus")
+ ) ++ Seq(
+ // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation")
)
case v if v.startsWith("1.6") =>
Seq(