aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
blob: fecf372c3d843cff2032f6bdf8d33dc590330f13 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.tree.impl

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError}
import org.apache.spark.mllib.util.MLlibTestSparkContext

/**
 * Test suite for [[GradientBoostedTrees]].
 */
class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

  test("runWithValidation stops early and performs better on a validation dataset") {
    // Set numIterations large enough so that it stops early.
    val numIterations = 20
    val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
    val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
    val trainDF = sqlContext.createDataFrame(trainRdd)
    val validateDF = sqlContext.createDataFrame(validateRdd)

    val algos = Array(Regression, Regression, Classification)
    val losses = Array(SquaredError, AbsoluteError, LogLoss)
    algos.zip(losses).foreach { case (algo, loss) =>
      val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
        categoricalFeaturesInfo = Map.empty)
      val boostingStrategy =
        new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
      val (validateTrees, validateTreeWeights) = GradientBoostedTrees
        .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
      val numTrees = validateTrees.length
      assert(numTrees !== numIterations)

      // Test that it performs better on the validation dataset.
      val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
      val (errorWithoutValidation, errorWithValidation) = {
        if (algo == Classification) {
          val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
          (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
            GradientBoostedTrees.computeError(remappedRdd, validateTrees,
              validateTreeWeights, loss))
        } else {
          (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
            GradientBoostedTrees.computeError(validateRdd, validateTrees,
              validateTreeWeights, loss))
        }
      }
      assert(errorWithValidation <= errorWithoutValidation)

      // Test that results from evaluateEachIteration comply with runWithValidation.
      // Note that convergenceTol is set to 0.0
      val evaluationArray = GradientBoostedTrees
        .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo)
      assert(evaluationArray.length === numIterations)
      assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
      var i = 1
      while (i < numTrees) {
        assert(evaluationArray(i) <= evaluationArray(i - 1))
        i += 1
      }
    }
  }

}