aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-03 16:44:25 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 16:44:25 -0700
commit13675c742a71cbdc8324701c3694775ce1dd5c62 (patch)
treee2979af1ae6bf87427f2850d9e8f6bc8791f7560 /mllib/src/main/scala
parenta2409d1c8e8ddec04b529ac6f6a12b5993f0eeda (diff)
downloadspark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.gz
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.tar.bz2
spark-13675c742a71cbdc8324701c3694775ce1dd5c62.zip
[SPARK-8874] [ML] Add missing methods in Word2Vec
Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala38
1 files changed, 37 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 6ea6590956..b4f46cef79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,15 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
@@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
+
+ /**
+ * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
+ * and the vector the DenseVector that it is mapped to.
+ */
+ val getVectors: DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
+ sc.parallelize(wordVec.toSeq).toDF("word", "vector")
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word.
+ * Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word.
+ */
+ def findSynonyms(word: String, num: Int): DataFrame = {
+ findSynonyms(wordVectors.transform(word), num)
+ }
+
+ /**
+ * Find "num" number of words closest to similarity to the given vector representation
+ * of the word. Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word vector.
+ */
+ def findSynonyms(word: Vector, num: Int): DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ }
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)