diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
commit | 723853edab18d28515af22097b76e4e6574b228e (patch) | |
tree | 205a3c30104da6d6784cf68cdc6424fc6e76540f /python/pyspark/ml/regression.py | |
parent | b208f998b5800bdba4ce6651f172c26a8d7d351b (diff) | |
download | spark-723853edab18d28515af22097b76e4e6574b228e.tar.gz spark-723853edab18d28515af22097b76e4e6574b228e.tar.bz2 spark-723853edab18d28515af22097b76e4e6574b228e.zip |
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz
Author: Xiangrui Meng <meng@databricks.com>
Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:
1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c3d2..2803864ff4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ |