aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f4c997261e..c482e6b068 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -46,6 +46,7 @@ from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler
+from pyspark.mllib.feature import ElementwiseProduct
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
@@ -850,6 +851,18 @@ class StandardScalerTests(MLlibTestCase):
self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
+class ElementwiseProductTests(MLlibTestCase):
+ def test_model_transform(self):
+ weight = Vectors.dense([3, 2, 1])
+
+ densevec = Vectors.dense([4, 5, 6])
+ sparsevec = Vectors.sparse(3, [0], [1])
+ eprod = ElementwiseProduct(weight)
+ self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6]))
+ self.assertEqual(
+ eprod.transform(sparsevec), SparseVector(3, [0], [3]))
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")