diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index fd16d3d6c2..7b2a451ca5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -55,7 +55,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { schema(c).dataType match { case DoubleType => UnresolvedAttribute(c) case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType => + case _: NumericType | BooleanType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() } } @@ -68,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { - case _: NumericType => + case _: NumericType | BooleanType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") |