From 22732e1eca730929345e440ba831386ee7446b74 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 17 Jun 2015 22:08:38 -0700 Subject: [SPARK-7605] [MLLIB] [PYSPARK] Python API for ElementwiseProduct Python API for org.apache.spark.mllib.feature.ElementwiseProduct Author: MechCoder Closes #6346 from MechCoder/spark-7605 and squashes the following commits: 79d1ef5 [MechCoder] Consistent and support list / array types 5f81d81 [MechCoder] [SPARK-7605] [MLlib] Python API for ElementwiseProduct --- python/pyspark/mllib/tests.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'python/pyspark/mllib/tests.py') diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f4c997261e..c482e6b068 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -46,6 +46,7 @@ from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler +from pyspark.mllib.feature import ElementwiseProduct from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext @@ -850,6 +851,18 @@ class StandardScalerTests(MLlibTestCase): self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") -- cgit v1.2.3