aboutsummaryrefslogblamecommitdiff
path: root/python/pyspark/ml/tests.py
blob: ba6478dcd58a902a2d20c24577d7bc45c5d67d2b (plain) (tree)


































                                                                                       
                                                           
                                                                       











                                               
                                               














                                                     
                                               












                                                     
                                        

































                                                                                                  















                                                               
                                                                        


























                                                                 
                                                                                                

 

                          
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Unit tests for Spark ML Python APIs.
"""

import sys

if sys.version_info[:2] <= (2, 6):
    try:
        import unittest2 as unittest
    except ImportError:
        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
        sys.exit(1)
else:
    import unittest

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer


class MockDataset(DataFrame):

    def __init__(self):
        self.index = 0


class MockTransformer(Transformer):

    def __init__(self):
        super(MockTransformer, self).__init__()
        self.fake = Param(self, "fake", "fake")
        self.dataset_index = None
        self.fake_param_value = None

    def transform(self, dataset, params={}):
        self.dataset_index = dataset.index
        if self.fake in params:
            self.fake_param_value = params[self.fake]
        dataset.index += 1
        return dataset


class MockEstimator(Estimator):

    def __init__(self):
        super(MockEstimator, self).__init__()
        self.fake = Param(self, "fake", "fake")
        self.dataset_index = None
        self.fake_param_value = None
        self.model = None

    def fit(self, dataset, params={}):
        self.dataset_index = dataset.index
        if self.fake in params:
            self.fake_param_value = params[self.fake]
        model = MockModel()
        self.model = model
        return model


class MockModel(MockTransformer, Model):

    def __init__(self):
        super(MockModel, self).__init__()


class PipelineTests(PySparkTestCase):

    def test_pipeline(self):
        dataset = MockDataset()
        estimator0 = MockEstimator()
        transformer1 = MockTransformer()
        estimator2 = MockEstimator()
        transformer3 = MockTransformer()
        pipeline = Pipeline() \
            .setStages([estimator0, transformer1, estimator2, transformer3])
        pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
        self.assertEqual(0, estimator0.dataset_index)
        self.assertEqual(0, estimator0.fake_param_value)
        model0 = estimator0.model
        self.assertEqual(0, model0.dataset_index)
        self.assertEqual(1, transformer1.dataset_index)
        self.assertEqual(1, transformer1.fake_param_value)
        self.assertEqual(2, estimator2.dataset_index)
        model2 = estimator2.model
        self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
                                                "not be called during fit.")
        dataset = pipeline_model.transform(dataset)
        self.assertEqual(2, model0.dataset_index)
        self.assertEqual(3, transformer1.dataset_index)
        self.assertEqual(4, model2.dataset_index)
        self.assertEqual(5, transformer3.dataset_index)
        self.assertEqual(6, dataset.index)


class TestParams(HasMaxIter, HasInputCol):
    """
    A subclass of Params mixed with HasMaxIter and HasInputCol.
    """

    def __init__(self):
        super(TestParams, self).__init__()
        self._setDefault(maxIter=10)


class ParamTests(PySparkTestCase):

    def test_param(self):
        testParams = TestParams()
        maxIter = testParams.maxIter
        self.assertEqual(maxIter.name, "maxIter")
        self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
        self.assertTrue(maxIter.parent is testParams)

    def test_params(self):
        testParams = TestParams()
        maxIter = testParams.maxIter
        inputCol = testParams.inputCol

        params = testParams.params
        self.assertEqual(params, [inputCol, maxIter])

        self.assertTrue(testParams.hasDefault(maxIter))
        self.assertFalse(testParams.isSet(maxIter))
        self.assertTrue(testParams.isDefined(maxIter))
        self.assertEqual(testParams.getMaxIter(), 10)
        testParams.setMaxIter(100)
        self.assertTrue(testParams.isSet(maxIter))
        self.assertEquals(testParams.getMaxIter(), 100)

        self.assertFalse(testParams.hasDefault(inputCol))
        self.assertFalse(testParams.isSet(inputCol))
        self.assertFalse(testParams.isDefined(inputCol))
        with self.assertRaises(KeyError):
            testParams.getInputCol()

        self.assertEquals(
            testParams.explainParams(),
            "\n".join(["inputCol: input column name (undefined)",
                       "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))


if __name__ == "__main__":
    unittest.main()