diff options
author | BenFradet <benjamin.fradet@gmail.com> | 2015-11-22 22:05:01 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-22 22:05:01 -0800 |
commit | 4be360d4ee6cdb4d06306feca38ddef5212608cf (patch) | |
tree | c1965ef7d05999d71a8029dafe93b9f86b9e1831 | |
parent | d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 (diff) | |
download | spark-4be360d4ee6cdb4d06306feca38ddef5212608cf.tar.gz spark-4be360d4ee6cdb4d06306feca38ddef5212608cf.tar.bz2 spark-4be360d4ee6cdb4d06306feca38ddef5212608cf.zip |
[SPARK-11902][ML] Unhandled case in VectorAssembler#transform
There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT.
So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType".
This PR aims to fix this, throwing a SparkException when dealing with an unknown column type.
Author: BenFradet <benjamin.fradet@gmail.com>
Closes #9885 from BenFradet/SPARK-11902.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 11 |
2 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec05498..801096fed2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9b..9c1c00f41a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) |