aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/feature.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 8cb992df2d..741c630cbd 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -28,7 +28,7 @@ from py4j.protocol import Py4JJavaError
from pyspark import RDD, SparkContext
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@@ -212,7 +212,7 @@ class IDFModel(JavaVectorTransformer):
"""
Represents an IDF model that can transform term frequency vectors.
"""
- def transform(self, dataset):
+ def transform(self, x):
"""
Transforms term frequency (TF) vectors to TF-IDF vectors.
@@ -220,12 +220,14 @@ class IDFModel(JavaVectorTransformer):
the terms which occur in fewer than `minDocFreq`
documents will have an entry of 0.
- :param dataset: an RDD of term frequency vectors
- :return: an RDD of TF-IDF vectors
+ :param x: an RDD of term frequency vectors or a term frequency vector
+ :return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
- if not isinstance(dataset, RDD):
- raise TypeError("dataset should be an RDD of term frequency vectors")
- return JavaVectorTransformer.transform(self, dataset)
+ if isinstance(x, RDD):
+ return JavaVectorTransformer.transform(self, x)
+
+ x = _convert_to_vector(x)
+ return JavaVectorTransformer.transform(self, x)
class IDF(object):
@@ -255,6 +257,12 @@ class IDF(object):
SparseVector(4, {1: 0.0, 3: 0.5754})
DenseVector([0.0, 0.0, 1.3863, 0.863])
SparseVector(4, {1: 0.0})
+ >>> model.transform(Vectors.dense([0.0, 1.0, 2.0, 3.0]))
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform([0.0, 1.0, 2.0, 3.0])
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0)))
+ SparseVector(4, {1: 0.0, 3: 0.5754})
"""
def __init__(self, minDocFreq=0):
"""