aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-03 16:44:25 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 16:44:25 -0700
commit13675c742a71cbdc8324701c3694775ce1dd5c62 (patch)
treee2979af1ae6bf87427f2850d9e8f6bc8791f7560 /mllib
parenta2409d1c8e8ddec04b529ac6f6a12b5993f0eeda (diff)
downloadspark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.gz
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.bz2
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.zip
[SPARK-8874] [ML] Add missing methods in Word2Vec
Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala38
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala62
2 files changed, 99 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 6ea6590956..b4f46cef79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,15 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
@@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
+
+ /**
+ * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
+ * and the vector the DenseVector that it is mapped to.
+ */
+ val getVectors: DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
+ sc.parallelize(wordVec.toSeq).toDF("word", "vector")
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word.
+ * Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word.
+ */
+ def findSynonyms(word: String, num: Int): DataFrame = {
+ findSynonyms(wordVectors.transform(word), num)
+ }
+
+ /**
+ * Find "num" number of words closest to similarity to the given vector representation
+ * of the word. Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word vector.
+ */
+ def findSynonyms(word: Vector, num: Int): DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ }
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index aa6ce533fd..adcda0e623 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}
+
+ test("getVectors") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+
+ val codes = Map(
+ "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
+ "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
+ "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
+ )
+ val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }
+
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val realVectors = model.getVectors.sort("word").select("vector").map {
+ case Row(v: Vector) => v
+ }.collect()
+
+ realVectors.zip(expectedVectors).foreach {
+ case (real, expected) =>
+ assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
+ }
+ }
+
+ test("findSynonyms") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
+ val (synonyms, similarity) = model.findSynonyms("a", 2).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+
+ assert(synonyms.toArray === Array("b", "c"))
+ expectedSimilarity.zip(similarity).map {
+ case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
+ }
+
+ }
}