aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-25 13:38:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-03-25 13:38:33 -0700
commit435337381f093f95248c8f0204e60c0b366edc81 (patch)
tree60e820022e209f5bc5771d71d0a147bdd391c545 /mllib/src
parentc1b74df6042b33b2b061cb07c2fbd82dba9074bb (diff)
downloadspark-435337381f093f95248c8f0204e60c0b366edc81.tar.gz
spark-435337381f093f95248c8f0204e60c0b366edc81.tar.bz2
spark-435337381f093f95248c8f0204e60c0b366edc81.zip
[SPARK-6256] [MLlib] MLlib Python API parity check for regression
MLlib Python API parity check for Regression, major disparities need to be added for Python list following: ```scala LinearRegressionWithSGD setValidateData LassoWithSGD setIntercept setValidateData RidgeRegressionWithSGD setIntercept setValidateData ``` setFeatureScaling is mllib private function which is not needed to expose in pyspark. Author: Yanbo Liang <ybliang8@gmail.com> Closes #4997 from yanboliang/spark-6256 and squashes the following commits: 102f498 [Yanbo Liang] fix intercept issue & add doc test 1fb7b4f [Yanbo Liang] change 'intercept' to 'addIntercept' de5ecbc [Yanbo Liang] MLlib Python API parity check for regression
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala16
1 files changed, 13 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 15ca2547d5..e391567347 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -111,9 +111,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lrAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -135,8 +137,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lassoAlg = new LassoWithSGD()
+ lassoAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lassoAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -157,8 +163,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
+ ridgeAlg.setIntercept(intercept)
+ .setValidateData(validateData)
ridgeAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)