diff options
author | lewuathe <lewuathe@me.com> | 2015-03-31 11:25:21 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-31 11:25:21 -0700 |
commit | 46de6c05e0619250346f0988e296849f8f93d2b1 (patch) | |
tree | f5d626ca33660394ac699396ef5a3d003618daff /python | |
parent | cd48ca50129e8952f487051796244e7569275416 (diff) | |
download | spark-46de6c05e0619250346f0988e296849f8f93d2b1.tar.gz spark-46de6c05e0619250346f0988e296849f8f93d2b1.tar.bz2 spark-46de6c05e0619250346f0988e296849f8f93d2b1.zip |
[SPARK-6598][MLLIB] Python API for IDFModel
This is the sub-task of SPARK-6254.
Wrapping IDFModel `idf` member function for pyspark.
Author: lewuathe <lewuathe@me.com>
Closes #5264 from Lewuathe/SPARK-6598 and squashes the following commits:
1dc522c [lewuathe] [SPARK-6598] Python API for IDFModel
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/feature.py | 6 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 14 |
2 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 0ffe092a07..4bfe3014ef 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -244,6 +244,12 @@ class IDFModel(JavaVectorTransformer): x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 155019638f..3bb0f0ca68 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -41,6 +41,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _ from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import IDF from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -620,6 +621,19 @@ class ChiSqTestTests(PySparkTestCase): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class FeatureTest(PySparkTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" |