diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-14 18:13:58 -0700 |
commit | 723853edab18d28515af22097b76e4e6574b228e (patch) | |
tree | 205a3c30104da6d6784cf68cdc6424fc6e76540f | |
parent | b208f998b5800bdba4ce6651f172c26a8d7d351b (diff) | |
download | spark-723853edab18d28515af22097b76e4e6574b228e.tar.gz spark-723853edab18d28515af22097b76e4e6574b228e.tar.bz2 spark-723853edab18d28515af22097b76e4e6574b228e.zip |
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz
Author: Xiangrui Meng <meng@databricks.com>
Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:
1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
-rw-r--r-- | python/pyspark/ml/classification.py | 18 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 18 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 8 |
3 files changed, 43 insertions, 1 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 96d29058a3..8c9a55e79a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeClassifierParams(object): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c3d2..2803864ff4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a3986..dda6c6aba3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer): def _java_obj(self): return self._java_model + def _call_java(self, name, *args): + m = getattr(self._java_model, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): |