aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-04-06 17:13:34 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 17:13:34 -0700
commitbb873754b4700104755ab969694bf30945557dc3 (patch)
tree02b5b39b530827fea0871ade32e4b8927edb7e9a /mllib/src/test/scala/org/apache
parent864d1b4d665e2cc1d40b53502a4ddf26c1fbfc1d (diff)
downloadspark-bb873754b4700104755ab969694bf30945557dc3.tar.gz
spark-bb873754b4700104755ab969694bf30945557dc3.tar.bz2
spark-bb873754b4700104755ab969694bf30945557dc3.zip
[SPARK-12382][ML] Remove mllib GBT implementation and wrap ml
## What changes were proposed in this pull request? This patch removes the implementation of gradient boosted trees in mllib/tree/GradientBoostedTrees.scala and changes mllib GBTs to call the implementation in spark.ML. Primary changes: * Removed `boost` method in mllib GradientBoostedTrees.scala * Created new test suite GradientBoostedTreesSuite in ML, which contains unit tests that were specific to GBT internals from mllib Other changes: * Added an `updatePrediction` method in GradientBoostedTrees package. This method is added to provide consistency for methods that build predictions from boosted models. There are several methods that hard code the method of predicting as: sum_{i=1}^{numTrees} (treePrediction*treeWeight). Calling this function ensures that test methods that check accuracy use the same prediction method that the algorithm uses during training * Added methods that were previously only used in testing, but were public methods, to GradientBoostedTrees. This includes `computeError` (previously part of `Loss` trait) and `evaluateEachIteration`. These are used in the new spark.ML unit tests. They are left in mllib as well so as to not break the API. ## How was this patch tested? Existing unit tests which compare ML and MLlib ensure that mllib GBTs have not changed. Only a single unit test was moved to ML, which verifies that `runWithValidation` performs as expected. Author: sethah <seth.hendrickson16@gmail.com> Closes #12050 from sethah/SPARK-12382.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala45
3 files changed, 87 insertions, 45 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 914818f41f..3c11631f98 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -53,7 +53,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
- test("Regression with continuous features: SquaredError") {
+ test("Regression with continuous features") {
val categoricalFeatures = Map.empty[Int, Int]
GBTRegressor.supportedLossTypes.foreach { loss =>
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
new file mode 100644
index 0000000000..fecf372c3d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suite for [[GradientBoostedTrees]].
+ */
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
+ val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
+ val trainDF = sqlContext.createDataFrame(trainRdd)
+ val validateDF = sqlContext.createDataFrame(validateRdd)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val (validateTrees, validateTreeWeights) = GradientBoostedTrees
+ .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
+ val numTrees = validateTrees.length
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(remappedRdd, validateTrees,
+ validateTreeWeights, loss))
+ } else {
+ (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(validateRdd, validateTrees,
+ validateTreeWeights, loss))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 747c267b4f..c61f89322d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("runWithValidation stops early and performs better on a validation dataset") {
- // Set numIterations large enough so that it stops early.
- val numIterations = 20
- val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
- val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
-
- val algos = Array(Regression, Regression, Classification)
- val losses = Array(SquaredError, AbsoluteError, LogLoss)
- algos.zip(losses).foreach { case (algo, loss) =>
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
- }
- }
- }
-
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
-private object GradientBoostedTreesSuite {
+private[spark] object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))