aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala57
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala12
2 files changed, 62 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e9f4175858..f7251e65e0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
@@ -288,6 +290,59 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib Word2Vec fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ * @param dataJRDD input JavaRDD
+ * @param vectorSize size of vector
+ * @param learningRate initial learning rate
+ * @param numPartitions number of partitions
+ * @param numIterations number of iterations
+ * @param seed initial seed for random generator
+ * @return A handle to java Word2VecModelWrapper instance at python side
+ */
+ def trainWord2Vec(
+ dataJRDD: JavaRDD[java.util.ArrayList[String]],
+ vectorSize: Int,
+ learningRate: Double,
+ numPartitions: Int,
+ numIterations: Int,
+ seed: Long): Word2VecModelWrapper = {
+ val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
+ val word2vec = new Word2Vec()
+ .setVectorSize(vectorSize)
+ .setLearningRate(learningRate)
+ .setNumPartitions(numPartitions)
+ .setNumIterations(numIterations)
+ .setSeed(seed)
+ val model = word2vec.fit(data)
+ data.unpersist()
+ new Word2VecModelWrapper(model)
+ }
+
+ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(words)
+ ret.add(similarity)
+ ret
+ }
+ }
+
+ /**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index fc14447053..d321994c2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -67,7 +67,7 @@ private case class VocabWord(
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
- private var startingAlpha = 0.025
+ private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
- this.startingAlpha = learningRate
+ this.learningRate = learningRate
this
}
@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = startingAlpha
+ var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
lwc = wordCount
// TODO: discount by iteration?
alpha =
- startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
- if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
+ learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
@@ -437,7 +437,7 @@ class Word2VecModel private[mllib] (
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
- * @return array of (word, similarity)
+ * @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)