aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-01-08 00:42:09 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-08 00:42:09 +0800
commitb3d39620c563e5f6a32a4082aa3908e1009c17d2 (patch)
tree2b927e760096c78895228cb9d79f6730c8e9d182 /mllib/src
parentcdda3372a39b508f7d159426749b682476c813b9 (diff)
downloadspark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.tar.gz
spark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.tar.bz2
spark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.zip
[SPARK-19085][SQL] cleanup OutputWriterFactory and OutputWriter
## What changes were proposed in this pull request? `OutputWriterFactory`/`OutputWriter` are internal interfaces and we can remove some unnecessary APIs: 1. `OutputWriterFactory.newWriter(path: String)`: no one calls it and no one implements it. 2. `OutputWriter.write(row: Row)`: during execution we only call `writeInternal`, which is weird as `OutputWriter` is already an internal interface. We should rename `writeInternal` to `write` and remove `def write(row: Row)` and it's related converter code. All implementations should just implement `def write(row: InternalRow)` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #16479 from cloud-fan/hive-writer.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala6
2 files changed, 10 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b5aa7ce4e1..89bbc1556c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter(
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
- override def write(row: Row): Unit = {
- val label = row.get(0)
- val vector = row.get(1).asInstanceOf[Vector]
+ // This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema`
+ private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT]
+
+ override def write(row: InternalRow): Unit = {
+ val label = row.getDouble(0)
+ val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length))
writer.write(label.toString)
vector.foreachActive { case (i, v) =>
writer.write(s" ${i + 1}:$v")
@@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
+ verifySchema(dataSchema)
new OutputWriterFactory {
override def newInstance(
path: String,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 2517de59fe..c701f38238 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.ml.source.libsvm
-import java.io.File
+import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import com.google.common.io.Files
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
@@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = spark.read.format("text").load(path)
- intercept[SparkException] {
+ intercept[IOException] {
df.write.format("libsvm").save(path + "_2")
}
}