aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/feature.py6
-rw-r--r--python/pyspark/mllib/tests.py14
2 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 0ffe092a07..4bfe3014ef 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -244,6 +244,12 @@ class IDFModel(JavaVectorTransformer):
x = _convert_to_vector(x)
return JavaVectorTransformer.transform(self, x)
+ def idf(self):
+ """
+ Returns the current IDF vector.
+ """
+ return self.call('idf')
+
class IDF(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 155019638f..3bb0f0ca68 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -41,6 +41,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import IDF
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -620,6 +621,19 @@ class ChiSqTestTests(PySparkTestCase):
self.assertEqual(len(chi), num_cols)
self.assertIsNotNone(chi[1000])
+
+class FeatureTest(PySparkTestCase):
+ def test_idf_model(self):
+ data = [
+ Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
+ Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
+ Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
+ Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
+ ]
+ model = IDF().fit(self.sc.parallelize(data, 2))
+ idf = model.idf()
+ self.assertEqual(len(idf), 11)
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"