aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-08-22 12:21:22 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-08-22 12:21:22 +0200
commitb264cbb16fb97116e630fb593adf5898a5a0e8fa (patch)
tree056d664ec1ea34ee88489b2b8c6bbc0dc43b8c03 /python/pyspark/ml/regression.py
parentbd9655063bdba8836b4ec96ed115e5653e246b65 (diff)
downloadspark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.gz
spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.bz2
spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.zip
[SPARK-15113][PYSPARK][ML] Add missing num features num classes
## What changes were proposed in this pull request? Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc. ## How was this patch tested? Extended doctests Author: Holden Karau <holden@us.ibm.com> Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1ae2bd4e40..56312f672f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
True
>>> model.intercept == model2.intercept
True
+ >>> model.numFeatures
+ 1
.. versionadded:: 1.4.0
"""
@@ -126,7 +128,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`LinearRegression`.
@@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
3
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -719,7 +723,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@inherit_doc
-class DecisionTreeModel(JavaModel):
+class DecisionTreeModel(JavaModel, JavaPredictionModel):
"""
Abstraction for Decision Tree models.
@@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
+ >>> model.numFeatures
+ 1
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
@@ -909,7 +915,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return RandomForestRegressionModel(java_model)
-class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by :class:`RandomForestRegressor`.
@@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -1047,7 +1056,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
return self.getOrDefault(self.lossType)
-class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`GBTRegressor`.
@@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
+ >>> model.numFeatures
+ 2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
@@ -1412,7 +1423,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
return self.getOrDefault(self.link)
-class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
.. note:: Experimental