aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
commit9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch)
tree2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /python/pyspark/ml/classification.py
parent56ede88485cfca90974425fcb603b257be47229b (diff)
downloadspark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py35
1 files changed, 20 insertions, 15 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1411d3fd9c..4e645519c4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -55,7 +55,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
...
TypeError: Method setParams forces keyword arguments.
"""
- _java_class = "org.apache.spark.ml.classification.LogisticRegression"
+
# a placeholder to make it appear in the generated doc
elasticNetParam = \
Param(Params._dummy(), "elasticNetParam",
@@ -75,6 +75,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
threshold=0.5, probabilityCol="probability")
"""
super(LogisticRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.LogisticRegression", self.uid)
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
self.elasticNetParam = \
@@ -111,7 +113,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`elasticNetParam`.
"""
- self.paramMap[self.elasticNetParam] = value
+ self._paramMap[self.elasticNetParam] = value
return self
def getElasticNetParam(self):
@@ -124,7 +126,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`fitIntercept`.
"""
- self.paramMap[self.fitIntercept] = value
+ self._paramMap[self.fitIntercept] = value
return self
def getFitIntercept(self):
@@ -137,7 +139,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`threshold`.
"""
- self.paramMap[self.threshold] = value
+ self._paramMap[self.threshold] = value
return self
def getThreshold(self):
@@ -208,7 +210,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- _java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -224,6 +225,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
"""
super(DecisionTreeClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -256,7 +259,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -299,7 +302,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- _java_class = "org.apache.spark.ml.classification.RandomForestClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -325,6 +327,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
numTrees=20, featureSubsetStrategy="auto", seed=42)
"""
super(RandomForestClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -370,7 +374,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -383,7 +387,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -396,7 +400,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`numTrees`.
"""
- self.paramMap[self.numTrees] = value
+ self._paramMap[self.numTrees] = value
return self
def getNumTrees(self):
@@ -409,7 +413,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`featureSubsetStrategy`.
"""
- self.paramMap[self.featureSubsetStrategy] = value
+ self._paramMap[self.featureSubsetStrategy] = value
return self
def getFeatureSubsetStrategy(self):
@@ -452,7 +456,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
1.0
"""
- _java_class = "org.apache.spark.ml.classification.GBTClassifier"
# a placeholder to make it appear in the generated doc
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -476,6 +479,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
lossType="logistic", maxIter=20, stepSize=0.1)
"""
super(GBTClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.GBTClassifier", self.uid)
#: param for Loss function which GBT tries to minimize (case-insensitive).
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -517,7 +522,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`lossType`.
"""
- self.paramMap[self.lossType] = value
+ self._paramMap[self.lossType] = value
return self
def getLossType(self):
@@ -530,7 +535,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -543,7 +548,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`stepSize`.
"""
- self.paramMap[self.stepSize] = value
+ self._paramMap[self.stepSize] = value
return self
def getStepSize(self):