diff options
author | Holden Karau <holden@us.ibm.com> | 2016-08-22 12:21:22 +0200 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2016-08-22 12:21:22 +0200 |
commit | b264cbb16fb97116e630fb593adf5898a5a0e8fa (patch) | |
tree | 056d664ec1ea34ee88489b2b8c6bbc0dc43b8c03 /python/pyspark/ml/regression.py | |
parent | bd9655063bdba8836b4ec96ed115e5653e246b65 (diff) | |
download | spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.gz spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.bz2 spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.zip |
[SPARK-15113][PYSPARK][ML] Add missing num features num classes
## What changes were proposed in this pull request?
Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc.
## How was this patch tested?
Extended doctests
Author: Holden Karau <holden@us.ibm.com>
Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1ae2bd4e40..56312f672f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.intercept == model2.intercept True + >>> model.numFeatures + 1 .. versionadded:: 1.4.0 """ @@ -126,7 +128,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. @@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 3 >>> model.featureImportances SparseVector(1, {0: 1.0}) + >>> model.numFeatures + 1 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -719,7 +723,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @inherit_doc -class DecisionTreeModel(JavaModel): +class DecisionTreeModel(JavaModel, JavaPredictionModel): """ Abstraction for Decision Tree models. @@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 + >>> model.numFeatures + 1 >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] >>> model.getNumTrees @@ -909,7 +915,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return RandomForestRegressionModel(java_model) -class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): +class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, + JavaMLReadable): """ Model fitted by :class:`RandomForestRegressor`. @@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) + >>> model.numFeatures + 1 >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -1047,7 +1056,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable): +class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`GBTRegressor`. @@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha True >>> model.coefficients DenseVector([1.5..., -1.0...]) + >>> model.numFeatures + 2 >>> abs(model.intercept - 1.5) < 0.001 True >>> glr_path = temp_path + "/glr" @@ -1412,7 +1423,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha return self.getOrDefault(self.link) -class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): +class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, + JavaMLReadable): """ .. note:: Experimental |