From 13675c742a71cbdc8324701c3694775ce1dd5c62 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 3 Aug 2015 16:44:25 -0700 Subject: [SPARK-8874] [ML] Add missing methods in Word2Vec Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec --- .../org/apache/spark/ml/feature/Word2Vec.scala | 38 ++++++++++++- .../apache/spark/ml/feature/Word2VecSuite.scala | 62 ++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 6ea6590956..b4f46cef79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -18,15 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -146,6 +148,40 @@ class Word2VecModel private[ml] ( wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + val getVectors: DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + sc.parallelize(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word. + * Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word. + */ + def findSynonyms(word: String, num: Int): DataFrame = { + findSynonyms(wordVectors.transform(word), num) + } + + /** + * Find "num" number of words closest to similarity to the given vector representation + * of the word. Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word vector. + */ + def findSynonyms(word: Vector, num: Int): DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + } + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index aa6ce533fd..adcda0e623 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") } } + + test("getVectors") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + + realVectors.zip(expectedVectors).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") + } + } + + test("findSynonyms") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + + } } -- cgit v1.2.3