aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-23 12:43:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-23 12:43:32 -0700
commitf2022fa0d375c804eca7803e172543b23ecbb9b7 (patch)
tree1c4c51b7950cfb4a78a6d1ae3fb944275546492a /mllib
parent2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75 (diff)
downloadspark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.gz
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.tar.bz2
spark-f2022fa0d375c804eca7803e172543b23ecbb9b7.zip
[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
It is useful to generate linear data for easy testing of linear models and in general. Scala already has it. This is just a wrapper around the Scala code. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6715 from MechCoder/generate_linear_input and squashes the following commits: 6182884 [MechCoder] Minor changes 8bda047 [MechCoder] Minor style fixes 0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala32
1 files changed, 31 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f9a271f47e..c4bea7c2ca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
@@ -972,7 +973,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def estimateKernelDensity(
sample: JavaRDD[Double],
bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
- return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
+ new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
points.asScala.toArray)
}
@@ -991,6 +992,35 @@ private[python] class PythonMLLibAPI extends Serializable {
List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
}
+ /**
+ * Wrapper around the generateLinearInput method of LinearDataGenerator.
+ */
+ def generateLinearInputWrapper(
+ intercept: Double,
+ weights: JList[Double],
+ xMean: JList[Double],
+ xVariance: JList[Double],
+ nPoints: Int,
+ seed: Int,
+ eps: Double): Array[LabeledPoint] = {
+ LinearDataGenerator.generateLinearInput(
+ intercept, weights.asScala.toArray, xMean.asScala.toArray,
+ xVariance.asScala.toArray, nPoints, seed, eps).toArray
+ }
+
+ /**
+ * Wrapper around the generateLinearRDD method of LinearDataGenerator.
+ */
+ def generateLinearRDDWrapper(
+ sc: JavaSparkContext,
+ nexamples: Int,
+ nfeatures: Int,
+ eps: Double,
+ nparts: Int,
+ intercept: Double): JavaRDD[LabeledPoint] = {
+ LinearDataGenerator.generateLinearRDD(
+ sc, nexamples, nfeatures, eps, nparts, intercept)
+ }
}
/**