From 4be360d4ee6cdb4d06306feca38ddef5212608cf Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec05498..801096fed2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9b..9c1c00f41a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) -- cgit v1.2.3