aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala74
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala2
-rw-r--r--python/pyspark/ml/feature.py12
3 files changed, 44 insertions, 44 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index dee898827f..3241ebeb22 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging {
private var numIterations = 1
private var seed = Utils.random.nextLong()
private var minCount = 5
+ private var maxSentenceLength = 1000
+
+ /**
+ * Sets the maximum length (in words) of each sentence in the input data.
+ * Any sentence longer than this threshold will be divided into chunks of
+ * up to `maxSentenceLength` size (default: 1000)
+ */
+ @Since("2.0.0")
+ def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
+ this.maxSentenceLength = maxSentenceLength
+ this
+ }
/**
* Sets vector size (default: 100).
@@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging {
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
- private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */
private var window = 5
@@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging {
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]
- private def learnVocab(words: RDD[String]): Unit = {
+ private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {
+ val words = dataset.flatMap(x => x)
+
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
@@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging {
/**
* Computes the vector representation of each word in vocabulary.
- * @param dataset an RDD of words
+ * @param dataset an RDD of sentences,
+ * each sentence is expressed as an iterable collection of words
* @return a Word2VecModel
*/
@Since("1.1.0")
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
- val words = dataset.flatMap(x => x)
-
- learnVocab(words)
+ learnVocab(dataset)
createBinaryTree()
@@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
-
- val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
- new Iterator[Array[Int]] {
- def hasNext: Boolean = iter.hasNext
-
- def next(): Array[Int] = {
- val sentence = ArrayBuilder.make[Int]
- var sentenceLength = 0
- while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
- val word = bcVocabHash.value.get(iter.next())
- word match {
- case Some(w) =>
- sentence += w
- sentenceLength += 1
- case None =>
- }
- }
- sentence.result()
- }
+ // each partition is a collection of sentences,
+ // will be translated into arrays of Index integer
+ val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
+ // Each sentence will map to 0 or more Array[Int]
+ sentenceIter.flatMap { sentence =>
+ // Sentence of words, some of which map to a word index
+ val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
+ // break wordIndexes into trunks of maxSentenceLength when has more
+ wordIndexes.grouped(maxSentenceLength).map(_.toArray)
}
}
@@ -477,15 +479,6 @@ class Word2VecModel private[spark] (
this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
}
- private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
- require(v1.length == v2.length, "Vectors should have the same length")
- val n = v1.length
- val norm1 = blas.snrm2(n, v1, 1)
- val norm2 = blas.snrm2(n, v2, 1)
- if (norm1 == 0 || norm2 == 0) return 0.0
- blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2
- }
-
override protected def formatVersion = "1.0"
@Since("1.4.0")
@@ -542,6 +535,7 @@ class Word2VecModel private[spark] (
// Need not divide with the norm of the given vector since it is constant.
val cosVec = cosineVec.map(_.toDouble)
var ind = 0
+ val vecNorm = blas.snrm2(vectorSize, fVector, 1)
while (ind < numWords) {
val norm = wordVecNorms(ind)
if (norm == 0.0) {
@@ -551,12 +545,17 @@ class Word2VecModel private[spark] (
}
ind += 1
}
- wordList.zip(cosVec)
+ var topResults = wordList.zip(cosVec)
.toSeq
- .sortBy(- _._2)
+ .sortBy(-_._2)
.take(num + 1)
.tail
- .toArray
+ if (vecNorm != 0.0f) {
+ topResults = topResults.map { case (word, cosVal) =>
+ (word, cosVal / vecNorm)
+ }
+ }
+ topResults.toArray
}
/**
@@ -568,6 +567,7 @@ class Word2VecModel private[spark] (
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
}
}
+
}
@Since("1.4.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a73b565125..f094c550e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -133,7 +133,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setSeed(42L)
.fit(docDF)
- val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823)
+ val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
val (synonyms, similarity) = model.findSynonyms("a", 2).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index d017a23188..464c9446f2 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1836,12 +1836,12 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
+----+--------------------+
...
>>> model.findSynonyms("a", 2).show()
- +----+--------------------+
- |word| similarity|
- +----+--------------------+
- | b| 0.16782984556103436|
- | c|-0.46761559092107646|
- +----+--------------------+
+ +----+-------------------+
+ |word| similarity|
+ +----+-------------------+
+ | b| 0.2505344027513247|
+ | c|-0.6980510075367647|
+ +----+-------------------+
...
>>> model.transform(doc).head().model
DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])