aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala57
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala12
-rw-r--r--python/docs/pyspark.mllib.rst8
-rw-r--r--python/pyspark/mllib/feature.py193
-rwxr-xr-xpython/run-tests1
5 files changed, 264 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e9f4175858..f7251e65e0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
@@ -288,6 +290,59 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib Word2Vec fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ * @param dataJRDD input JavaRDD
+ * @param vectorSize size of vector
+ * @param learningRate initial learning rate
+ * @param numPartitions number of partitions
+ * @param numIterations number of iterations
+ * @param seed initial seed for random generator
+ * @return A handle to java Word2VecModelWrapper instance at python side
+ */
+ def trainWord2Vec(
+ dataJRDD: JavaRDD[java.util.ArrayList[String]],
+ vectorSize: Int,
+ learningRate: Double,
+ numPartitions: Int,
+ numIterations: Int,
+ seed: Long): Word2VecModelWrapper = {
+ val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
+ val word2vec = new Word2Vec()
+ .setVectorSize(vectorSize)
+ .setLearningRate(learningRate)
+ .setNumPartitions(numPartitions)
+ .setNumIterations(numIterations)
+ .setSeed(seed)
+ val model = word2vec.fit(data)
+ data.unpersist()
+ new Word2VecModelWrapper(model)
+ }
+
+ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(words)
+ ret.add(similarity)
+ ret
+ }
+ }
+
+ /**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index fc14447053..d321994c2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -67,7 +67,7 @@ private case class VocabWord(
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
- private var startingAlpha = 0.025
+ private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
- this.startingAlpha = learningRate
+ this.learningRate = learningRate
this
}
@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = startingAlpha
+ var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
lwc = wordCount
// TODO: discount by iteration?
alpha =
- startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
- if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
+ learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
@@ -437,7 +437,7 @@ class Word2VecModel private[mllib] (
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
- * @return array of (word, similarity)
+ * @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index e95d19e97f..4548b8739e 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -20,6 +20,14 @@ pyspark.mllib.clustering module
:undoc-members:
:show-inheritance:
+pyspark.mllib.feature module
+-------------------------------
+
+.. automodule:: pyspark.mllib.feature
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
pyspark.mllib.linalg module
---------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
new file mode 100644
index 0000000000..a44a27fd3b
--- /dev/null
+++ b/python/pyspark/mllib/feature.py
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for feature in MLlib.
+"""
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+
+from pyspark.mllib.linalg import _convert_to_vector
+
+__all__ = ['Word2Vec', 'Word2VecModel']
+
+
+class Word2VecModel(object):
+ """
+ class for Word2Vec model
+ """
+ def __init__(self, sc, java_model):
+ """
+ :param sc: Spark context
+ :param java_model: Handle to Java model object
+ """
+ self._sc = sc
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def transform(self, word):
+ """
+ :param word: a word
+ :return: vector representation of word
+ Transforms a word to its vector representation
+
+ Note: local use only
+ """
+ # TODO: make transform usable in RDD operations from python side
+ result = self._java_model.transform(word)
+ return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
+
+ def findSynonyms(self, x, num):
+ """
+ :param x: a word or a vector representation of word
+ :param num: number of synonyms to find
+ :return: array of (word, cosineSimilarity)
+ Find synonyms of a word
+
+ Note: local use only
+ """
+ # TODO: make findSynonyms usable in RDD operations from python side
+ ser = PickleSerializer()
+ if type(x) == str:
+ jlist = self._java_model.findSynonyms(x, num)
+ else:
+ bytes = bytearray(ser.dumps(_convert_to_vector(x)))
+ vec = self._sc._jvm.SerDe.loads(bytes)
+ jlist = self._java_model.findSynonyms(vec, num)
+ words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist)))
+ return zip(words, similarity)
+
+
+class Word2Vec(object):
+ """
+ Word2Vec creates vector representation of words in a text corpus.
+ The algorithm first constructs a vocabulary from the corpus
+ and then learns vector representation of words in the vocabulary.
+ The vector representation can be used as features in
+ natural language processing and machine learning algorithms.
+
+ We used skip-gram model in our implementation and hierarchical softmax
+ method to train the model. The variable names in the implementation
+ matches the original C implementation.
+ For original C implementation, see https://code.google.com/p/word2vec/
+ For research papers, see
+ Efficient Estimation of Word Representations in Vector Space
+ and
+ Distributed Representations of Words and Phrases and their Compositionality.
+
+ >>> sentence = "a b " * 100 + "a c " * 10
+ >>> localDoc = [sentence, sentence]
+ >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
+ >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
+ >>> syms = model.findSynonyms("a", 2)
+ >>> str(syms[0][0])
+ 'b'
+ >>> str(syms[1][0])
+ 'c'
+ >>> len(syms)
+ 2
+ >>> vec = model.transform("a")
+ >>> len(vec)
+ 10
+ >>> syms = model.findSynonyms(vec, 2)
+ >>> str(syms[0][0])
+ 'b'
+ >>> str(syms[1][0])
+ 'c'
+ >>> len(syms)
+ 2
+ """
+ def __init__(self):
+ """
+ Construct Word2Vec instance
+ """
+ self.vectorSize = 100
+ self.learningRate = 0.025
+ self.numPartitions = 1
+ self.numIterations = 1
+ self.seed = 42L
+
+ def setVectorSize(self, vectorSize):
+ """
+ Sets vector size (default: 100).
+ """
+ self.vectorSize = vectorSize
+ return self
+
+ def setLearningRate(self, learningRate):
+ """
+ Sets initial learning rate (default: 0.025).
+ """
+ self.learningRate = learningRate
+ return self
+
+ def setNumPartitions(self, numPartitions):
+ """
+ Sets number of partitions (default: 1). Use a small number for accuracy.
+ """
+ self.numPartitions = numPartitions
+ return self
+
+ def setNumIterations(self, numIterations):
+ """
+ Sets number of iterations (default: 1), which should be smaller than or equal to number of
+ partitions.
+ """
+ self.numIterations = numIterations
+ return self
+
+ def setSeed(self, seed):
+ """
+ Sets random seed.
+ """
+ self.seed = seed
+ return self
+
+ def fit(self, data):
+ """
+ Computes the vector representation of each word in vocabulary.
+
+ :param data: training data. RDD of subtype of Iterable[String]
+ :return: python Word2VecModel instance
+ """
+ sc = data.context
+ ser = PickleSerializer()
+ vectorSize = self.vectorSize
+ learningRate = self.learningRate
+ numPartitions = self.numPartitions
+ numIterations = self.numIterations
+ seed = self.seed
+
+ model = sc._jvm.PythonMLLibAPI().trainWord2Vec(
+ data._to_java_object_rdd(), vectorSize,
+ learningRate, numPartitions, numIterations, seed)
+ return Word2VecModel(sc, model)
+
+
+def _test():
+ import doctest
+ from pyspark import SparkContext
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/run-tests b/python/run-tests
index c713861eb7..63395f7278 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -69,6 +69,7 @@ function run_mllib_tests() {
echo "Run mllib tests ..."
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
+ run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"