aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py45
1 files changed, 40 insertions, 5 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 6e9c68ec8a..dd3b66ce67 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -42,6 +42,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
@@ -630,6 +631,12 @@ class ChiSqTestTests(PySparkTestCase):
self.assertIsNotNone(chi[1000])
+class SerDeTest(PySparkTestCase):
+ def test_to_java_object_rdd(self): # SPARK-6660
+ data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
+ self.assertEqual(_to_java_object_rdd(data).count(), 10)
+
+
class FeatureTest(PySparkTestCase):
def test_idf_model(self):
data = [
@@ -643,11 +650,39 @@ class FeatureTest(PySparkTestCase):
self.assertEqual(len(idf), 11)
-class SerDeTest(PySparkTestCase):
- def test_to_java_object_rdd(self): # SPARK-6660
- data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
- self.assertEqual(_to_java_object_rdd(data).count(), 10)
-
+class Word2VecTests(PySparkTestCase):
+ def test_word2vec_setters(self):
+ data = [
+ ["I", "have", "a", "pen"],
+ ["I", "like", "soccer", "very", "much"],
+ ["I", "live", "in", "Tokyo"]
+ ]
+ model = Word2Vec() \
+ .setVectorSize(2) \
+ .setLearningRate(0.01) \
+ .setNumPartitions(2) \
+ .setNumIterations(10) \
+ .setSeed(1024) \
+ .setMinCount(3)
+ self.assertEquals(model.vectorSize, 2)
+ self.assertTrue(model.learningRate < 0.02)
+ self.assertEquals(model.numPartitions, 2)
+ self.assertEquals(model.numIterations, 10)
+ self.assertEquals(model.seed, 1024)
+ self.assertEquals(model.minCount, 3)
+
+ def test_word2vec_get_vectors(self):
+ data = [
+ ["a", "b", "c", "d", "e", "f", "g"],
+ ["a", "b", "c", "d", "e", "f"],
+ ["a", "b", "c", "d", "e"],
+ ["a", "b", "c", "d"],
+ ["a", "b", "c"],
+ ["a", "b"],
+ ["a"]
+ ]
+ model = Word2Vec().fit(self.sc.parallelize(data))
+ self.assertEquals(len(model.getVectors()), 3)
if __name__ == "__main__":
if not _have_scipy: