aboutsummaryrefslogblamecommitdiff
path: root/python/pyspark/ml/tuning.py
blob: 86f4dc7368be07df883f5dff6e2a539739eeec70 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
















                                                                          
                
                  
 





                                                                       


                               
        

                                                                       
                                                                
                                 










                                                                                       



                                            






























                                                                                    


                                                                                  

 

















                                                                                       



                                                                        


































































































































                                                                                                 
                                                   






                                                        

                          













                                                                     
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import itertools
import numpy as np

from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']


class ParamGridBuilder(object):
    r"""
    Builder for a param grid used in grid search-based model selection.

    >>> from pyspark.ml.classification import LogisticRegression
    >>> lr = LogisticRegression()
    >>> output = ParamGridBuilder() \
    ...     .baseOn({lr.labelCol: 'l'}) \
    ...     .baseOn([lr.predictionCol, 'p']) \
    ...     .addGrid(lr.regParam, [1.0, 2.0]) \
    ...     .addGrid(lr.maxIter, [1, 5]) \
    ...     .build()
    >>> expected = [
    ...     {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
    ...     {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
    ...     {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
    ...     {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
    >>> len(output) == len(expected)
    True
    >>> all([m in expected for m in output])
    True
    """

    def __init__(self):
        self._param_grid = {}

    def addGrid(self, param, values):
        """
        Sets the given parameters in this grid to fixed values.
        """
        self._param_grid[param] = values

        return self

    def baseOn(self, *args):
        """
        Sets the given parameters in this grid to fixed values.
        Accepts either a parameter dictionary or a list of (parameter, value) pairs.
        """
        if isinstance(args[0], dict):
            self.baseOn(*args[0].items())
        else:
            for (param, value) in args:
                self.addGrid(param, [value])

        return self

    def build(self):
        """
        Builds and returns all combinations of parameters specified
        by the param grid.
        """
        keys = self._param_grid.keys()
        grid_values = self._param_grid.values()
        return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]


class CrossValidator(Estimator):
    """
    K-fold cross validation.

    >>> from pyspark.ml.classification import LogisticRegression
    >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
    >>> from pyspark.mllib.linalg import Vectors
    >>> dataset = sqlContext.createDataFrame(
    ...     [(Vectors.dense([0.0, 1.0]), 0.0),
    ...      (Vectors.dense([1.0, 2.0]), 1.0),
    ...      (Vectors.dense([0.55, 3.0]), 0.0),
    ...      (Vectors.dense([0.45, 4.0]), 1.0),
    ...      (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
    ...     ["features", "label"])
    >>> lr = LogisticRegression()
    >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
    >>> evaluator = BinaryClassificationEvaluator()
    >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
    >>> # SPARK-7432: The following test is flaky.
    >>> # cvModel = cv.fit(dataset)
    >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
    >>> # cvModel.transform(dataset).collect() == expected.collect()
    """

    # a placeholder to make it appear in the generated doc
    estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")

    # a placeholder to make it appear in the generated doc
    estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")

    # a placeholder to make it appear in the generated doc
    evaluator = Param(
        Params._dummy(), "evaluator",
        "evaluator used to select hyper-parameters that maximize the cross-validated metric")

    # a placeholder to make it appear in the generated doc
    numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")

    @keyword_only
    def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
        """
        __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
        """
        super(CrossValidator, self).__init__()
        #: param for estimator to be cross-validated
        self.estimator = Param(self, "estimator", "estimator to be cross-validated")
        #: param for estimator param maps
        self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
        #: param for the evaluator used to select hyper-parameters that
        #: maximize the cross-validated metric
        self.evaluator = Param(
            self, "evaluator",
            "evaluator used to select hyper-parameters that maximize the cross-validated metric")
        #: param for number of folds for cross validation
        self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
        self._setDefault(numFolds=3)
        kwargs = self.__init__._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
        """
        setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
        Sets params for cross validator.
        """
        kwargs = self.setParams._input_kwargs
        return self._set(**kwargs)

    def setEstimator(self, value):
        """
        Sets the value of :py:attr:`estimator`.
        """
        self.paramMap[self.estimator] = value
        return self

    def getEstimator(self):
        """
        Gets the value of estimator or its default value.
        """
        return self.getOrDefault(self.estimator)

    def setEstimatorParamMaps(self, value):
        """
        Sets the value of :py:attr:`estimatorParamMaps`.
        """
        self.paramMap[self.estimatorParamMaps] = value
        return self

    def getEstimatorParamMaps(self):
        """
        Gets the value of estimatorParamMaps or its default value.
        """
        return self.getOrDefault(self.estimatorParamMaps)

    def setEvaluator(self, value):
        """
        Sets the value of :py:attr:`evaluator`.
        """
        self.paramMap[self.evaluator] = value
        return self

    def getEvaluator(self):
        """
        Gets the value of evaluator or its default value.
        """
        return self.getOrDefault(self.evaluator)

    def setNumFolds(self, value):
        """
        Sets the value of :py:attr:`numFolds`.
        """
        self.paramMap[self.numFolds] = value
        return self

    def getNumFolds(self):
        """
        Gets the value of numFolds or its default value.
        """
        return self.getOrDefault(self.numFolds)

    def fit(self, dataset, params={}):
        paramMap = self.extractParamMap(params)
        est = paramMap[self.estimator]
        epm = paramMap[self.estimatorParamMaps]
        numModels = len(epm)
        eva = paramMap[self.evaluator]
        nFolds = paramMap[self.numFolds]
        h = 1.0 / nFolds
        randCol = self.uid + "_rand"
        df = dataset.select("*", rand(0).alias(randCol))
        metrics = np.zeros(numModels)
        for i in range(nFolds):
            validateLB = i * h
            validateUB = (i + 1) * h
            condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
            validation = df.filter(condition)
            train = df.filter(~condition)
            for j in range(numModels):
                model = est.fit(train, epm[j])
                # TODO: duplicate evaluator to take extra params from input
                metric = eva.evaluate(model.transform(validation, epm[j]))
                metrics[j] += metric
        bestIndex = np.argmax(metrics)
        bestModel = est.fit(dataset, epm[bestIndex])
        return CrossValidatorModel(bestModel)


class CrossValidatorModel(Model):
    """
    Model from k-fold cross validation.
    """

    def __init__(self, bestModel):
        super(CrossValidatorModel, self).__init__()
        #: best model from cross validation
        self.bestModel = bestModel

    def transform(self, dataset, params={}):
        return self.bestModel.transform(dataset, params)


if __name__ == "__main__":
    import doctest
    from pyspark.context import SparkContext
    from pyspark.sql import SQLContext
    globs = globals().copy()
    # The small batch size here ensures that we see multiple batches,
    # even in these small test examples:
    sc = SparkContext("local[2]", "ml.tuning tests")
    sqlContext = SQLContext(sc)
    globs['sc'] = sc
    globs['sqlContext'] = sqlContext
    (failure_count, test_count) = doctest.testmod(
        globs=globs, optionflags=doctest.ELLIPSIS)
    sc.stop()
    if failure_count:
        exit(-1)