aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala21
2 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 14c05123c6..d53f3df514 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
+ SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 0b441f8b80..613cc3d60b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
}
+
+ test("Word2Vec works with input that is non-nullable (NGram)") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val sentence = "a q s t q s t b b b s t m s t m q "
+ val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
+
+ val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams")
+ val ngramDF = ngram.transform(docDF)
+
+ val model = new Word2Vec()
+ .setVectorSize(2)
+ .setInputCol("ngrams")
+ .setOutputCol("result")
+ .fit(ngramDF)
+
+ // Just test that this transformation succeeds
+ model.transform(ngramDF).collect()
+ }
+
}