diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-06-17 22:08:38 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-17 22:08:38 -0700 |
commit | 22732e1eca730929345e440ba831386ee7446b74 (patch) | |
tree | bdd891792462222c50ebd324efb154980c320d5a /python | |
parent | 4817ccdf50ef6ee24192800f9924d9ef3bb74e12 (diff) | |
download | spark-22732e1eca730929345e440ba831386ee7446b74.tar.gz spark-22732e1eca730929345e440ba831386ee7446b74.tar.bz2 spark-22732e1eca730929345e440ba831386ee7446b74.zip |
[SPARK-7605] [MLLIB] [PYSPARK] Python API for ElementwiseProduct
Python API for org.apache.spark.mllib.feature.ElementwiseProduct
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #6346 from MechCoder/spark-7605 and squashes the following commits:
79d1ef5 [MechCoder] Consistent and support list / array types
5f81d81 [MechCoder] [SPARK-7605] [MLlib] Python API for ElementwiseProduct
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/feature.py | 37 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 13 |
2 files changed, 48 insertions, 2 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index da90554f41..cf5fdf2cf9 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -33,12 +33,13 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.linalg import ( + Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', - 'ChiSqSelector', 'ChiSqSelectorModel'] + 'ChiSqSelector', 'ChiSqSelectorModel', 'ElementwiseProduct'] class VectorTransformer(object): @@ -520,6 +521,38 @@ class Word2Vec(object): return Word2VecModel(jmodel) +class ElementwiseProduct(VectorTransformer): + """ + .. note:: Experimental + + Scales each column of the vector, with the supplied weight vector. + i.e the elementwise product. + + >>> weight = Vectors.dense([1.0, 2.0, 3.0]) + >>> eprod = ElementwiseProduct(weight) + >>> a = Vectors.dense([2.0, 1.0, 3.0]) + >>> eprod.transform(a) + DenseVector([2.0, 2.0, 9.0]) + >>> b = Vectors.dense([9.0, 3.0, 4.0]) + >>> rdd = sc.parallelize([a, b]) + >>> eprod.transform(rdd).collect() + [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + """ + def __init__(self, scalingVector): + self.scalingVector = _convert_to_vector(scalingVector) + + def transform(self, vector): + """ + Computes the Hadamard product of the vector. + """ + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f4c997261e..c482e6b068 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -46,6 +46,7 @@ from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler +from pyspark.mllib.feature import ElementwiseProduct from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext @@ -850,6 +851,18 @@ class StandardScalerTests(MLlibTestCase): self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") |