aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-14 18:13:58 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-14 18:13:58 -0700
commit723853edab18d28515af22097b76e4e6574b228e (patch)
tree205a3c30104da6d6784cf68cdc6424fc6e76540f /python/pyspark/ml/regression.py
parentb208f998b5800bdba4ce6651f172c26a8d7d351b (diff)
downloadspark-723853edab18d28515af22097b76e4e6574b228e.tar.gz
spark-723853edab18d28515af22097b76e4e6574b228e.tar.bz2
spark-723853edab18d28515af22097b76e4e6574b228e.zip
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz Author: Xiangrui Meng <meng@databricks.com> Closes #6156 from mengxr/SPARK-7647 and squashes the following commits: 1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 0ab5c6c3d2..2803864ff4 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
-1.0
+ >>> model.weights
+ DenseVector([1.0])
+ >>> model.intercept
+ 0.0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
Model fitted by LinearRegression.
"""
+ @property
+ def weights(self):
+ """
+ Model weights.
+ """
+ return self._call_java("weights")
+
+ @property
+ def intercept(self):
+ """
+ Model intercept.
+ """
+ return self._call_java("intercept")
+
class TreeRegressorParams(object):
"""