diff options
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r-- | python/pyspark/mllib/feature.py | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 44bf6f269d..9ec28079ae 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -25,7 +25,7 @@ from py4j.protocol import Py4JJavaError from pyspark import RDD, SparkContext from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -81,12 +81,16 @@ class Normalizer(VectorTransformer): """ Applies unit length normalization on a vector. - :param vector: vector to be normalized. + :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) return callMLlibFunc("normalizeVector", self.p, vector) @@ -95,8 +99,12 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): Wrapper for the model in JVM """ - def transform(self, dataset): - return self.call("transform", dataset) + def transform(self, vector): + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return self.call("transform", vector) class StandardScalerModel(JavaVectorTransformer): @@ -109,7 +117,7 @@ class StandardScalerModel(JavaVectorTransformer): """ Applies standardization transformation on a vector. - :param vector: Vector to be standardized. + :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is zero, it will return default `0.0` for the column with zero variance. """ @@ -154,6 +162,7 @@ class StandardScaler(object): the transformation model. :return: a StandardScalarModel """ + dataset = dataset.map(_convert_to_vector) jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) return StandardScalerModel(jmodel) @@ -211,6 +220,8 @@ class IDFModel(JavaVectorTransformer): :param dataset: an RDD of term frequency vectors :return: an RDD of TF-IDF vectors """ + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") return JavaVectorTransformer.transform(self, dataset) @@ -255,7 +266,9 @@ class IDF(object): :param dataset: an RDD of term frequency vectors """ - jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset) + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector)) return IDFModel(jmodel) @@ -287,6 +300,8 @@ class Word2VecModel(JavaVectorTransformer): Note: local use only """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -374,9 +389,11 @@ class Word2Vec(object): """ Computes the vector representation of each word in vocabulary. - :param data: training data. RDD of subtype of Iterable[String] + :param data: training data. RDD of list of string :return: Word2VecModel instance """ + if not isinstance(data, RDD): + raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), long(self.seed)) |