aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
commit9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch)
tree2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /python
parent56ede88485cfca90974425fcb603b257be47229b (diff)
downloadspark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py35
-rw-r--r--python/pyspark/ml/evaluation.py6
-rw-r--r--python/pyspark/ml/feature.py91
-rw-r--r--python/pyspark/ml/param/__init__.py118
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py2
-rw-r--r--python/pyspark/ml/param/shared.py42
-rw-r--r--python/pyspark/ml/pipeline.py109
-rw-r--r--python/pyspark/ml/recommendation.py25
-rw-r--r--python/pyspark/ml/regression.py30
-rw-r--r--python/pyspark/ml/tests.py105
-rw-r--r--python/pyspark/ml/tuning.py43
-rw-r--r--python/pyspark/ml/util.py13
-rw-r--r--python/pyspark/ml/wrapper.py125
13 files changed, 490 insertions, 254 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1411d3fd9c..4e645519c4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -55,7 +55,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
...
TypeError: Method setParams forces keyword arguments.
"""
- _java_class = "org.apache.spark.ml.classification.LogisticRegression"
+
# a placeholder to make it appear in the generated doc
elasticNetParam = \
Param(Params._dummy(), "elasticNetParam",
@@ -75,6 +75,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
threshold=0.5, probabilityCol="probability")
"""
super(LogisticRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.LogisticRegression", self.uid)
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
self.elasticNetParam = \
@@ -111,7 +113,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`elasticNetParam`.
"""
- self.paramMap[self.elasticNetParam] = value
+ self._paramMap[self.elasticNetParam] = value
return self
def getElasticNetParam(self):
@@ -124,7 +126,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`fitIntercept`.
"""
- self.paramMap[self.fitIntercept] = value
+ self._paramMap[self.fitIntercept] = value
return self
def getFitIntercept(self):
@@ -137,7 +139,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
"""
Sets the value of :py:attr:`threshold`.
"""
- self.paramMap[self.threshold] = value
+ self._paramMap[self.threshold] = value
return self
def getThreshold(self):
@@ -208,7 +210,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- _java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -224,6 +225,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
"""
super(DecisionTreeClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -256,7 +259,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -299,7 +302,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""
- _java_class = "org.apache.spark.ml.classification.RandomForestClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -325,6 +327,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
numTrees=20, featureSubsetStrategy="auto", seed=42)
"""
super(RandomForestClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -370,7 +374,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -383,7 +387,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -396,7 +400,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`numTrees`.
"""
- self.paramMap[self.numTrees] = value
+ self._paramMap[self.numTrees] = value
return self
def getNumTrees(self):
@@ -409,7 +413,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"""
Sets the value of :py:attr:`featureSubsetStrategy`.
"""
- self.paramMap[self.featureSubsetStrategy] = value
+ self._paramMap[self.featureSubsetStrategy] = value
return self
def getFeatureSubsetStrategy(self):
@@ -452,7 +456,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
1.0
"""
- _java_class = "org.apache.spark.ml.classification.GBTClassifier"
# a placeholder to make it appear in the generated doc
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -476,6 +479,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
lossType="logistic", maxIter=20, stepSize=0.1)
"""
super(GBTClassifier, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.GBTClassifier", self.uid)
#: param for Loss function which GBT tries to minimize (case-insensitive).
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -517,7 +522,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`lossType`.
"""
- self.paramMap[self.lossType] = value
+ self._paramMap[self.lossType] = value
return self
def getLossType(self):
@@ -530,7 +535,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -543,7 +548,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
"""
Sets the value of :py:attr:`stepSize`.
"""
- self.paramMap[self.stepSize] = value
+ self._paramMap[self.stepSize] = value
return self
def getStepSize(self):
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 02020ebff9..f4655c513c 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -42,8 +42,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
0.83...
"""
- _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator"
-
# a placeholder to make it appear in the generated doc
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")
@@ -56,6 +54,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
metricName="areaUnderROC")
"""
super(BinaryClassificationEvaluator, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
#: param for metric name in evaluation (areaUnderROC|areaUnderPR)
self.metricName = Param(self, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")
@@ -68,7 +68,7 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
"""
Sets the value of :py:attr:`metricName`.
"""
- self.paramMap[self.metricName] = value
+ self._paramMap[self.metricName] = value
return self
def getMetricName(self):
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 58e22190c7..c8115cb5bc 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -43,7 +43,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
1.0
"""
- _java_class = "org.apache.spark.ml.feature.Binarizer"
# a placeholder to make it appear in the generated doc
threshold = Param(Params._dummy(), "threshold",
"threshold in binary classification prediction, in range [0, 1]")
@@ -54,6 +53,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
__init__(self, threshold=0.0, inputCol=None, outputCol=None)
"""
super(Binarizer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid)
self.threshold = Param(self, "threshold",
"threshold in binary classification prediction, in range [0, 1]")
self._setDefault(threshold=0.0)
@@ -73,7 +73,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`threshold`.
"""
- self.paramMap[self.threshold] = value
+ self._paramMap[self.threshold] = value
return self
def getThreshold(self):
@@ -104,7 +104,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
0.0
"""
- _java_class = "org.apache.spark.ml.feature.Bucketizer"
# a placeholder to make it appear in the generated doc
splits = \
Param(Params._dummy(), "splits",
@@ -121,6 +120,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
__init__(self, splits=None, inputCol=None, outputCol=None)
"""
super(Bucketizer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
#: param for Splitting points for mapping continuous features into buckets. With n+1 splits,
# there are n buckets. A bucket defined by splits x,y holds values in the range [x,y)
# except the last bucket, which also includes y. The splits should be strictly increasing.
@@ -150,7 +150,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`splits`.
"""
- self.paramMap[self.splits] = value
+ self._paramMap[self.splits] = value
return self
def getSplits(self):
@@ -177,14 +177,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
"""
- _java_class = "org.apache.spark.ml.feature.HashingTF"
-
@keyword_only
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
"""
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
"""
super(HashingTF, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
self._setDefault(numFeatures=1 << 18)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -217,8 +216,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
DenseVector([0.2877, 0.0])
"""
- _java_class = "org.apache.spark.ml.feature.IDF"
-
# a placeholder to make it appear in the generated doc
minDocFreq = Param(Params._dummy(), "minDocFreq",
"minimum of documents in which a term should appear for filtering")
@@ -229,6 +226,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
__init__(self, minDocFreq=0, inputCol=None, outputCol=None)
"""
super(IDF, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid)
self.minDocFreq = Param(self, "minDocFreq",
"minimum of documents in which a term should appear for filtering")
self._setDefault(minDocFreq=0)
@@ -248,7 +246,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`minDocFreq`.
"""
- self.paramMap[self.minDocFreq] = value
+ self._paramMap[self.minDocFreq] = value
return self
def getMinDocFreq(self):
@@ -257,6 +255,9 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
"""
return self.getOrDefault(self.minDocFreq)
+ def _create_model(self, java_model):
+ return IDFModel(java_model)
+
class IDFModel(JavaModel):
"""
@@ -285,14 +286,13 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
# a placeholder to make it appear in the generated doc
p = Param(Params._dummy(), "p", "the p norm value.")
- _java_class = "org.apache.spark.ml.feature.Normalizer"
-
@keyword_only
def __init__(self, p=2.0, inputCol=None, outputCol=None):
"""
__init__(self, p=2.0, inputCol=None, outputCol=None)
"""
super(Normalizer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid)
self.p = Param(self, "p", "the p norm value.")
self._setDefault(p=2.0)
kwargs = self.__init__._input_kwargs
@@ -311,7 +311,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`p`.
"""
- self.paramMap[self.p] = value
+ self._paramMap[self.p] = value
return self
def getP(self):
@@ -347,8 +347,6 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
SparseVector(3, {0: 1.0})
"""
- _java_class = "org.apache.spark.ml.feature.OneHotEncoder"
-
# a placeholder to make it appear in the generated doc
includeFirst = Param(Params._dummy(), "includeFirst", "include first category")
@@ -358,6 +356,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
__init__(self, includeFirst=True, inputCol=None, outputCol=None)
"""
super(OneHotEncoder, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
self.includeFirst = Param(self, "includeFirst", "include first category")
self._setDefault(includeFirst=True)
kwargs = self.__init__._input_kwargs
@@ -376,7 +375,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`includeFirst`.
"""
- self.paramMap[self.includeFirst] = value
+ self._paramMap[self.includeFirst] = value
return self
def getIncludeFirst(self):
@@ -404,8 +403,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
"""
- _java_class = "org.apache.spark.ml.feature.PolynomialExpansion"
-
# a placeholder to make it appear in the generated doc
degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)")
@@ -415,6 +412,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
__init__(self, degree=2, inputCol=None, outputCol=None)
"""
super(PolynomialExpansion, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)")
self._setDefault(degree=2)
kwargs = self.__init__._input_kwargs
@@ -433,7 +432,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`degree`.
"""
- self.paramMap[self.degree] = value
+ self._paramMap[self.degree] = value
return self
def getDegree(self):
@@ -471,7 +470,6 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
TypeError: Method setParams forces keyword arguments.
"""
- _java_class = "org.apache.spark.ml.feature.RegexTokenizer"
# a placeholder to make it appear in the generated doc
minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)")
gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens")
@@ -485,7 +483,8 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
inputCol=None, outputCol=None)
"""
super(RegexTokenizer, self).__init__()
- self.minTokenLength = Param(self, "minLength", "minimum token length (>= 0)")
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
+ self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)")
self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens")
self.pattern = Param(self, "pattern", "regex pattern used for tokenizing")
self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+")
@@ -507,7 +506,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`minTokenLength`.
"""
- self.paramMap[self.minTokenLength] = value
+ self._paramMap[self.minTokenLength] = value
return self
def getMinTokenLength(self):
@@ -520,7 +519,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`gaps`.
"""
- self.paramMap[self.gaps] = value
+ self._paramMap[self.gaps] = value
return self
def getGaps(self):
@@ -533,7 +532,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`pattern`.
"""
- self.paramMap[self.pattern] = value
+ self._paramMap[self.pattern] = value
return self
def getPattern(self):
@@ -557,8 +556,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
DenseVector([1.4142])
"""
- _java_class = "org.apache.spark.ml.feature.StandardScaler"
-
# a placeholder to make it appear in the generated doc
withMean = Param(Params._dummy(), "withMean", "Center data with mean")
withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation")
@@ -569,6 +566,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
__init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None)
"""
super(StandardScaler, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid)
self.withMean = Param(self, "withMean", "Center data with mean")
self.withStd = Param(self, "withStd", "Scale to unit standard deviation")
self._setDefault(withMean=False, withStd=True)
@@ -588,7 +586,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`withMean`.
"""
- self.paramMap[self.withMean] = value
+ self._paramMap[self.withMean] = value
return self
def getWithMean(self):
@@ -601,7 +599,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`withStd`.
"""
- self.paramMap[self.withStd] = value
+ self._paramMap[self.withStd] = value
return self
def getWithStd(self):
@@ -610,6 +608,9 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
"""
return self.getOrDefault(self.withStd)
+ def _create_model(self, java_model):
+ return StandardScalerModel(java_model)
+
class StandardScalerModel(JavaModel):
"""
@@ -633,14 +634,13 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
"""
- _java_class = "org.apache.spark.ml.feature.StringIndexer"
-
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
"""
__init__(self, inputCol=None, outputCol=None)
"""
super(StringIndexer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -653,6 +653,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
+ def _create_model(self, java_model):
+ return StringIndexerModel(java_model)
+
class StringIndexerModel(JavaModel):
"""
@@ -686,14 +689,13 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
TypeError: Method setParams forces keyword arguments.
"""
- _java_class = "org.apache.spark.ml.feature.Tokenizer"
-
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
"""
__init__(self, inputCol=None, outputCol=None)
"""
super(Tokenizer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -723,14 +725,13 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
DenseVector([0.0, 1.0])
"""
- _java_class = "org.apache.spark.ml.feature.VectorAssembler"
-
@keyword_only
def __init__(self, inputCols=None, outputCol=None):
"""
__init__(self, inputCols=None, outputCol=None)
"""
super(VectorAssembler, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -797,7 +798,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
DenseVector([1.0, 0.0])
"""
- _java_class = "org.apache.spark.ml.feature.VectorIndexer"
# a placeholder to make it appear in the generated doc
maxCategories = Param(Params._dummy(), "maxCategories",
"Threshold for the number of values a categorical feature can take " +
@@ -810,6 +810,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
__init__(self, maxCategories=20, inputCol=None, outputCol=None)
"""
super(VectorIndexer, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
self.maxCategories = Param(self, "maxCategories",
"Threshold for the number of values a categorical feature " +
"can take (>= 2). If a feature is found to have " +
@@ -831,7 +832,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
"""
Sets the value of :py:attr:`maxCategories`.
"""
- self.paramMap[self.maxCategories] = value
+ self._paramMap[self.maxCategories] = value
return self
def getMaxCategories(self):
@@ -840,6 +841,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
"""
return self.getOrDefault(self.maxCategories)
+ def _create_model(self, java_model):
+ return VectorIndexerModel(java_model)
+
+
+class VectorIndexerModel(JavaModel):
+ """
+ Model fitted by VectorIndexer.
+ """
+
@inherit_doc
@ignore_unicode_prefix
@@ -855,7 +865,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
"""
- _java_class = "org.apache.spark.ml.feature.Word2Vec"
# a placeholder to make it appear in the generated doc
vectorSize = Param(Params._dummy(), "vectorSize",
"the dimension of codes after transforming from words")
@@ -873,6 +882,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
seed=42, inputCol=None, outputCol=None)
"""
super(Word2Vec, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
self.vectorSize = Param(self, "vectorSize",
"the dimension of codes after transforming from words")
self.numPartitions = Param(self, "numPartitions",
@@ -900,7 +910,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
Sets the value of :py:attr:`vectorSize`.
"""
- self.paramMap[self.vectorSize] = value
+ self._paramMap[self.vectorSize] = value
return self
def getVectorSize(self):
@@ -913,7 +923,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
Sets the value of :py:attr:`numPartitions`.
"""
- self.paramMap[self.numPartitions] = value
+ self._paramMap[self.numPartitions] = value
return self
def getNumPartitions(self):
@@ -926,7 +936,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
Sets the value of :py:attr:`minCount`.
"""
- self.paramMap[self.minCount] = value
+ self._paramMap[self.minCount] = value
return self
def getMinCount(self):
@@ -935,6 +945,9 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
return self.getOrDefault(self.minCount)
+ def _create_model(self, java_model):
+ return Word2VecModel(java_model)
+
class Word2VecModel(JavaModel):
"""
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 49c20b4cf7..67fb6e3dc7 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -16,6 +16,7 @@
#
from abc import ABCMeta
+import copy
from pyspark.ml.util import Identifiable
@@ -29,9 +30,9 @@ class Param(object):
"""
def __init__(self, parent, name, doc):
- if not isinstance(parent, Params):
- raise TypeError("Parent must be a Params but got type %s." % type(parent))
- self.parent = parent
+ if not isinstance(parent, Identifiable):
+ raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
+ self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
@@ -41,6 +42,15 @@ class Param(object):
def __repr__(self):
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
+ def __hash__(self):
+ return hash(str(self))
+
+ def __eq__(self, other):
+ if isinstance(other, Param):
+ return self.parent == other.parent and self.name == other.name
+ else:
+ return False
+
class Params(Identifiable):
"""
@@ -51,10 +61,13 @@ class Params(Identifiable):
__metaclass__ = ABCMeta
#: internal param map for user-supplied values param map
- paramMap = {}
+ _paramMap = {}
#: internal param map for default values
- defaultParamMap = {}
+ _defaultParamMap = {}
+
+ #: value returned by :py:func:`params`
+ _params = None
@property
def params(self):
@@ -63,10 +76,12 @@ class Params(Identifiable):
uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
- return list(filter(lambda attr: isinstance(attr, Param),
- [getattr(self, x) for x in dir(self) if x != "params"]))
+ if self._params is None:
+ self._params = list(filter(lambda attr: isinstance(attr, Param),
+ [getattr(self, x) for x in dir(self) if x != "params"]))
+ return self._params
- def _explain(self, param):
+ def explainParam(self, param):
"""
Explains a single param and returns its name, doc, and optional
default value and user-supplied value in a string.
@@ -74,10 +89,10 @@ class Params(Identifiable):
param = self._resolveParam(param)
values = []
if self.isDefined(param):
- if param in self.defaultParamMap:
- values.append("default: %s" % self.defaultParamMap[param])
- if param in self.paramMap:
- values.append("current: %s" % self.paramMap[param])
+ if param in self._defaultParamMap:
+ values.append("default: %s" % self._defaultParamMap[param])
+ if param in self._paramMap:
+ values.append("current: %s" % self._paramMap[param])
else:
values.append("undefined")
valueStr = "(" + ", ".join(values) + ")"
@@ -88,7 +103,7 @@ class Params(Identifiable):
Returns the documentation of all params with their optionally
default values and user-supplied values.
"""
- return "\n".join([self._explain(param) for param in self.params])
+ return "\n".join([self.explainParam(param) for param in self.params])
def getParam(self, paramName):
"""
@@ -105,56 +120,76 @@ class Params(Identifiable):
Checks whether a param is explicitly set by user.
"""
param = self._resolveParam(param)
- return param in self.paramMap
+ return param in self._paramMap
def hasDefault(self, param):
"""
Checks whether a param has a default value.
"""
param = self._resolveParam(param)
- return param in self.defaultParamMap
+ return param in self._defaultParamMap
def isDefined(self, param):
"""
- Checks whether a param is explicitly set by user or has a default value.
+ Checks whether a param is explicitly set by user or has
+ a default value.
"""
return self.isSet(param) or self.hasDefault(param)
+ def hasParam(self, paramName):
+ """
+ Tests whether this instance contains a param with a given
+ (string) name.
+ """
+ param = self._resolveParam(paramName)
+ return param in self.params
+
def getOrDefault(self, param):
"""
Gets the value of a param in the user-supplied param map or its
default value. Raises an error if either is set.
"""
- if isinstance(param, Param):
- if param in self.paramMap:
- return self.paramMap[param]
- else:
- return self.defaultParamMap[param]
- elif isinstance(param, str):
- return self.getOrDefault(self.getParam(param))
+ param = self._resolveParam(param)
+ if param in self._paramMap:
+ return self._paramMap[param]
else:
- raise KeyError("Cannot recognize %r as a param." % param)
+ return self._defaultParamMap[param]
- def extractParamMap(self, extraParamMap={}):
+ def extractParamMap(self, extra={}):
"""
Extracts the embedded default param values and user-supplied
values, and then merges them with extra values from input into
a flat param map, where the latter value is used if there exist
conflicts, i.e., with ordering: default param values <
- user-supplied values < extraParamMap.
- :param extraParamMap: extra param values
+ user-supplied values < extra.
+ :param extra: extra param values
:return: merged param map
"""
- paramMap = self.defaultParamMap.copy()
- paramMap.update(self.paramMap)
- paramMap.update(extraParamMap)
+ paramMap = self._defaultParamMap.copy()
+ paramMap.update(self._paramMap)
+ paramMap.update(extra)
return paramMap
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with the same uid and some
+ extra params. The default implementation creates a
+ shallow copy using :py:func:`copy.copy`, and then copies the
+ embedded and extra parameters over and returns the copy.
+ Subclasses should override this method if the default approach
+ is not sufficient.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ that = copy.copy(self)
+ that._paramMap = self.extractParamMap(extra)
+ return that
+
def _shouldOwn(self, param):
"""
Validates that the input param belongs to this Params instance.
"""
- if param.parent is not self:
+ if not (self.uid == param.parent and self.hasParam(param.name)):
raise ValueError("Param %r does not belong to %r." % (param, self))
def _resolveParam(self, param):
@@ -175,7 +210,8 @@ class Params(Identifiable):
@staticmethod
def _dummy():
"""
- Returns a dummy Params instance used as a placeholder to generate docs.
+ Returns a dummy Params instance used as a placeholder to
+ generate docs.
"""
dummy = Params()
dummy.uid = "undefined"
@@ -186,7 +222,7 @@ class Params(Identifiable):
Sets user-supplied params.
"""
for param, value in kwargs.items():
- self.paramMap[getattr(self, param)] = value
+ self._paramMap[getattr(self, param)] = value
return self
def _setDefault(self, **kwargs):
@@ -194,5 +230,19 @@ class Params(Identifiable):
Sets default params.
"""
for param, value in kwargs.items():
- self.defaultParamMap[getattr(self, param)] = value
+ self._defaultParamMap[getattr(self, param)] = value
return self
+
+ def _copyValues(self, to, extra={}):
+ """
+ Copies param values from this instance to another instance for
+ params shared by them.
+ :param to: the target instance
+ :param extra: extra params to be copied
+ :return: the target instance with param values copied
+ """
+ paramMap = self.extractParamMap(extra)
+ for p in self.params:
+ if p in paramMap and to.hasParam(p.name):
+ to._set(**{p.name: paramMap[p]})
+ return to
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 6fa9b8c2cf..91e45ec373 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -83,7 +83,7 @@ def _gen_param_code(name, doc, defaultValueStr):
"""
Sets the value of :py:attr:`$name`.
"""
- self.paramMap[self.$name] = value
+ self._paramMap[self.$name] = value
return self
def get$Name(self):
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index b116f05a06..a5dc9b7ef2 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -39,7 +39,7 @@ class HasMaxIter(Params):
"""
Sets the value of :py:attr:`maxIter`.
"""
- self.paramMap[self.maxIter] = value
+ self._paramMap[self.maxIter] = value
return self
def getMaxIter(self):
@@ -68,7 +68,7 @@ class HasRegParam(Params):
"""
Sets the value of :py:attr:`regParam`.
"""
- self.paramMap[self.regParam] = value
+ self._paramMap[self.regParam] = value
return self
def getRegParam(self):
@@ -97,7 +97,7 @@ class HasFeaturesCol(Params):
"""
Sets the value of :py:attr:`featuresCol`.
"""
- self.paramMap[self.featuresCol] = value
+ self._paramMap[self.featuresCol] = value
return self
def getFeaturesCol(self):
@@ -126,7 +126,7 @@ class HasLabelCol(Params):
"""
Sets the value of :py:attr:`labelCol`.
"""
- self.paramMap[self.labelCol] = value
+ self._paramMap[self.labelCol] = value
return self
def getLabelCol(self):
@@ -155,7 +155,7 @@ class HasPredictionCol(Params):
"""
Sets the value of :py:attr:`predictionCol`.
"""
- self.paramMap[self.predictionCol] = value
+ self._paramMap[self.predictionCol] = value
return self
def getPredictionCol(self):
@@ -184,7 +184,7 @@ class HasProbabilityCol(Params):
"""
Sets the value of :py:attr:`probabilityCol`.
"""
- self.paramMap[self.probabilityCol] = value
+ self._paramMap[self.probabilityCol] = value
return self
def getProbabilityCol(self):
@@ -213,7 +213,7 @@ class HasRawPredictionCol(Params):
"""
Sets the value of :py:attr:`rawPredictionCol`.
"""
- self.paramMap[self.rawPredictionCol] = value
+ self._paramMap[self.rawPredictionCol] = value
return self
def getRawPredictionCol(self):
@@ -242,7 +242,7 @@ class HasInputCol(Params):
"""
Sets the value of :py:attr:`inputCol`.
"""
- self.paramMap[self.inputCol] = value
+ self._paramMap[self.inputCol] = value
return self
def getInputCol(self):
@@ -271,7 +271,7 @@ class HasInputCols(Params):
"""
Sets the value of :py:attr:`inputCols`.
"""
- self.paramMap[self.inputCols] = value
+ self._paramMap[self.inputCols] = value
return self
def getInputCols(self):
@@ -300,7 +300,7 @@ class HasOutputCol(Params):
"""
Sets the value of :py:attr:`outputCol`.
"""
- self.paramMap[self.outputCol] = value
+ self._paramMap[self.outputCol] = value
return self
def getOutputCol(self):
@@ -329,7 +329,7 @@ class HasNumFeatures(Params):
"""
Sets the value of :py:attr:`numFeatures`.
"""
- self.paramMap[self.numFeatures] = value
+ self._paramMap[self.numFeatures] = value
return self
def getNumFeatures(self):
@@ -358,7 +358,7 @@ class HasCheckpointInterval(Params):
"""
Sets the value of :py:attr:`checkpointInterval`.
"""
- self.paramMap[self.checkpointInterval] = value
+ self._paramMap[self.checkpointInterval] = value
return self
def getCheckpointInterval(self):
@@ -387,7 +387,7 @@ class HasSeed(Params):
"""
Sets the value of :py:attr:`seed`.
"""
- self.paramMap[self.seed] = value
+ self._paramMap[self.seed] = value
return self
def getSeed(self):
@@ -416,7 +416,7 @@ class HasTol(Params):
"""
Sets the value of :py:attr:`tol`.
"""
- self.paramMap[self.tol] = value
+ self._paramMap[self.tol] = value
return self
def getTol(self):
@@ -445,7 +445,7 @@ class HasStepSize(Params):
"""
Sets the value of :py:attr:`stepSize`.
"""
- self.paramMap[self.stepSize] = value
+ self._paramMap[self.stepSize] = value
return self
def getStepSize(self):
@@ -487,7 +487,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`maxDepth`.
"""
- self.paramMap[self.maxDepth] = value
+ self._paramMap[self.maxDepth] = value
return self
def getMaxDepth(self):
@@ -500,7 +500,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`maxBins`.
"""
- self.paramMap[self.maxBins] = value
+ self._paramMap[self.maxBins] = value
return self
def getMaxBins(self):
@@ -513,7 +513,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
- self.paramMap[self.minInstancesPerNode] = value
+ self._paramMap[self.minInstancesPerNode] = value
return self
def getMinInstancesPerNode(self):
@@ -526,7 +526,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
- self.paramMap[self.minInfoGain] = value
+ self._paramMap[self.minInfoGain] = value
return self
def getMinInfoGain(self):
@@ -539,7 +539,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
- self.paramMap[self.maxMemoryInMB] = value
+ self._paramMap[self.maxMemoryInMB] = value
return self
def getMaxMemoryInMB(self):
@@ -552,7 +552,7 @@ class DecisionTreeParams(Params):
"""
Sets the value of :py:attr:`cacheNodeIds`.
"""
- self.paramMap[self.cacheNodeIds] = value
+ self._paramMap[self.cacheNodeIds] = value
return self
def getCacheNodeIds(self):
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a328bcf84a..0f38e02127 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -31,18 +31,40 @@ class Estimator(Params):
__metaclass__ = ABCMeta
@abstractmethod
- def fit(self, dataset, params={}):
+ def _fit(self, dataset):
"""
- Fits a model to the input dataset with optional parameters.
+ Fits a model to the input dataset. This is called by the
+ default implementation of fit.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overwrites embedded
- params
:returns: fitted model
"""
raise NotImplementedError()
+ def fit(self, dataset, params={}):
+ """
+ Fits a model to the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.DataFrame`
+ :param params: an optional param map that overrides embedded
+ params. If a list/tuple of param maps is given,
+ this calls fit on each param map and returns a
+ list of models.
+ :returns: fitted model(s)
+ """
+ if isinstance(params, (list, tuple)):
+ return [self.fit(dataset, paramMap) for paramMap in params]
+ elif isinstance(params, dict):
+ if params:
+ return self.copy(params)._fit(dataset)
+ else:
+ return self._fit(dataset)
+ else:
+ raise ValueError("Params must be either a param map or a list/tuple of param maps, "
+ "but got %s." % type(params))
+
@inherit_doc
class Transformer(Params):
@@ -54,18 +76,34 @@ class Transformer(Params):
__metaclass__ = ABCMeta
@abstractmethod
- def transform(self, dataset, params={}):
+ def _transform(self, dataset):
"""
Transforms the input dataset with optional parameters.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overwrites embedded
- params
:returns: transformed dataset
"""
raise NotImplementedError()
+ def transform(self, dataset, params={}):
+ """
+ Transforms the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.DataFrame`
+ :param params: an optional param map that overrides embedded
+ params.
+ :returns: transformed dataset
+ """
+ if isinstance(params, dict):
+ if params:
+ return self.copy(params,)._transform(dataset)
+ else:
+ return self._transform(dataset)
+ else:
+ raise ValueError("Params must be either a param map but got %s." % type(params))
+
@inherit_doc
class Model(Transformer):
@@ -113,15 +151,15 @@ class Pipeline(Estimator):
:param value: a list of transformers or estimators
:return: the pipeline instance
"""
- self.paramMap[self.stages] = value
+ self._paramMap[self.stages] = value
return self
def getStages(self):
"""
Get pipeline stages.
"""
- if self.stages in self.paramMap:
- return self.paramMap[self.stages]
+ if self.stages in self._paramMap:
+ return self._paramMap[self.stages]
@keyword_only
def setParams(self, stages=[]):
@@ -132,9 +170,8 @@ class Pipeline(Estimator):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
- def fit(self, dataset, params={}):
- paramMap = self.extractParamMap(params)
- stages = paramMap[self.stages]
+ def _fit(self, dataset):
+ stages = self.getStages()
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise TypeError(
@@ -148,16 +185,21 @@ class Pipeline(Estimator):
if i <= indexOfLastEstimator:
if isinstance(stage, Transformer):
transformers.append(stage)
- dataset = stage.transform(dataset, paramMap)
+ dataset = stage.transform(dataset)
else: # must be an Estimator
- model = stage.fit(dataset, paramMap)
+ model = stage.fit(dataset)
transformers.append(model)
if i < indexOfLastEstimator:
- dataset = model.transform(dataset, paramMap)
+ dataset = model.transform(dataset)
else:
transformers.append(stage)
return PipelineModel(transformers)
+ def copy(self, extra={}):
+ that = Params.copy(self, extra)
+ stages = [stage.copy(extra) for stage in that.getStages()]
+ return that.setStages(stages)
+
@inherit_doc
class PipelineModel(Model):
@@ -165,16 +207,19 @@ class PipelineModel(Model):
Represents a compiled pipeline with transformers and fitted models.
"""
- def __init__(self, transformers):
+ def __init__(self, stages):
super(PipelineModel, self).__init__()
- self.transformers = transformers
+ self.stages = stages
- def transform(self, dataset, params={}):
- paramMap = self.extractParamMap(params)
- for t in self.transformers:
- dataset = t.transform(dataset, paramMap)
+ def _transform(self, dataset):
+ for t in self.stages:
+ dataset = t.transform(dataset)
return dataset
+ def copy(self, extra={}):
+ stages = [stage.copy(extra) for stage in self.stages]
+ return PipelineModel(stages)
+
class Evaluator(Params):
"""
@@ -184,14 +229,30 @@ class Evaluator(Params):
__metaclass__ = ABCMeta
@abstractmethod
- def evaluate(self, dataset, params={}):
+ def _evaluate(self, dataset):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and
+ predictions
+ :return: metric
+ """
+ raise NotImplementedError()
+
+ def evaluate(self, dataset, params={}):
+ """
+ Evaluates the output with optional parameters.
+
+ :param dataset: a dataset that contains labels/observations and
predictions
:param params: an optional param map that overrides embedded
params
:return: metric
"""
- raise NotImplementedError()
+ if isinstance(params, dict):
+ if params:
+ return self.copy(params)._evaluate(dataset)
+ else:
+ return self._evaluate(dataset)
+ else:
+ raise ValueError("Params must be a param map but got %s." % type(params))
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index b2439cbd96..39c2527543 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -74,7 +74,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
>>> predictions[2]
Row(user=2, item=0, prediction=-1.15...)
"""
- _java_class = "org.apache.spark.ml.recommendation.ALS"
+
# a placeholder to make it appear in the generated doc
rank = Param(Params._dummy(), "rank", "rank of the factorization")
numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks")
@@ -97,6 +97,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
ratingCol="rating", nonnegative=false, checkpointInterval=10)
"""
super(ALS, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
self.rank = Param(self, "rank", "rank of the factorization")
self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks")
self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks")
@@ -133,7 +134,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`rank`.
"""
- self.paramMap[self.rank] = value
+ self._paramMap[self.rank] = value
return self
def getRank(self):
@@ -146,7 +147,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`numUserBlocks`.
"""
- self.paramMap[self.numUserBlocks] = value
+ self._paramMap[self.numUserBlocks] = value
return self
def getNumUserBlocks(self):
@@ -159,7 +160,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`numItemBlocks`.
"""
- self.paramMap[self.numItemBlocks] = value
+ self._paramMap[self.numItemBlocks] = value
return self
def getNumItemBlocks(self):
@@ -172,14 +173,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
"""
- self.paramMap[self.numUserBlocks] = value
- self.paramMap[self.numItemBlocks] = value
+ self._paramMap[self.numUserBlocks] = value
+ self._paramMap[self.numItemBlocks] = value
def setImplicitPrefs(self, value):
"""
Sets the value of :py:attr:`implicitPrefs`.
"""
- self.paramMap[self.implicitPrefs] = value
+ self._paramMap[self.implicitPrefs] = value
return self
def getImplicitPrefs(self):
@@ -192,7 +193,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`alpha`.
"""
- self.paramMap[self.alpha] = value
+ self._paramMap[self.alpha] = value
return self
def getAlpha(self):
@@ -205,7 +206,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`userCol`.
"""
- self.paramMap[self.userCol] = value
+ self._paramMap[self.userCol] = value
return self
def getUserCol(self):
@@ -218,7 +219,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`itemCol`.
"""
- self.paramMap[self.itemCol] = value
+ self._paramMap[self.itemCol] = value
return self
def getItemCol(self):
@@ -231,7 +232,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`ratingCol`.
"""
- self.paramMap[self.ratingCol] = value
+ self._paramMap[self.ratingCol] = value
return self
def getRatingCol(self):
@@ -244,7 +245,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
Sets the value of :py:attr:`nonnegative`.
"""
- self.paramMap[self.nonnegative] = value
+ self._paramMap[self.nonnegative] = value
return self
def getNonnegative(self):
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index ef77e19327..ff809cdafd 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -62,7 +62,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
...
TypeError: Method setParams forces keyword arguments.
"""
- _java_class = "org.apache.spark.ml.regression.LinearRegression"
+
# a placeholder to make it appear in the generated doc
elasticNetParam = \
Param(Params._dummy(), "elasticNetParam",
@@ -77,6 +77,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
"""
super(LinearRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.LinearRegression", self.uid)
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
self.elasticNetParam = \
@@ -105,7 +107,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
"""
Sets the value of :py:attr:`elasticNetParam`.
"""
- self.paramMap[self.elasticNetParam] = value
+ self._paramMap[self.elasticNetParam] = value
return self
def getElasticNetParam(self):
@@ -178,7 +180,6 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
1.0
"""
- _java_class = "org.apache.spark.ml.regression.DecisionTreeRegressor"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -194,6 +195,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
"""
super(DecisionTreeRegressor, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -226,7 +229,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -264,7 +267,6 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
0.5
"""
- _java_class = "org.apache.spark.ml.regression.RandomForestRegressor"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
@@ -290,6 +292,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
impurity="variance", numTrees=20, featureSubsetStrategy="auto", seed=42)
"""
super(RandomForestRegressor, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.RandomForestRegressor", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
@@ -335,7 +339,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"""
Sets the value of :py:attr:`impurity`.
"""
- self.paramMap[self.impurity] = value
+ self._paramMap[self.impurity] = value
return self
def getImpurity(self):
@@ -348,7 +352,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -361,7 +365,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"""
Sets the value of :py:attr:`numTrees`.
"""
- self.paramMap[self.numTrees] = value
+ self._paramMap[self.numTrees] = value
return self
def getNumTrees(self):
@@ -374,7 +378,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"""
Sets the value of :py:attr:`featureSubsetStrategy`.
"""
- self.paramMap[self.featureSubsetStrategy] = value
+ self._paramMap[self.featureSubsetStrategy] = value
return self
def getFeatureSubsetStrategy(self):
@@ -412,7 +416,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
1.0
"""
- _java_class = "org.apache.spark.ml.regression.GBTRegressor"
# a placeholder to make it appear in the generated doc
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -436,6 +439,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
lossType="squared", maxIter=20, stepSize=0.1)
"""
super(GBTRegressor, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
#: param for Loss function which GBT tries to minimize (case-insensitive).
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
@@ -477,7 +481,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
"""
Sets the value of :py:attr:`lossType`.
"""
- self.paramMap[self.lossType] = value
+ self._paramMap[self.lossType] = value
return self
def getLossType(self):
@@ -490,7 +494,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
- self.paramMap[self.subsamplingRate] = value
+ self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
@@ -503,7 +507,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
"""
Sets the value of :py:attr:`stepSize`.
"""
- self.paramMap[self.stepSize] = value
+ self._paramMap[self.stepSize] = value
return self
def getStepSize(self):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index ba6478dcd5..10fe0ef8db 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -31,10 +31,12 @@ else:
import unittest
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame
-from pyspark.ml.param import Param
+from pyspark.sql import DataFrame, SQLContext
+from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
-from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer
+from pyspark.ml import Estimator, Model, Pipeline, Transformer
+from pyspark.ml.feature import *
+from pyspark.mllib.linalg import DenseVector
class MockDataset(DataFrame):
@@ -43,44 +45,43 @@ class MockDataset(DataFrame):
self.index = 0
-class MockTransformer(Transformer):
+class HasFake(Params):
+
+ def __init__(self):
+ super(HasFake, self).__init__()
+ self.fake = Param(self, "fake", "fake param")
+
+ def getFake(self):
+ return self.getOrDefault(self.fake)
+
+
+class MockTransformer(Transformer, HasFake):
def __init__(self):
super(MockTransformer, self).__init__()
- self.fake = Param(self, "fake", "fake")
self.dataset_index = None
- self.fake_param_value = None
- def transform(self, dataset, params={}):
+ def _transform(self, dataset):
self.dataset_index = dataset.index
- if self.fake in params:
- self.fake_param_value = params[self.fake]
dataset.index += 1
return dataset
-class MockEstimator(Estimator):
+class MockEstimator(Estimator, HasFake):
def __init__(self):
super(MockEstimator, self).__init__()
- self.fake = Param(self, "fake", "fake")
self.dataset_index = None
- self.fake_param_value = None
- self.model = None
- def fit(self, dataset, params={}):
+ def _fit(self, dataset):
self.dataset_index = dataset.index
- if self.fake in params:
- self.fake_param_value = params[self.fake]
model = MockModel()
- self.model = model
+ self._copyValues(model)
return model
-class MockModel(MockTransformer, Model):
-
- def __init__(self):
- super(MockModel, self).__init__()
+class MockModel(MockTransformer, Model, HasFake):
+ pass
class PipelineTests(PySparkTestCase):
@@ -91,19 +92,17 @@ class PipelineTests(PySparkTestCase):
transformer1 = MockTransformer()
estimator2 = MockEstimator()
transformer3 = MockTransformer()
- pipeline = Pipeline() \
- .setStages([estimator0, transformer1, estimator2, transformer3])
+ pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
- self.assertEqual(0, estimator0.dataset_index)
- self.assertEqual(0, estimator0.fake_param_value)
- model0 = estimator0.model
+ model0, transformer1, model2, transformer3 = pipeline_model.stages
self.assertEqual(0, model0.dataset_index)
+ self.assertEqual(0, model0.getFake())
self.assertEqual(1, transformer1.dataset_index)
- self.assertEqual(1, transformer1.fake_param_value)
- self.assertEqual(2, estimator2.dataset_index)
- model2 = estimator2.model
- self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
- "not be called during fit.")
+ self.assertEqual(1, transformer1.getFake())
+ self.assertEqual(2, dataset.index)
+ self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
+ self.assertIsNone(transformer3.dataset_index,
+ "The last transformer shouldn't be called in fit.")
dataset = pipeline_model.transform(dataset)
self.assertEqual(2, model0.dataset_index)
self.assertEqual(3, transformer1.dataset_index)
@@ -129,7 +128,7 @@ class ParamTests(PySparkTestCase):
maxIter = testParams.maxIter
self.assertEqual(maxIter.name, "maxIter")
self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
- self.assertTrue(maxIter.parent is testParams)
+ self.assertTrue(maxIter.parent == testParams.uid)
def test_params(self):
testParams = TestParams()
@@ -139,6 +138,7 @@ class ParamTests(PySparkTestCase):
params = testParams.params
self.assertEqual(params, [inputCol, maxIter])
+ self.assertTrue(testParams.hasParam(maxIter))
self.assertTrue(testParams.hasDefault(maxIter))
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
@@ -147,6 +147,7 @@ class ParamTests(PySparkTestCase):
self.assertTrue(testParams.isSet(maxIter))
self.assertEquals(testParams.getMaxIter(), 100)
+ self.assertTrue(testParams.hasParam(inputCol))
self.assertFalse(testParams.hasDefault(inputCol))
self.assertFalse(testParams.isSet(inputCol))
self.assertFalse(testParams.isDefined(inputCol))
@@ -159,5 +160,45 @@ class ParamTests(PySparkTestCase):
"maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
+class FeatureTests(PySparkTestCase):
+
+ def test_binarizer(self):
+ b0 = Binarizer()
+ self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold])
+ self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
+ self.assertTrue(b0.hasDefault(b0.threshold))
+ self.assertEqual(b0.getThreshold(), 0.0)
+ b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
+ self.assertTrue(all([b0.isSet(p) for p in b0.params]))
+ self.assertEqual(b0.getThreshold(), 1.0)
+ self.assertEqual(b0.getInputCol(), "input")
+ self.assertEqual(b0.getOutputCol(), "output")
+
+ b0c = b0.copy({b0.threshold: 2.0})
+ self.assertEqual(b0c.uid, b0.uid)
+ self.assertListEqual(b0c.params, b0.params)
+ self.assertEqual(b0c.getThreshold(), 2.0)
+
+ b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
+ self.assertNotEqual(b1.uid, b0.uid)
+ self.assertEqual(b1.getThreshold(), 2.0)
+ self.assertEqual(b1.getInputCol(), "input")
+ self.assertEqual(b1.getOutputCol(), "output")
+
+ def test_idf(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (DenseVector([1.0, 2.0]),),
+ (DenseVector([0.0, 1.0]),),
+ (DenseVector([3.0, 0.2]),)], ["tf"])
+ idf0 = IDF(inputCol="tf")
+ self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol])
+ idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"})
+ self.assertEqual(idf0m.uid, idf0.uid,
+ "Model should inherit the UID from its parent estimator.")
+ output = idf0m.transform(dataset)
+ self.assertIsNotNone(output.head().idf)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 86f4dc7368..497841b6c8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -155,7 +155,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`estimator`.
"""
- self.paramMap[self.estimator] = value
+ self._paramMap[self.estimator] = value
return self
def getEstimator(self):
@@ -168,7 +168,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
- self.paramMap[self.estimatorParamMaps] = value
+ self._paramMap[self.estimatorParamMaps] = value
return self
def getEstimatorParamMaps(self):
@@ -181,7 +181,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`evaluator`.
"""
- self.paramMap[self.evaluator] = value
+ self._paramMap[self.evaluator] = value
return self
def getEvaluator(self):
@@ -194,7 +194,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`numFolds`.
"""
- self.paramMap[self.numFolds] = value
+ self._paramMap[self.numFolds] = value
return self
def getNumFolds(self):
@@ -203,13 +203,12 @@ class CrossValidator(Estimator):
"""
return self.getOrDefault(self.numFolds)
- def fit(self, dataset, params={}):
- paramMap = self.extractParamMap(params)
- est = paramMap[self.estimator]
- epm = paramMap[self.estimatorParamMaps]
+ def _fit(self, dataset):
+ est = self.getOrDefault(self.estimator)
+ epm = self.getOrDefault(self.estimatorParamMaps)
numModels = len(epm)
- eva = paramMap[self.evaluator]
- nFolds = paramMap[self.numFolds]
+ eva = self.getOrDefault(self.evaluator)
+ nFolds = self.getOrDefault(self.numFolds)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
@@ -229,6 +228,15 @@ class CrossValidator(Estimator):
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)
+ def copy(self, extra={}):
+ newCV = Params.copy(self, extra)
+ if self.isSet(self.estimator):
+ newCV.setEstimator(self.getEstimator().copy(extra))
+ # estimatorParamMaps remain the same
+ if self.isSet(self.evaluator):
+ newCV.setEvaluator(self.getEvaluator().copy(extra))
+ return newCV
+
class CrossValidatorModel(Model):
"""
@@ -240,8 +248,19 @@ class CrossValidatorModel(Model):
#: best model from cross validation
self.bestModel = bestModel
- def transform(self, dataset, params={}):
- return self.bestModel.transform(dataset, params)
+ def _transform(self, dataset):
+ return self.bestModel.transform(dataset)
+
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with a randomly generated uid
+ and some extra params. This copies the underlying bestModel,
+ creates a deep copy of the embedded paramMap, and
+ copies the embedded and extra parameters over.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ return CrossValidatorModel(self.bestModel.copy(extra))
if __name__ == "__main__":
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d3cb100a9e..cee9d67b05 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -39,9 +39,16 @@ class Identifiable(object):
"""
def __init__(self):
- #: A unique id for the object. The default implementation
- #: concatenates the class name, "_", and 8 random hex chars.
- self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8]
+ #: A unique id for the object.
+ self.uid = self._randomUID()
def __repr__(self):
return self.uid
+
+ @classmethod
+ def _randomUID(cls):
+ """
+ Generate a unique id for the object. The default implementation
+ concatenates the class name, "_", and 12 random hex chars.
+ """
+ return cls.__name__ + "_" + uuid.uuid4().hex[12:]
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index dda6c6aba3..4419e16184 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -45,46 +45,61 @@ class JavaWrapper(Params):
__metaclass__ = ABCMeta
- #: Fully-qualified class name of the wrapped Java component.
- _java_class = None
+ #: The wrapped Java companion object. Subclasses should initialize
+ #: it properly. The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+ _java_obj = None
- def _java_obj(self):
+ @staticmethod
+ def _new_java_obj(java_class, *args):
"""
- Returns or creates a Java object.
+ Construct a new Java object.
"""
+ sc = SparkContext._active_spark_context
java_obj = _jvm()
- for name in self._java_class.split("."):
+ for name in java_class.split("."):
java_obj = getattr(java_obj, name)
- return java_obj()
+ java_args = [_py2java(sc, arg) for arg in args]
+ return java_obj(*java_args)
- def _transfer_params_to_java(self, params, java_obj):
+ def _make_java_param_pair(self, param, value):
"""
- Transforms the embedded params and additional params to the
- input Java object.
- :param params: additional params (overwriting embedded values)
- :param java_obj: Java object to receive the params
+ Makes a Java parm pair.
+ """
+ sc = SparkContext._active_spark_context
+ param = self._resolveParam(param)
+ java_param = self._java_obj.getParam(param.name)
+ java_value = _py2java(sc, value)
+ return java_param.w(java_value)
+
+ def _transfer_params_to_java(self):
+ """
+ Transforms the embedded params to the companion Java object.
"""
- paramMap = self.extractParamMap(params)
+ paramMap = self.extractParamMap()
for param in self.params:
if param in paramMap:
- value = paramMap[param]
- java_param = java_obj.getParam(param.name)
- java_obj.set(java_param.w(value))
+ pair = self._make_java_param_pair(param, paramMap[param])
+ self._java_obj.set(pair)
+
+ def _transfer_params_from_java(self):
+ """
+ Transforms the embedded params from the companion Java object.
+ """
+ sc = SparkContext._active_spark_context
+ for param in self.params:
+ if self._java_obj.hasParam(param.name):
+ java_param = self._java_obj.getParam(param.name)
+ value = _java2py(sc, self._java_obj.getOrDefault(java_param))
+ self._paramMap[param] = value
- def _empty_java_param_map(self):
+ @staticmethod
+ def _empty_java_param_map():
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _create_java_param_map(self, params, java_obj):
- paramMap = self._empty_java_param_map()
- for param, value in params.items():
- if param.parent is self:
- java_param = java_obj.getParam(param.name)
- paramMap.put(java_param.w(value))
- return paramMap
-
@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
@@ -99,9 +114,9 @@ class JavaEstimator(Estimator, JavaWrapper):
"""
Creates a model from the input Java model reference.
"""
- return JavaModel(java_model)
+ raise NotImplementedError()
- def _fit_java(self, dataset, params={}):
+ def _fit_java(self, dataset):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
@@ -109,12 +124,11 @@ class JavaEstimator(Estimator, JavaWrapper):
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return java_obj.fit(dataset._jdf, self._empty_java_param_map())
+ self._transfer_params_to_java()
+ return self._java_obj.fit(dataset._jdf)
- def fit(self, dataset, params={}):
- java_model = self._fit_java(dataset, params)
+ def _fit(self, dataset):
+ java_model = self._fit_java(dataset)
return self._create_model(java_model)
@@ -127,30 +141,47 @@ class JavaTransformer(Transformer, JavaWrapper):
__metaclass__ = ABCMeta
- def transform(self, dataset, params={}):
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
+ def _transform(self, dataset):
+ self._transfer_params_to_java()
+ return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
@inherit_doc
class JavaModel(Model, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
- implementations.
+ implementations. Subclasses should inherit this class before
+ param mix-ins, because this sets the UID from the Java model.
"""
__metaclass__ = ABCMeta
def __init__(self, java_model):
- super(JavaTransformer, self).__init__()
- self._java_model = java_model
+ """
+ Initialize this instance with a Java model object.
+ Subclasses should call this constructor, initialize params,
+ and then call _transformer_params_from_java.
+ """
+ super(JavaModel, self).__init__()
+ self._java_obj = java_model
+ self.uid = java_model.uid()
- def _java_obj(self):
- return self._java_model
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with the same uid and some
+ extra params. This implementation first calls Params.copy and
+ then make a copy of the companion Java model with extra params.
+ So both the Python wrapper and the Java model get copied.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ that = super(JavaModel, self).copy(extra)
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
+ return that
def _call_java(self, name, *args):
- m = getattr(self._java_model, name)
+ m = getattr(self._java_obj, name)
sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))
@@ -165,7 +196,11 @@ class JavaEvaluator(Evaluator, JavaWrapper):
__metaclass__ = ABCMeta
- def evaluate(self, dataset, params={}):
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
+ def _evaluate(self, dataset):
+ """
+ Evaluates the output.
+ :param dataset: a dataset that contains labels/observations and predictions.
+ :return: evaluation metric
+ """
+ self._transfer_params_to_java()
+ return self._java_obj.evaluate(dataset._jdf)