aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala21
-rw-r--r--python/pyspark/ml/feature.py15
3 files changed, 37 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 5b079fce3a..7e6c367970 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -540,14 +540,16 @@ class Word2VecModel private[spark] (
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
val beta: Float = 0
-
+ // Normalize input vector before blas.sgemv to avoid Inf value
+ val vecNorm = blas.snrm2(vectorSize, fVector, 1)
+ if (vecNorm != 0.0f) {
+ blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)
+ }
blas.sgemv(
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
- // Need not divide with the norm of the given vector since it is constant.
val cosVec = cosineVec.map(_.toDouble)
var ind = 0
- val vecNorm = blas.snrm2(vectorSize, fVector, 1)
while (ind < numWords) {
val norm = wordVecNorms(ind)
if (norm == 0.0) {
@@ -557,17 +559,13 @@ class Word2VecModel private[spark] (
}
ind += 1
}
- var topResults = wordList.zip(cosVec)
+
+ wordList.zip(cosVec)
.toSeq
.sortBy(-_._2)
.take(num + 1)
.tail
- if (vecNorm != 0.0f) {
- topResults = topResults.map { case (word, cosVal) =>
- (word, cosVal / vecNorm)
- }
- }
- topResults.toArray
+ .toArray
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 4fcf417d5f..6d699440f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -108,5 +108,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("test similarity for word vectors with large values is not Infinity or NaN") {
+ val vecA = Array(-4.331467827487745E21, -5.26707742075006E21,
+ 5.63551690626524E21, 2.833692188614257E21, -1.9688159903619345E21, -4.933950659913092E21,
+ -2.7401535502536787E21, -1.418671793782632E20).map(_.toFloat)
+ val vecB = Array(-3.9850175451103232E16, -3.4829783883841536E16,
+ 9.421469251534848E15, 4.4069684466679808E16, 7.20936298872832E15, -4.2883302830374912E16,
+ -3.605579947835392E16, -2.8151294422155264E16).map(_.toFloat)
+ val vecC = Array(-1.9227381025734656E16, -3.907009342603264E16,
+ 2.110207626838016E15, -4.8770066610651136E16, -1.9734964555743232E16, -3.2206001247617024E16,
+ 2.7725358220443648E16, 3.1618718156980224E16).map(_.toFloat)
+ val wordMapIn = Map(
+ ("A", vecA),
+ ("B", vecB),
+ ("C", vecC)
+ )
+
+ val model = new Word2VecModel(wordMapIn)
+ model.findSynonyms("A", 5).foreach { pair =>
+ assert(!(pair._2.isInfinite || pair._2.isNaN))
+ }
+ }
}
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 610d167f3a..1b059a7199 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2186,13 +2186,14 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
| c|[-0.3794820010662...|
+----+--------------------+
...
- >>> model.findSynonyms("a", 2).show()
- +----+-------------------+
- |word| similarity|
- +----+-------------------+
- | b| 0.2505344027513247|
- | c|-0.6980510075367647|
- +----+-------------------+
+ >>> from pyspark.sql.functions import format_number as fmt
+ >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show()
+ +----+----------+
+ |word|similarity|
+ +----+----------+
+ | b| 0.25053|
+ | c| -0.69805|
+ +----+----------+
...
>>> model.transform(doc).head().model
DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])