aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/feature.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r--python/pyspark/mllib/feature.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 44bf6f269d..9ec28079ae 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -25,7 +25,7 @@ from py4j.protocol import Py4JJavaError
from pyspark import RDD, SparkContext
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg import Vectors, _convert_to_vector
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@@ -81,12 +81,16 @@ class Normalizer(VectorTransformer):
"""
Applies unit length normalization on a vector.
- :param vector: vector to be normalized.
+ :param vector: vector or RDD of vector to be normalized.
:return: normalized vector. If the norm of the input is zero, it
will return the input vector.
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
return callMLlibFunc("normalizeVector", self.p, vector)
@@ -95,8 +99,12 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
Wrapper for the model in JVM
"""
- def transform(self, dataset):
- return self.call("transform", dataset)
+ def transform(self, vector):
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
+ return self.call("transform", vector)
class StandardScalerModel(JavaVectorTransformer):
@@ -109,7 +117,7 @@ class StandardScalerModel(JavaVectorTransformer):
"""
Applies standardization transformation on a vector.
- :param vector: Vector to be standardized.
+ :param vector: Vector or RDD of Vector to be standardized.
:return: Standardized vector. If the variance of a column is zero,
it will return default `0.0` for the column with zero variance.
"""
@@ -154,6 +162,7 @@ class StandardScaler(object):
the transformation model.
:return: a StandardScalarModel
"""
+ dataset = dataset.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
return StandardScalerModel(jmodel)
@@ -211,6 +220,8 @@ class IDFModel(JavaVectorTransformer):
:param dataset: an RDD of term frequency vectors
:return: an RDD of TF-IDF vectors
"""
+ if not isinstance(dataset, RDD):
+ raise TypeError("dataset should be an RDD of term frequency vectors")
return JavaVectorTransformer.transform(self, dataset)
@@ -255,7 +266,9 @@ class IDF(object):
:param dataset: an RDD of term frequency vectors
"""
- jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset)
+ if not isinstance(dataset, RDD):
+ raise TypeError("dataset should be an RDD of term frequency vectors")
+ jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector))
return IDFModel(jmodel)
@@ -287,6 +300,8 @@ class Word2VecModel(JavaVectorTransformer):
Note: local use only
"""
+ if not isinstance(word, basestring):
+ word = _convert_to_vector(word)
words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
@@ -374,9 +389,11 @@ class Word2Vec(object):
"""
Computes the vector representation of each word in vocabulary.
- :param data: training data. RDD of subtype of Iterable[String]
+ :param data: training data. RDD of list of string
:return: Word2VecModel instance
"""
+ if not isinstance(data, RDD):
+ raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), long(self.seed))