From 723853edab18d28515af22097b76e4e6574b228e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 14 May 2015 18:13:58 -0700 Subject: [SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml Otherwise, users can only use `transform` on the models. brkyvz Author: Xiangrui Meng Closes #6156 from mengxr/SPARK-7647 and squashes the following commits: 1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel --- python/pyspark/ml/classification.py | 18 ++++++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 +++++++- 3 files changed, 43 insertions(+), 1 deletion(-) (limited to 'python/pyspark/ml') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 96d29058a3..8c9a55e79a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeClassifierParams(object): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c3d2..2803864ff4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a3986..dda6c6aba3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer): def _java_obj(self): return self._java_model + def _call_java(self, name, *args): + m = getattr(self._java_model, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): -- cgit v1.2.3