aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-23 12:43:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-23 12:43:32 -0700
commitf2022fa0d375c804eca7803e172543b23ecbb9b7 (patch)
tree1c4c51b7950cfb4a78a6d1ae3fb944275546492a /python
parent2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75 (diff)
downloadspark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.gz
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.bz2
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.zip
[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
It is useful to generate linear data for easy testing of linear models and in general. Scala already has it. This is just a wrapper around the Scala code. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6715 from MechCoder/generate_linear_input and squashes the following commits: 6182884 [MechCoder] Minor changes 8bda047 [MechCoder] Minor style fixes 0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tests.py22
-rw-r--r--python/pyspark/mllib/util.py35
2 files changed, 55 insertions, 2 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index c8d61b9855..509faa11df 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -49,8 +49,8 @@ from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
-from pyspark.mllib.feature import StandardScaler
-from pyspark.mllib.feature import ElementwiseProduct
+from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
+from pyspark.mllib.util import LinearDataGenerator
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
@@ -1019,6 +1019,24 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
+class LinearDataGeneratorTests(MLlibTestCase):
+ def test_dim(self):
+ linear_data = LinearDataGenerator.generateLinearInput(
+ intercept=0.0, weights=[0.0, 0.0, 0.0],
+ xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
+ nPoints=4, seed=0, eps=0.1)
+ self.assertEqual(len(linear_data), 4)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 3)
+
+ linear_data = LinearDataGenerator.generateLinearRDD(
+ sc=sc, nexamples=6, nfeatures=2, eps=0.1,
+ nParts=2, intercept=0.0).collect()
+ self.assertEqual(len(linear_data), 6)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 2)
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 16a90db146..348238319e 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -257,6 +257,41 @@ class JavaLoader(Loader):
return cls(java_model)
+class LinearDataGenerator(object):
+ """Utils for generating linear data"""
+
+ @staticmethod
+ def generateLinearInput(intercept, weights, xMean, xVariance,
+ nPoints, seed, eps):
+ """
+ :param: intercept bias factor, the term c in X'w + c
+ :param: weights feature vector, the term w in X'w + c
+ :param: xMean Point around which the data X is centered.
+ :param: xVariance Variance of the given data
+ :param: nPoints Number of points to be generated
+ :param: seed Random Seed
+ :param: eps Used to scale the noise. If eps is set high,
+ the amount of gaussian noise added is more.
+ Returns a list of LabeledPoints of length nPoints
+ """
+ weights = [float(weight) for weight in weights]
+ xMean = [float(mean) for mean in xMean]
+ xVariance = [float(var) for var in xVariance]
+ return list(callMLlibFunc(
+ "generateLinearInputWrapper", float(intercept), weights, xMean,
+ xVariance, int(nPoints), int(seed), float(eps)))
+
+ @staticmethod
+ def generateLinearRDD(sc, nexamples, nfeatures, eps,
+ nParts=2, intercept=0.0):
+ """
+ Generate a RDD of LabeledPoints.
+ """
+ return callMLlibFunc(
+ "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
+ float(eps), int(nParts), float(intercept))
+
+
def _test():
import doctest
from pyspark.context import SparkContext