aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-06 01:28:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-06 01:28:43 -0700
commit32cdc815c6fc19b5c8c4eca35f88a61302d67cd5 (patch)
tree0ed046021a78b1f3682b13ca918d65cd08fed9b9
parent9f019c7223bb79b8d5cd52980b2723a1601d1134 (diff)
downloadspark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.gz
spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.bz2
spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.zip
[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API
Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5926 from mengxr/SPARK-6940 and squashes the following commits: 6af181f [Xiangrui Meng] add TODOs 8285134 [Xiangrui Meng] update doc 060f7c3 [Xiangrui Meng] update doctest acac727 [Xiangrui Meng] add keyword args cdddecd [Xiangrui Meng] add CrossValidator in Python
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala7
-rw-r--r--python/pyspark/ml/pipeline.py13
-rw-r--r--python/pyspark/ml/tuning.py183
-rw-r--r--python/pyspark/ml/wrapper.py4
4 files changed, 199 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cee2aa6e85..9208127eb1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params {
def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
/**
- * param for the evaluator for selection
+ * param for the evaluator used to select hyper-parameters that maximize the cross-validated
+ * metric
* @group param
*/
- val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+ "evaluator used to select hyper-parameters that maximize the cross-validated metric")
/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)
@@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
+ // TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 7b875e4b71..c1b2077c98 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc
-__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']
@inherit_doc
@@ -71,6 +71,15 @@ class Transformer(Params):
@inherit_doc
+class Model(Transformer):
+ """
+ Abstract class for models that are fitted by estimators.
+ """
+
+ __metaclass__ = ABCMeta
+
+
+@inherit_doc
class Pipeline(Estimator):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
@@ -154,7 +163,7 @@ class Pipeline(Estimator):
@inherit_doc
-class PipelineModel(Transformer):
+class PipelineModel(Model):
"""
Represents a compiled pipeline with transformers and fitted models.
"""
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 1773ab5bdc..f6cf2c3439 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -16,8 +16,14 @@
#
import itertools
+import numpy as np
-__all__ = ['ParamGridBuilder']
+from pyspark.ml.param import Params, Param
+from pyspark.ml import Estimator, Model
+from pyspark.ml.util import keyword_only
+from pyspark.sql.functions import rand
+
+__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
class ParamGridBuilder(object):
@@ -79,6 +85,179 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
+class CrossValidator(Estimator):
+ """
+ K-fold cross validation.
+
+ >>> from pyspark.ml.classification import LogisticRegression
+ >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> dataset = sqlContext.createDataFrame(
+ ... [(Vectors.dense([0.0, 1.0]), 0.0),
+ ... (Vectors.dense([1.0, 2.0]), 1.0),
+ ... (Vectors.dense([0.55, 3.0]), 0.0),
+ ... (Vectors.dense([0.45, 4.0]), 1.0),
+ ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+ ... ["features", "label"])
+ >>> lr = LogisticRegression()
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+ >>> evaluator = BinaryClassificationEvaluator()
+ >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ >>> cvModel = cv.fit(dataset)
+ >>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
+ >>> cvModel.transform(dataset).collect() == expected.collect()
+ True
+ """
+
+ # a placeholder to make it appear in the generated doc
+ estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+
+ # a placeholder to make it appear in the generated doc
+ estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+
+ # a placeholder to make it appear in the generated doc
+ evaluator = Param(
+ Params._dummy(), "evaluator",
+ "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+
+ # a placeholder to make it appear in the generated doc
+ numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
+
+ @keyword_only
+ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ """
+ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
+ """
+ super(CrossValidator, self).__init__()
+ #: param for estimator to be cross-validated
+ self.estimator = Param(self, "estimator", "estimator to be cross-validated")
+ #: param for estimator param maps
+ self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
+ #: param for the evaluator used to select hyper-parameters that
+ #: maximize the cross-validated metric
+ self.evaluator = Param(
+ self, "evaluator",
+ "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+ #: param for number of folds for cross validation
+ self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
+ self._setDefault(numFolds=3)
+ kwargs = self.__init__._input_kwargs
+ self._set(**kwargs)
+
+ @keyword_only
+ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ """
+ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ Sets params for cross validator.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setEstimator(self, value):
+ """
+ Sets the value of :py:attr:`estimator`.
+ """
+ self.paramMap[self.estimator] = value
+ return self
+
+ def getEstimator(self):
+ """
+ Gets the value of estimator or its default value.
+ """
+ return self.getOrDefault(self.estimator)
+
+ def setEstimatorParamMaps(self, value):
+ """
+ Sets the value of :py:attr:`estimatorParamMaps`.
+ """
+ self.paramMap[self.estimatorParamMaps] = value
+ return self
+
+ def getEstimatorParamMaps(self):
+ """
+ Gets the value of estimatorParamMaps or its default value.
+ """
+ return self.getOrDefault(self.estimatorParamMaps)
+
+ def setEvaluator(self, value):
+ """
+ Sets the value of :py:attr:`evaluator`.
+ """
+ self.paramMap[self.evaluator] = value
+ return self
+
+ def getEvaluator(self):
+ """
+ Gets the value of evaluator or its default value.
+ """
+ return self.getOrDefault(self.evaluator)
+
+ def setNumFolds(self, value):
+ """
+ Sets the value of :py:attr:`numFolds`.
+ """
+ self.paramMap[self.numFolds] = value
+ return self
+
+ def getNumFolds(self):
+ """
+ Gets the value of numFolds or its default value.
+ """
+ return self.getOrDefault(self.numFolds)
+
+ def fit(self, dataset, params={}):
+ paramMap = self.extractParamMap(params)
+ est = paramMap[self.estimator]
+ epm = paramMap[self.estimatorParamMaps]
+ numModels = len(epm)
+ eva = paramMap[self.evaluator]
+ nFolds = paramMap[self.numFolds]
+ h = 1.0 / nFolds
+ randCol = self.uid + "_rand"
+ df = dataset.select("*", rand(0).alias(randCol))
+ metrics = np.zeros(numModels)
+ for i in range(nFolds):
+ validateLB = i * h
+ validateUB = (i + 1) * h
+ condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
+ validation = df.filter(condition)
+ train = df.filter(~condition)
+ for j in range(numModels):
+ model = est.fit(train, epm[j])
+ # TODO: duplicate evaluator to take extra params from input
+ metric = eva.evaluate(model.transform(validation, epm[j]))
+ metrics[j] += metric
+ bestIndex = np.argmax(metrics)
+ bestModel = est.fit(dataset, epm[bestIndex])
+ return CrossValidatorModel(bestModel)
+
+
+class CrossValidatorModel(Model):
+ """
+ Model from k-fold cross validation.
+ """
+
+ def __init__(self, bestModel):
+ #: best model from cross validation
+ self.bestModel = bestModel
+
+ def transform(self, dataset, params={}):
+ return self.bestModel.transform(dataset, params)
+
+
if __name__ == "__main__":
import doctest
- doctest.testmod()
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext("local[2]", "ml.tuning tests")
+ sqlContext = SQLContext(sc)
+ globs['sc'] = sc
+ globs['sqlContext'] = sqlContext
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 73741c4b40..0634254bbd 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -20,7 +20,7 @@ from abc import ABCMeta
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
+from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
from pyspark.mllib.common import inherit_doc
@@ -133,7 +133,7 @@ class JavaTransformer(Transformer, JavaWrapper):
@inherit_doc
-class JavaModel(JavaTransformer):
+class JavaModel(Model, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.