diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-08-03 16:44:25 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 16:44:25 -0700 |
commit | 13675c742a71cbdc8324701c3694775ce1dd5c62 (patch) | |
tree | e2979af1ae6bf87427f2850d9e8f6bc8791f7560 /mllib/src/test | |
parent | a2409d1c8e8ddec04b529ac6f6a12b5993f0eeda (diff) | |
download | spark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.gz spark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.bz2 spark-13675c742a71cbdc8324701c3694775ce1dd5c62.zip |
[SPARK-8874] [ML] Add missing methods in Word2Vec
Add missing methods
1. getVectors
2. findSynonyms
to W2Vec scala and python API
mengxr
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits:
149d5ca [MechCoder] minor doc
69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index aa6ce533fd..adcda0e623 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") } } + + test("getVectors") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + + realVectors.zip(expectedVectors).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") + } + } + + test("findSynonyms") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + + } } |