diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
commit | 723853edab18d28515af22097b76e4e6574b228e (patch) | |
tree | 205a3c30104da6d6784cf68cdc6424fc6e76540f /python/pyspark/ml/classification.py | |
parent | b208f998b5800bdba4ce6651f172c26a8d7d351b (diff) | |
download | spark-723853edab18d28515af22097b76e4e6574b228e.tar.gz spark-723853edab18d28515af22097b76e4e6574b228e.tar.bz2 spark-723853edab18d28515af22097b76e4e6574b228e.zip |
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz
Author: Xiangrui Meng <meng@databricks.com>
Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:
1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 96d29058a3..8c9a55e79a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeClassifierParams(object): """ |