aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala3
1 files changed, 2 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index d5cb05f29b..a6dfe58e56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -71,7 +71,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
- SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
+ SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}