aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiquan Pei <liquanpei@gmail.com>2014-10-07 16:43:34 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-07 16:43:34 -0700
commit098c7344e64e69dffdcf0d95fe1c9e65a54e98f3 (patch)
treef092d22cf1eb086f298ee3fa8a782fa3481bdf88 /mllib
parent3d7b36e0de26049e8b36b6705d8ff4224bde9eb1 (diff)
downloadspark-098c7344e64e69dffdcf0d95fe1c9e65a54e98f3.tar.gz
spark-098c7344e64e69dffdcf0d95fe1c9e65a54e98f3.tar.bz2
spark-098c7344e64e69dffdcf0d95fe1c9e65a54e98f3.zip
[SPARK-3486][MLlib][PySpark] PySpark support for Word2Vec
mengxr Added PySpark support for Word2Vec Change list (1) PySpark support for Word2Vec (2) SerDe support of string sequence both on python side and JVM side (3) Test for SerDe of string sequence on JVM side Author: Liquan Pei <liquanpei@gmail.com> Closes #2356 from Ishiihara/Word2Vec-python and squashes the following commits: 476ea34 [Liquan Pei] style fixes b13a0b9 [Liquan Pei] resolve merge conflicts and minor fixes 8671eba [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python daf88a6 [Liquan Pei] modification according to feedback a73fa19 [Liquan Pei] clean up 3d8007b [Liquan Pei] fix findSynonyms for vector 1bdcd2e [Liquan Pei] minor fixes cdef9f4 [Liquan Pei] add missing comments b7447eb [Liquan Pei] modify according to feedback b9a7383 [Liquan Pei] cache words RDD in fit 89490bf [Liquan Pei] add tests and Word2VecModelWrapper 78bbb53 [Liquan Pei] use pickle for seq string SerDe a264b08 [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python ca1e5ff [Liquan Pei] fix test 68e7276 [Liquan Pei] minor style fixes 48d5e72 [Liquan Pei] Functionality improvement 0ad3ac1 [Liquan Pei] minor fix c867fdf [Liquan Pei] add Word2Vec to pyspark
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala57
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala12
2 files changed, 62 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e9f4175858..f7251e65e0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
@@ -288,6 +290,59 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib Word2Vec fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ * @param dataJRDD input JavaRDD
+ * @param vectorSize size of vector
+ * @param learningRate initial learning rate
+ * @param numPartitions number of partitions
+ * @param numIterations number of iterations
+ * @param seed initial seed for random generator
+ * @return A handle to java Word2VecModelWrapper instance at python side
+ */
+ def trainWord2Vec(
+ dataJRDD: JavaRDD[java.util.ArrayList[String]],
+ vectorSize: Int,
+ learningRate: Double,
+ numPartitions: Int,
+ numIterations: Int,
+ seed: Long): Word2VecModelWrapper = {
+ val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
+ val word2vec = new Word2Vec()
+ .setVectorSize(vectorSize)
+ .setLearningRate(learningRate)
+ .setNumPartitions(numPartitions)
+ .setNumIterations(numIterations)
+ .setSeed(seed)
+ val model = word2vec.fit(data)
+ data.unpersist()
+ new Word2VecModelWrapper(model)
+ }
+
+ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(words)
+ ret.add(similarity)
+ ret
+ }
+ }
+
+ /**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index fc14447053..d321994c2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -67,7 +67,7 @@ private case class VocabWord(
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
- private var startingAlpha = 0.025
+ private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
- this.startingAlpha = learningRate
+ this.learningRate = learningRate
this
}
@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = startingAlpha
+ var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
lwc = wordCount
// TODO: discount by iteration?
alpha =
- startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
- if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
+ learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
@@ -437,7 +437,7 @@ class Word2VecModel private[mllib] (
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
- * @return array of (word, similarity)
+ * @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)