aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-17 22:08:38 -0700
committerDavies Liu <davies@databricks.com>2015-06-17 22:08:38 -0700
commit22732e1eca730929345e440ba831386ee7446b74 (patch)
treebdd891792462222c50ebd324efb154980c320d5a /python/pyspark/mllib/tests.py
parent4817ccdf50ef6ee24192800f9924d9ef3bb74e12 (diff)
downloadspark-22732e1eca730929345e440ba831386ee7446b74.tar.gz
spark-22732e1eca730929345e440ba831386ee7446b74.tar.bz2
spark-22732e1eca730929345e440ba831386ee7446b74.zip
[SPARK-7605] [MLLIB] [PYSPARK] Python API for ElementwiseProduct
Python API for org.apache.spark.mllib.feature.ElementwiseProduct Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6346 from MechCoder/spark-7605 and squashes the following commits: 79d1ef5 [MechCoder] Consistent and support list / array types 5f81d81 [MechCoder] [SPARK-7605] [MLlib] Python API for ElementwiseProduct
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f4c997261e..c482e6b068 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -46,6 +46,7 @@ from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler
+from pyspark.mllib.feature import ElementwiseProduct
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
@@ -850,6 +851,18 @@ class StandardScalerTests(MLlibTestCase):
self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
+class ElementwiseProductTests(MLlibTestCase):
+ def test_model_transform(self):
+ weight = Vectors.dense([3, 2, 1])
+
+ densevec = Vectors.dense([4, 5, 6])
+ sparsevec = Vectors.sparse(3, [0], [1])
+ eprod = ElementwiseProduct(weight)
+ self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6]))
+ self.assertEqual(
+ eprod.transform(sparsevec), SparseVector(3, [0], [3]))
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")