aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsher Krim <akrim@hubspot.com>2017-03-07 20:36:46 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-07 20:36:46 -0800
commit56e1bd337ccb03cb01702e4260e4be59d2aa0ead (patch)
tree333318184a59e5bc6cdcf482aad717ed84408e72
parentd8830c5039d9c7c5ef03631904c32873ab558e22 (diff)
downloadspark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.tar.gz
spark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.tar.bz2
spark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.zip
[SPARK-17629][ML] methods to return synonyms directly
## What changes were proposed in this pull request? provide methods to return synonyms directly, without wrapping them in a dataframe In performance sensitive applications (such as user facing apis) the roundtrip to and from dataframes is costly and unnecessary The methods are named ``findSynonymsArray`` to make the return type clear, which also implies a local datastructure ## How was this patch tested? updated word2vec tests Author: Asher Krim <akrim@hubspot.com> Closes #16811 from Krimit/w2vFindSynonymsLocal.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala20
2 files changed, 45 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 42e8a66a62..4ca062c0b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -227,25 +227,50 @@ class Word2VecModel private[ml] (
/**
* Find "num" number of words closest in similarity to the given word, not
- * including the word itself. Returns a dataframe with the words and the
- * cosine similarities between the synonyms and the given word.
+ * including the word itself.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
+ * similarities between the synonyms and the given word vector.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity")
}
/**
- * Find "num" number of words whose vector representation most similar to the supplied vector.
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
* If the supplied vector is the vector representation of a word in the model's vocabulary,
- * that word will be in the results. Returns a dataframe with the words and the cosine
+ * that word will be in the results.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
* similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity")
+ }
+
+ /**
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(vec, num)
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(word, num)
}
/** @group setParam */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 613cc3d60b..2043a16c15 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setSeed(42L)
.fit(docDF)
- val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
- val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map {
+ val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078))
+ val findSynonymsResult = model.findSynonyms("a", 2).rdd.map {
case Row(w: String, sim: Double) => (w, sim)
- }.collect().unzip
+ }.collectAsMap()
+
+ expected.foreach {
+ case (expectedSynonym, expectedSimilarity) =>
+ assert(findSynonymsResult.contains(expectedSynonym))
+ assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5)
+ }
- assert(synonyms === Array("b", "c"))
- expectedSimilarity.zip(similarity).foreach {
- case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
+ val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap
+ findSynonymsResult.foreach {
+ case (expectedSynonym, expectedSimilarity) =>
+ assert(findSynonymsArrayResult.contains(expectedSynonym))
+ assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5)
}
}