aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-05 09:56:18 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-05 09:56:18 -0800
commit9da7ceed81b0afce7deb8f39f3a6d565d401a391 (patch)
treee6c70d2cc92cfdfd48ee366ad6230f9b87795230 /python/pyspark/ml/regression.py
parentb072ff4d1d05fc212cd7036d1897a032a395f0b3 (diff)
downloadspark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.tar.gz
spark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.tar.bz2
spark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.zip
[SPARK-11473][ML] R-like summary statistics with intercept for OLS via normal equation solver
Follow up [SPARK-9836](https://issues.apache.org/jira/browse/SPARK-9836), we should also support summary statistics for ```intercept```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9485 from yanboliang/spark-11473.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index d7b4fd92c3..7648bf1326 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -55,15 +55,15 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
>>> model = lr.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
- >>> model.transform(test0).head().prediction
- -1.0
- >>> model.weights
- DenseVector([1.0])
- >>> model.intercept
- 0.0
+ >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
+ True
+ >>> abs(model.coefficients[0] - 1.0) < 0.001
+ True
+ >>> abs(model.intercept - 0.0) < 0.001
+ True
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
- >>> model.transform(test1).head().prediction
- 1.0
+ >>> abs(model.transform(test1).head().prediction - 1.0) < 0.001
+ True
>>> lr.setParams("vector")
Traceback (most recent call last):
...