aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala11
2 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 0feec05498..801096fed2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String)
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.fill(numAttrs)(NumericAttribute.defaultAttr)
}
+ case otherType =>
+ throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index fb21ab6b9b..9c1c00f41a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -69,6 +69,17 @@ class VectorAssemblerSuite
}
}
+ test("transform should throw an exception in case of unsupported type") {
+ val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("a", "b", "c"))
+ .setOutputCol("features")
+ val thrown = intercept[SparkException] {
+ assembler.transform(df)
+ }
+ assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
+ }
+
test("ML attributes") {
val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)