aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-03 16:44:25 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 16:44:25 -0700
commit13675c742a71cbdc8324701c3694775ce1dd5c62 (patch)
treee2979af1ae6bf87427f2850d9e8f6bc8791f7560 /mllib/src/test
parenta2409d1c8e8ddec04b529ac6f6a12b5993f0eeda (diff)
downloadspark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.gz
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.bz2
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.zip
[SPARK-8874] [ML] Add missing methods in Word2Vec
Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala62
1 files changed, 62 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index aa6ce533fd..adcda0e623 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}
+
+ test("getVectors") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+
+ val codes = Map(
+ "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
+ "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
+ "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
+ )
+ val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }
+
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val realVectors = model.getVectors.sort("word").select("vector").map {
+ case Row(v: Vector) => v
+ }.collect()
+
+ realVectors.zip(expectedVectors).foreach {
+ case (real, expected) =>
+ assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
+ }
+ }
+
+ test("findSynonyms") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
+ val (synonyms, similarity) = model.findSynonyms("a", 2).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+
+ assert(synonyms.toArray === Array("b", "c"))
+ expectedSimilarity.zip(similarity).map {
+ case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
+ }
+
+ }
}