aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-12-14 09:59:42 -0800
committerDavies Liu <davies.liu@gmail.com>2015-12-14 09:59:42 -0800
commitb51a4cdff3a7e640a8a66f7a9c17021f3056fd34 (patch)
tree8b550b38428e02e6913e437f3f04e8f45a775149
parente25f1fe42747be71c6b6e6357ca214f9544e3a46 (diff)
downloadspark-b51a4cdff3a7e640a8a66f7a9c17021f3056fd34.tar.gz
spark-b51a4cdff3a7e640a8a66f7a9c17021f3056fd34.tar.bz2
spark-b51a4cdff3a7e640a8a66f7a9c17021f3056fd34.zip
[SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark
JIRA: https://issues.apache.org/jira/browse/SPARK-12016 We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #10100 from viirya/fix-load-py-wordvecmodel.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala62
-rw-r--r--python/pyspark/mllib/feature.py6
3 files changed, 67 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 8d546e3d60..29160a10e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable {
}
}
- private[python] class Word2VecModelWrapper(model: Word2VecModel) {
- def transform(word: String): Vector = {
- model.transform(word)
- }
-
- /**
- * Transforms an RDD of words to its vector representation
- * @param rdd an RDD of words
- * @return an RDD of vector representations of words
- */
- def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
- rdd.rdd.map(model.transform)
- }
-
- def findSynonyms(word: String, num: Int): JList[Object] = {
- val vec = transform(word)
- findSynonyms(vec, num)
- }
-
- def findSynonyms(vector: Vector, num: Int): JList[Object] = {
- val result = model.findSynonyms(vector, num)
- val similarity = Vectors.dense(result.map(_._2))
- val words = result.map(_._1)
- List(words, similarity).map(_.asInstanceOf[Object]).asJava
- }
-
- def getVectors: JMap[String, JList[Float]] = {
- model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
- }
-
- def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
- }
-
/**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
new file mode 100644
index 0000000000..0f55980481
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.feature.Word2VecModel
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+/**
+ * Wrapper around Word2VecModel to provide helper methods in Python
+ */
+private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ /**
+ * Transforms an RDD of words to its vector representation
+ * @param rdd an RDD of words
+ * @return an RDD of vector representations of words
+ */
+ def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
+ rdd.rdd.map(model.transform)
+ }
+
+ def findSynonyms(word: String, num: Int): JList[Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
+ }
+
+ def getVectors: JMap[String, JList[Float]] = {
+ model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
+ }
+
+ def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 7b077b058c..7254679ebb 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -504,7 +504,8 @@ class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
"""
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
- return Word2VecModel(jmodel)
+ model = sc._jvm.Word2VecModelWrapper(jmodel)
+ return Word2VecModel(model)
@ignore_unicode_prefix
@@ -546,6 +547,9 @@ class Word2Vec(object):
>>> sameModel = Word2VecModel.load(sc, path)
>>> model.transform("a") == sameModel.transform("a")
True
+ >>> syms = sameModel.findSynonyms("a", 2)
+ >>> [s[0] for s in syms]
+ [u'b', u'c']
>>> from shutil import rmtree
>>> try:
... rmtree(path)