From 1aad9114c93c5763030c14a2328f6426d9e5bcb6 Mon Sep 17 00:00:00 2001 From: Michael Giannakopoulos Date: Tue, 5 Aug 2014 16:30:32 -0700 Subject: [SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods Related to Jira Issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC) Author: Michael Giannakopoulos Closes #1775 from miccagiann/linearMethodsReg and squashes the following commits: cb774c3 [Michael Giannakopoulos] MiniBatchFraction added in related PythonMLLibAPI java stubs. 81fcbc6 [Michael Giannakopoulos] Fixing a typo-error. 8ad263e [Michael Giannakopoulos] Adding regularizer type and intercept parameters to LogisticRegressionWithSGD and SVMWithSGD. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 55 ++++++++++++++++------ 1 file changed, 40 insertions(+), 15 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1d5d3762ed..fd0b9556c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable { .setNumIterations(numIterations) .setRegParam(regParam) .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) if (regType == "l2") { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } -- cgit v1.2.3