aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-03-31 11:25:21 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-31 11:25:21 -0700
commit46de6c05e0619250346f0988e296849f8f93d2b1 (patch)
treef5d626ca33660394ac699396ef5a3d003618daff /python
parentcd48ca50129e8952f487051796244e7569275416 (diff)
downloadspark-46de6c05e0619250346f0988e296849f8f93d2b1.tar.gz
spark-46de6c05e0619250346f0988e296849f8f93d2b1.tar.bz2
spark-46de6c05e0619250346f0988e296849f8f93d2b1.zip
[SPARK-6598][MLLIB] Python API for IDFModel
This is the sub-task of SPARK-6254. Wrapping IDFModel `idf` member function for pyspark. Author: lewuathe <lewuathe@me.com> Closes #5264 from Lewuathe/SPARK-6598 and squashes the following commits: 1dc522c [lewuathe] [SPARK-6598] Python API for IDFModel
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/feature.py6
-rw-r--r--python/pyspark/mllib/tests.py14
2 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 0ffe092a07..4bfe3014ef 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -244,6 +244,12 @@ class IDFModel(JavaVectorTransformer):
x = _convert_to_vector(x)
return JavaVectorTransformer.transform(self, x)
+ def idf(self):
+ """
+ Returns the current IDF vector.
+ """
+ return self.call('idf')
+
class IDF(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 155019638f..3bb0f0ca68 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -41,6 +41,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import IDF
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -620,6 +621,19 @@ class ChiSqTestTests(PySparkTestCase):
self.assertEqual(len(chi), num_cols)
self.assertIsNotNone(chi[1000])
+
+class FeatureTest(PySparkTestCase):
+ def test_idf_model(self):
+ data = [
+ Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
+ Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
+ Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
+ Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
+ ]
+ model = IDF().fit(self.sc.parallelize(data, 2))
+ idf = model.idf()
+ self.assertEqual(len(idf), 11)
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"