From b3d39620c563e5f6a32a4082aa3908e1009c17d2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 8 Jan 2017 00:42:09 +0800 Subject: [SPARK-19085][SQL] cleanup OutputWriterFactory and OutputWriter ## What changes were proposed in this pull request? `OutputWriterFactory`/`OutputWriter` are internal interfaces and we can remove some unnecessary APIs: 1. `OutputWriterFactory.newWriter(path: String)`: no one calls it and no one implements it. 2. `OutputWriter.write(row: Row)`: during execution we only call `writeInternal`, which is weird as `OutputWriter` is already an internal interface. We should rename `writeInternal` to `write` and remove `def write(row: Row)` and it's related converter code. All implementations should just implement `def write(row: InternalRow)` ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #16479 from cloud-fan/hive-writer. --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 10 +++++++--- .../apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b5aa7ce4e1..89bbc1556c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter( private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val label = row.get(0) - val vector = row.get(1).asInstanceOf[Vector] + // This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema` + private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT] + + override def write(row: InternalRow): Unit = { + val label = row.getDouble(0) + val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length)) writer.write(label.toString) vector.foreachActive { case (i, v) => writer.write(s" ${i + 1}:$v") @@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) new OutputWriterFactory { override def newInstance( path: String, diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 2517de59fe..c701f38238 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.Files -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SaveMode} @@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = spark.read.format("text").load(path) - intercept[SparkException] { + intercept[IOException] { df.write.format("libsvm").save(path + "_2") } } -- cgit v1.2.3