aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-08-22 12:21:22 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-08-22 12:21:22 +0200
commitb264cbb16fb97116e630fb593adf5898a5a0e8fa (patch)
tree056d664ec1ea34ee88489b2b8c6bbc0dc43b8c03 /python/pyspark/ml/classification.py
parentbd9655063bdba8836b4ec96ed115e5653e246b65 (diff)
downloadspark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.gz
spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.tar.bz2
spark-b264cbb16fb97116e630fb593adf5898a5a0e8fa.zip
[SPARK-15113][PYSPARK][ML] Add missing num features num classes
## What changes were proposed in this pull request? Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc. ## How was this patch tested? Extended doctests Author: Holden Karau <holden@us.ibm.com> Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py37
1 files changed, 31 insertions, 6 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 6468007045..33ada27454 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -44,6 +44,23 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
@inherit_doc
+class JavaClassificationModel(JavaPredictionModel):
+ """
+ (Private) Java Model produced by a ``Classifier``.
+ Classes are indexed {0, 1, ..., numClasses - 1}.
+ To be mixed in with class:`pyspark.ml.JavaModel`
+ """
+
+ @property
+ @since("2.1.0")
+ def numClasses(self):
+ """
+ Number of classes (values which the label can take).
+ """
+ return self._call_java("numClasses")
+
+
+@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
@@ -212,7 +229,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
-class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
@@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
+ >>> model.numClasses
+ 2
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -595,7 +616,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@inherit_doc
-class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
+class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.
@@ -722,7 +744,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return RandomForestClassificationModel(java_model)
-class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
@@ -873,7 +896,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
return self.getOrDefault(self.lossType)
-class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by GBTClassifier.
@@ -1027,7 +1051,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
return self.getOrDefault(self.modelType)
-class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
@@ -1226,7 +1250,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
return self.getOrDefault(self.initialWeights)
-class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
.. note:: Experimental