diff options
author | Sean Owen <sowen@cloudera.com> | 2016-09-24 08:06:41 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-09-24 08:06:41 +0100 |
commit | f3fe55439e4c865c26502487a1bccf255da33f4a (patch) | |
tree | d534e9bbc36c8aced17f63d94df7eac1cbdbd5d5 | |
parent | 7c382524a959a2bc9b3d2fca44f6f0b41aba4e3c (diff) | |
download | spark-f3fe55439e4c865c26502487a1bccf255da33f4a.tar.gz spark-f3fe55439e4c865c26502487a1bccf255da33f4a.tar.bz2 spark-f3fe55439e4c865c26502487a1bccf255da33f4a.zip |
[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array
## What changes were proposed in this pull request?
To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram
## How was this patch tested?
Jenkins tests.
Author: Sean Owen <sowen@cloudera.com>
Closes #15179 from srowen/SPARK-10835.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 3 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 21 |
2 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c05123c6..d53f3df514 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0b441f8b80..613cc3d60b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } |