From 8098fab06cb2be22cca4e531e8e65ab29dbb909a Mon Sep 17 00:00:00 2001 From: Yuu ISHIKAWA Date: Mon, 15 Dec 2014 13:44:15 -0800 Subject: [SPARK-4494][mllib] IDFModel.transform() add support for single vector I improved `IDFModel.transform` to allow using a single vector. [[SPARK-4494] IDFModel.transform() add support for single vector - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-4494) Author: Yuu ISHIKAWA Closes #3603 from yu-iskw/idf and squashes the following commits: 256ff3d [Yuu ISHIKAWA] Fix typo a3bf566 [Yuu ISHIKAWA] - Fix typo - Optimize import order - Aggregate the assertion tests - Modify `IDFModel.transform` API for pyspark d25e49b [Yuu ISHIKAWA] Add the implementation of `IDFModel.transform` for a term frequency vector --- python/pyspark/mllib/feature.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 8cb992df2d..741c630cbd 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -28,7 +28,7 @@ from py4j.protocol import Py4JJavaError from pyspark import RDD, SparkContext from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, _convert_to_vector +from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -212,7 +212,7 @@ class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. """ - def transform(self, dataset): + def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -220,12 +220,14 @@ class IDFModel(JavaVectorTransformer): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - :param dataset: an RDD of term frequency vectors - :return: an RDD of TF-IDF vectors + :param x: an RDD of term frequency vectors or a term frequency vector + :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if not isinstance(dataset, RDD): - raise TypeError("dataset should be an RDD of term frequency vectors") - return JavaVectorTransformer.transform(self, dataset) + if isinstance(x, RDD): + return JavaVectorTransformer.transform(self, x) + + x = _convert_to_vector(x) + return JavaVectorTransformer.transform(self, x) class IDF(object): @@ -255,6 +257,12 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) DenseVector([0.0, 0.0, 1.3863, 0.863]) SparseVector(4, {1: 0.0}) + >>> model.transform(Vectors.dense([0.0, 1.0, 2.0, 3.0])) + DenseVector([0.0, 0.0, 1.3863, 0.863]) + >>> model.transform([0.0, 1.0, 2.0, 3.0]) + DenseVector([0.0, 0.0, 1.3863, 0.863]) + >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) + SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): """ -- cgit v1.2.3