aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala11
1 files changed, 10 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index bb4d5b983e..fb21ab6b9b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
@@ -101,4 +103,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
}
+
+ test("read/write") {
+ val t = new VectorAssembler()
+ .setInputCols(Array("myInputCol", "myInputCol2"))
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}