aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYuu ISHIKAWA <yuu.ishikawa@gmail.com>2014-12-15 13:44:15 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-15 13:44:15 -0800
commit8098fab06cb2be22cca4e531e8e65ab29dbb909a (patch)
treef419594f9e6671f1bb4af54d17544d0ee78ca7e3 /python/pyspark
parent4c0673879b5c504797dafb11607d14b04c1bf47d (diff)
downloadspark-8098fab06cb2be22cca4e531e8e65ab29dbb909a.tar.gz
spark-8098fab06cb2be22cca4e531e8e65ab29dbb909a.tar.bz2
spark-8098fab06cb2be22cca4e531e8e65ab29dbb909a.zip
[SPARK-4494][mllib] IDFModel.transform() add support for single vector
I improved `IDFModel.transform` to allow using a single vector. [[SPARK-4494] IDFModel.transform() add support for single vector - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-4494) Author: Yuu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #3603 from yu-iskw/idf and squashes the following commits: 256ff3d [Yuu ISHIKAWA] Fix typo a3bf566 [Yuu ISHIKAWA] - Fix typo - Optimize import order - Aggregate the assertion tests - Modify `IDFModel.transform` API for pyspark d25e49b [Yuu ISHIKAWA] Add the implementation of `IDFModel.transform` for a term frequency vector
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/feature.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 8cb992df2d..741c630cbd 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -28,7 +28,7 @@ from py4j.protocol import Py4JJavaError
from pyspark import RDD, SparkContext
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@@ -212,7 +212,7 @@ class IDFModel(JavaVectorTransformer):
"""
Represents an IDF model that can transform term frequency vectors.
"""
- def transform(self, dataset):
+ def transform(self, x):
"""
Transforms term frequency (TF) vectors to TF-IDF vectors.
@@ -220,12 +220,14 @@ class IDFModel(JavaVectorTransformer):
the terms which occur in fewer than `minDocFreq`
documents will have an entry of 0.
- :param dataset: an RDD of term frequency vectors
- :return: an RDD of TF-IDF vectors
+ :param x: an RDD of term frequency vectors or a term frequency vector
+ :return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
- if not isinstance(dataset, RDD):
- raise TypeError("dataset should be an RDD of term frequency vectors")
- return JavaVectorTransformer.transform(self, dataset)
+ if isinstance(x, RDD):
+ return JavaVectorTransformer.transform(self, x)
+
+ x = _convert_to_vector(x)
+ return JavaVectorTransformer.transform(self, x)
class IDF(object):
@@ -255,6 +257,12 @@ class IDF(object):
SparseVector(4, {1: 0.0, 3: 0.5754})
DenseVector([0.0, 0.0, 1.3863, 0.863])
SparseVector(4, {1: 0.0})
+ >>> model.transform(Vectors.dense([0.0, 1.0, 2.0, 3.0]))
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform([0.0, 1.0, 2.0, 3.0])
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0)))
+ SparseVector(4, {1: 0.0, 3: 0.5754})
"""
def __init__(self, minDocFreq=0):
"""