aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMichael Giannakopoulos <miccagiann@gmail.com>2014-08-01 21:00:31 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-01 21:00:31 -0700
commitc281189222e645d2c87277c269e2102c3c8ccc95 (patch)
treee56b1d46896433d8c859b68870807e0faa0cbd64 /mllib
parentf6a1899306c5ad766fea122d3ab4b83436d9f6fd (diff)
downloadspark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.gz
spark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.bz2
spark-c281189222e645d2c87277c269e2102c3c8ccc95.zip
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods.
Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC). Author: Michael Giannakopoulos <miccagiann@gmail.com> Closes #1624 from miccagiann/new-branch and squashes the following commits: c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master. 8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets. fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException. 44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD. 8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one. 638be47 [Michael Giannakopoulos] Modified code to comply with code standards. ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter. 78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function. 3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala28
1 files changed, 21 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 122925d096..7d912737b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
@@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeightsBA: Array[Byte],
+ regParam: Double,
+ regType: String,
+ intercept: Boolean): java.util.List[java.lang.Object] = {
+ val lrAlg = new LinearRegressionWithSGD()
+ lrAlg.setIntercept(intercept)
+ lrAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ if (regType == "l2") {
+ lrAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ lrAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType != "none") {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: [l1, l2, none].")
+ }
trainRegressionModel(
(data, initialWeights) =>
- LinearRegressionWithSGD.train(
- data,
- numIterations,
- stepSize,
- miniBatchFraction,
- initialWeights),
+ lrAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}