aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMichael Giannakopoulos <miccagiann@gmail.com>2014-08-05 16:30:32 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-05 16:30:32 -0700
commit1aad9114c93c5763030c14a2328f6426d9e5bcb6 (patch)
tree8a8085d64428993c23961042c8b430baaa61b204 /mllib
parentacff9a7f13b98f10a08aea1d11cfa685c3419367 (diff)
downloadspark-1aad9114c93c5763030c14a2328f6426d9e5bcb6.tar.gz
spark-1aad9114c93c5763030c14a2328f6426d9e5bcb6.tar.bz2
spark-1aad9114c93c5763030c14a2328f6426d9e5bcb6.zip
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods
Related to Jira Issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC) Author: Michael Giannakopoulos <miccagiann@gmail.com> Closes #1775 from miccagiann/linearMethodsReg and squashes the following commits: cb774c3 [Michael Giannakopoulos] MiniBatchFraction added in related PythonMLLibAPI java stubs. 81fcbc6 [Michael Giannakopoulos] Fixing a typo-error. 8ad263e [Michael Giannakopoulos] Adding regularizer type and intercept parameters to LogisticRegressionWithSGD and SVMWithSGD.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala55
1 files changed, 40 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 1d5d3762ed..fd0b9556c7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable {
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
@@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeightsBA: Array[Byte],
+ regType: String,
+ intercept: Boolean): java.util.List[java.lang.Object] = {
+ val SVMAlg = new SVMWithSGD()
+ SVMAlg.setIntercept(intercept)
+ SVMAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
+ if (regType == "l2") {
+ SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ SVMAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType != "none") {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: [l1, l2, none].")
+ }
trainRegressionModel(
(data, initialWeights) =>
- SVMWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ SVMAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeightsBA: Array[Byte],
+ regParam: Double,
+ regType: String,
+ intercept: Boolean): java.util.List[java.lang.Object] = {
+ val LogRegAlg = new LogisticRegressionWithSGD()
+ LogRegAlg.setIntercept(intercept)
+ LogRegAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
+ if (regType == "l2") {
+ LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ LogRegAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType != "none") {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: [l1, l2, none].")
+ }
trainRegressionModel(
(data, initialWeights) =>
- LogisticRegressionWithSGD.train(
- data,
- numIterations,
- stepSize,
- miniBatchFraction,
- initialWeights),
+ LogRegAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}