aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-04-12 22:17:16 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-12 22:17:16 -0700
commitfc17661475443d9f0a8d28e3439feeb7a7bca67b (patch)
treec76c67f64147ddcce967cab715dbb27621bc6d09 /python
parenta1fe59dae50f551d02dd18676308eca054ff6b07 (diff)
downloadspark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.tar.gz
spark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.tar.bz2
spark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.zip
[SPARK-6643][MLLIB] Implement StandardScalerModel missing methods
This is the sub-task of SPARK-6254. Wrap missing method for `StandardScalerModel`. Author: lewuathe <lewuathe@me.com> Closes #5310 from Lewuathe/SPARK-6643 and squashes the following commits: fafd690 [lewuathe] Fix for lint-python bd31a64 [lewuathe] Merge branch 'master' into SPARK-6643 578f5ee [lewuathe] Remove unnecessary class a38f155 [lewuathe] Merge master 66bb2ab [lewuathe] Fix typos 82683a0 [lewuathe] [SPARK-6643] Implement StandardScalerModel missing methods
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/feature.py16
-rw-r--r--python/pyspark/mllib/tests.py24
2 files changed, 40 insertions, 0 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 3cda1205e1..8be819acee 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -132,6 +132,22 @@ class StandardScalerModel(JavaVectorTransformer):
"""
return JavaVectorTransformer.transform(self, vector)
+ def setWithMean(self, withMean):
+ """
+ Setter of the boolean which decides
+ whether it uses mean or not
+ """
+ self.call("setWithMean", withMean)
+ return self
+
+ def setWithStd(self, withStd):
+ """
+ Setter of the boolean which decides
+ whether it uses std or not
+ """
+ self.call("setWithStd", withStd)
+ return self
+
class StandardScaler(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 3b40158c12..8eaddcf8b9 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -44,6 +44,7 @@ from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
+from pyspark.mllib.feature import StandardScaler
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -745,6 +746,29 @@ class Word2VecTests(PySparkTestCase):
model = Word2Vec().fit(self.sc.parallelize(data))
self.assertEquals(len(model.getVectors()), 3)
+
+class StandardScalerTests(PySparkTestCase):
+ def test_model_setters(self):
+ data = [
+ [1.0, 2.0, 3.0],
+ [2.0, 3.0, 4.0],
+ [3.0, 4.0, 5.0]
+ ]
+ model = StandardScaler().fit(self.sc.parallelize(data))
+ self.assertIsNotNone(model.setWithMean(True))
+ self.assertIsNotNone(model.setWithStd(True))
+ self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0]))
+
+ def test_model_transform(self):
+ data = [
+ [1.0, 2.0, 3.0],
+ [2.0, 3.0, 4.0],
+ [3.0, 4.0, 5.0]
+ ]
+ model = StandardScaler().fit(self.sc.parallelize(data))
+ self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
+
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"