aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-09-24 08:06:41 +0100
committerSean Owen <sowen@cloudera.com>2016-09-24 08:06:41 +0100
commitf3fe55439e4c865c26502487a1bccf255da33f4a (patch)
treed534e9bbc36c8aced17f63d94df7eac1cbdbd5d5
parent7c382524a959a2bc9b3d2fca44f6f0b41aba4e3c (diff)
downloadspark-f3fe55439e4c865c26502487a1bccf255da33f4a.tar.gz
spark-f3fe55439e4c865c26502487a1bccf255da33f4a.tar.bz2
spark-f3fe55439e4c865c26502487a1bccf255da33f4a.zip
[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array
## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen <sowen@cloudera.com> Closes #15179 from srowen/SPARK-10835.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala21
2 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 14c05123c6..d53f3df514 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
+ SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 0b441f8b80..613cc3d60b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
}
+
+ test("Word2Vec works with input that is non-nullable (NGram)") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val sentence = "a q s t q s t b b b s t m s t m q "
+ val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
+
+ val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams")
+ val ngramDF = ngram.transform(docDF)
+
+ val model = new Word2Vec()
+ .setVectorSize(2)
+ .setInputCol("ngrams")
+ .setOutputCol("result")
+ .fit(ngramDF)
+
+ // Just test that this transformation succeeds
+ model.transform(ngramDF).collect()
+ }
+
}