diff options
author | lewuathe <lewuathe@me.com> | 2015-04-12 22:17:16 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-12 22:17:16 -0700 |
commit | fc17661475443d9f0a8d28e3439feeb7a7bca67b (patch) | |
tree | c76c67f64147ddcce967cab715dbb27621bc6d09 | |
parent | a1fe59dae50f551d02dd18676308eca054ff6b07 (diff) | |
download | spark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.tar.gz spark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.tar.bz2 spark-fc17661475443d9f0a8d28e3439feeb7a7bca67b.zip |
[SPARK-6643][MLLIB] Implement StandardScalerModel missing methods
This is the sub-task of SPARK-6254.
Wrap missing method for `StandardScalerModel`.
Author: lewuathe <lewuathe@me.com>
Closes #5310 from Lewuathe/SPARK-6643 and squashes the following commits:
fafd690 [lewuathe] Fix for lint-python
bd31a64 [lewuathe] Merge branch 'master' into SPARK-6643
578f5ee [lewuathe] Remove unnecessary class
a38f155 [lewuathe] Merge master
66bb2ab [lewuathe] Fix typos
82683a0 [lewuathe] [SPARK-6643] Implement StandardScalerModel missing methods
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 4 | ||||
-rw-r--r-- | python/pyspark/mllib/feature.py | 16 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 24 |
3 files changed, 42 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1faa3def0e..ab15f0f36a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -450,9 +450,9 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 3cda1205e1..8be819acee 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -132,6 +132,22 @@ class StandardScalerModel(JavaVectorTransformer): """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3b40158c12..8eaddcf8b9 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -44,6 +44,7 @@ from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -745,6 +746,29 @@ class Word2VecTests(PySparkTestCase): model = Word2Vec().fit(self.sc.parallelize(data)) self.assertEquals(len(model.getVectors()), 3) + +class StandardScalerTests(PySparkTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" |