diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-06-23 12:43:32 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-06-23 12:43:32 -0700 |
commit | f2022fa0d375c804eca7803e172543b23ecbb9b7 (patch) | |
tree | 1c4c51b7950cfb4a78a6d1ae3fb944275546492a /python/pyspark/mllib/tests.py | |
parent | 2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75 (diff) | |
download | spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.gz spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.bz2 spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.zip |
[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
It is useful to generate linear data for easy testing of linear models and in general. Scala already has it. This is just a wrapper around the Scala code.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #6715 from MechCoder/generate_linear_input and squashes the following commits:
6182884 [MechCoder] Minor changes
8bda047 [MechCoder] Minor style fixes
0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c8d61b9855..509faa11df 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -49,8 +49,8 @@ from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler -from pyspark.mllib.feature import ElementwiseProduct +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct +from pyspark.mllib.util import LinearDataGenerator from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext @@ -1019,6 +1019,24 @@ class StreamingKMeansTest(MLLibStreamingTestCase): self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") |