aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorAsher Krim <akrim@hubspot.com>2017-03-07 20:36:46 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-07 20:36:46 -0800
commit56e1bd337ccb03cb01702e4260e4be59d2aa0ead (patch)
tree333318184a59e5bc6cdcf482aad717ed84408e72 /mllib/src/main
parentd8830c5039d9c7c5ef03631904c32873ab558e22 (diff)
downloadspark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.tar.gz
spark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.tar.bz2
spark-56e1bd337ccb03cb01702e4260e4be59d2aa0ead.zip
[SPARK-17629][ML] methods to return synonyms directly
## What changes were proposed in this pull request? provide methods to return synonyms directly, without wrapping them in a dataframe In performance sensitive applications (such as user facing apis) the roundtrip to and from dataframes is costly and unnecessary The methods are named ``findSynonymsArray`` to make the return type clear, which also implies a local datastructure ## How was this patch tested? updated word2vec tests Author: Asher Krim <akrim@hubspot.com> Closes #16811 from Krimit/w2vFindSynonymsLocal.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala37
1 files changed, 31 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 42e8a66a62..4ca062c0b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -227,25 +227,50 @@ class Word2VecModel private[ml] (
/**
* Find "num" number of words closest in similarity to the given word, not
- * including the word itself. Returns a dataframe with the words and the
- * cosine similarities between the synonyms and the given word.
+ * including the word itself.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
+ * similarities between the synonyms and the given word vector.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity")
}
/**
- * Find "num" number of words whose vector representation most similar to the supplied vector.
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
* If the supplied vector is the vector representation of a word in the model's vocabulary,
- * that word will be in the results. Returns a dataframe with the words and the cosine
+ * that word will be in the results.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
* similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity")
+ }
+
+ /**
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(vec, num)
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(word, num)
}
/** @group setParam */