aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-23 12:43:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-23 12:43:32 -0700
commitf2022fa0d375c804eca7803e172543b23ecbb9b7 (patch)
tree1c4c51b7950cfb4a78a6d1ae3fb944275546492a /python/pyspark/mllib/tests.py
parent2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75 (diff)
downloadspark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.gz
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.bz2
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.zip
[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
It is useful to generate linear data for easy testing of linear models and in general. Scala already has it. This is just a wrapper around the Scala code. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6715 from MechCoder/generate_linear_input and squashes the following commits: 6182884 [MechCoder] Minor changes 8bda047 [MechCoder] Minor style fixes 0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index c8d61b9855..509faa11df 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -49,8 +49,8 @@ from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
-from pyspark.mllib.feature import StandardScaler
-from pyspark.mllib.feature import ElementwiseProduct
+from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
+from pyspark.mllib.util import LinearDataGenerator
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
@@ -1019,6 +1019,24 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
+class LinearDataGeneratorTests(MLlibTestCase):
+ def test_dim(self):
+ linear_data = LinearDataGenerator.generateLinearInput(
+ intercept=0.0, weights=[0.0, 0.0, 0.0],
+ xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
+ nPoints=4, seed=0, eps=0.1)
+ self.assertEqual(len(linear_data), 4)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 3)
+
+ linear_data = LinearDataGenerator.generateLinearRDD(
+ sc=sc, nexamples=6, nfeatures=2, eps=0.1,
+ nParts=2, intercept=0.0).collect()
+ self.assertEqual(len(linear_data), 6)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 2)
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")