aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-15 11:50:34 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-03-15 11:50:34 +0200
commitdafd70fbfe70702502ef198f2a8f529ef7557592 (patch)
tree64fae7f84af42e8201775197479191120d89ca61 /mllib
parent10251a7457e44fa33f13274ccb2266bc5e931363 (diff)
downloadspark-dafd70fbfe70702502ef198f2a8f529ef7557592.tar.gz
spark-dafd70fbfe70702502ef198f2a8f529ef7557592.tar.bz2
spark-dafd70fbfe70702502ef198f2a8f529ef7557592.zip
[SPARK-12379][ML][MLLIB] Copy GBT implementation to spark.ml
Currently, GBTs in spark.ml wrap the implementation in spark.mllib. This is preventing several improvements to GBTs in spark.ml, so we need to move the implementation to ml and use spark.ml decision trees in the implementation. At first, we should make minimal changes to the implementation. Performance testing should be done to ensure there were no regressions. Performance testing results are [here](https://docs.google.com/document/d/1dYd2mnfGdUKkQ3vZe2BpzsTnI5IrpSLQ-NNKDZhUkgw/edit?usp=sharing) Author: sethah <seth.hendrickson16@gmail.com> Closes #10607 from sethah/SPARK-12379.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala277
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala2
12 files changed, 306 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7f0397f6bd..bcbedc8bc1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") (
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
+ /** (private[ml]) Train a decision tree on an RDD */
+ private[ml] def train(data: RDD[LabeledPoint],
+ oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
+ val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 0L, parentUID = Some(uid))
+ trees.head.asInstanceOf[DecisionTreeClassificationModel]
+ }
+
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index e0ffbedf6c..82059b1d0e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
TreeEnsembleModel}
+import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
@@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
- val oldGBT = new OldGBT(boostingStrategy)
- val oldModel = oldGBT.run(oldDataset)
- GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}
@Since("1.4.1")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 897b23383c..6e46292451 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
+ /** (private[ml]) Train a decision tree on an RDD */
+ private[ml] def train(data: RDD[LabeledPoint],
+ oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
+ val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 0L, parentUID = Some(uid))
+ trees.head.asInstanceOf[DecisionTreeRegressionModel]
+ }
+
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 9c842a6c88..4cc2721aef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
TreeRegressorParams}
+import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
SquaredError => OldSquaredError}
@@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val oldGBT = new OldGBT(boostingStrategy)
- val oldModel = oldGBT.run(oldDataset)
- GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}
@Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
new file mode 100644
index 0000000000..44ab5b723b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+private[ml] object GradientBoostedTrees extends Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def run(input: RDD[LabeledPoint],
+ boostingStrategy: OldBoostingStrategy
+ ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case OldAlgo.Regression =>
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
+ case OldAlgo.Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
+ }
+ }
+
+ /**
+ * Method to validate a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def runWithValidation(
+ input: RDD[LabeledPoint],
+ validationInput: RDD[LabeledPoint],
+ boostingStrategy: OldBoostingStrategy
+ ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case OldAlgo.Regression =>
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
+ case OldAlgo.Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val remappedValidationInput = validationInput.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
+ validate = true)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ /**
+ * Compute the initial predictions and errors for a dataset for the first
+ * iteration of gradient boosting.
+ * @param data: training data.
+ * @param initTreeWeight: learning rate assigned to the first tree.
+ * @param initTree: first DecisionTreeModel.
+ * @param loss: evaluation metric.
+ * @return a RDD with each element being a zip of the prediction and error
+ * corresponding to every sample.
+ */
+ def computeInitialPredictionAndError(
+ data: RDD[LabeledPoint],
+ initTreeWeight: Double,
+ initTree: DecisionTreeRegressionModel,
+ loss: OldLoss): RDD[(Double, Double)] = {
+ data.map { lp =>
+ val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
+ val error = loss.computeError(pred, lp.label)
+ (pred, error)
+ }
+ }
+
+ /**
+ * Update a zipped predictionError RDD
+ * (as obtained with computeInitialPredictionAndError)
+ * @param data: training data.
+ * @param predictionAndError: predictionError RDD
+ * @param treeWeight: Learning rate.
+ * @param tree: Tree using which the prediction and error should be updated.
+ * @param loss: evaluation metric.
+ * @return a RDD with each element being a zip of the prediction and error
+ * corresponding to each sample.
+ */
+ def updatePredictionError(
+ data: RDD[LabeledPoint],
+ predictionAndError: RDD[(Double, Double)],
+ treeWeight: Double,
+ tree: DecisionTreeRegressionModel,
+ loss: OldLoss): RDD[(Double, Double)] = {
+
+ val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
+ iter.map { case (lp, (pred, error)) =>
+ val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
+ }
+ }
+ newPredError
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param validationInput validation dataset, ignored if validate is set to false.
+ * @param boostingStrategy boosting parameters
+ * @param validate whether or not to use the validation dataset.
+ * @return tuple of ensemble models and weights:
+ * (array of decision tree models, array of model weights)
+ */
+ def boost(
+ input: RDD[LabeledPoint],
+ validationInput: RDD[LabeledPoint],
+ boostingStrategy: OldBoostingStrategy,
+ validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ boostingStrategy.assertValid()
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ val validationTol = boostingStrategy.validationTol
+ treeStrategy.algo = OldAlgo.Regression
+ treeStrategy.impurity = OldVariance
+ treeStrategy.assertValid()
+
+ // Cache input
+ val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ true
+ } else {
+ false
+ }
+
+ // Prepare periodic checkpointers
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+ val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val firstTree = new DecisionTreeRegressor()
+ val firstTreeModel = firstTree.train(input, treeStrategy)
+ val firstTreeWeight = 1.0
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = firstTreeWeight
+
+ var predError: RDD[(Double, Double)] =
+ computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ predErrorCheckpointer.update(predError)
+ logDebug("error of gbt = " + predError.values.mean())
+
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ var validatePredError: RDD[(Double, Double)] =
+ computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+ if (validate) validatePredErrorCheckpointer.update(validatePredError)
+ var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
+ var bestM = 1
+
+ var m = 1
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ // Update data with pseudo-residuals
+ val data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val dt = new DecisionTreeRegressor()
+ val model = dt.train(data, treeStrategy)
+ timer.stop(s"building tree $m")
+ // Update partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+
+ predError = updatePredictionError(
+ input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ predErrorCheckpointer.update(predError)
+ logDebug("error of gbt = " + predError.values.mean())
+
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best validation error.
+
+ validatePredError = updatePredictionError(
+ validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+ validatePredErrorCheckpointer.update(validatePredError)
+ val currentValidateError = validatePredError.values.mean()
+ if (bestValidateError - currentValidateError < validationTol * Math.max(
+ currentValidateError, 0.01)) {
+ doneLearning = true
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
+ }
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ predErrorCheckpointer.deleteAllCheckpoints()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
+ if (persistedInput) input.unpersist()
+
+ if (validate) {
+ (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
+ } else {
+ (baseLearners, baseLearnerWeights)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
index f31ed2aa90..145dc22b74 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
@@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel
*
* TODO: Move this out of MLlib?
*/
-private[mllib] class PeriodicRDDCheckpointer[T](
+private[spark] class PeriodicRDDCheckpointer[T](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 0b118a7673..d8405d13ce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") (
* Check validity of parameters.
* Throws exception if invalid.
*/
- private[tree] def assertValid(): Unit = {
+ private[spark] def assertValid(): Unit = {
treeStrategy.algo match {
case Classification =>
require(treeStrategy.numClasses == 2,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 9e3e50192d..8a0907564e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") (
* Check validity of parameters.
* Throws exception if invalid.
*/
- private[tree] def assertValid(): Unit = {
+ private[spark] def assertValid(): Unit = {
algo match {
case Classification =>
require(numClasses >= 2,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 48a4e38a34..9b60d018d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -45,7 +45,7 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}
- override private[mllib] def computeError(prediction: Double, label: Double): Double = {
+ override private[spark] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index b88743c0db..5d92ce495b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -47,7 +47,7 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
- override private[mllib] def computeError(prediction: Double, label: Double): Double = {
+ override private[spark] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 687cde325f..de14ddf024 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -61,5 +61,5 @@ trait Loss extends Serializable {
* @param label True label.
* @return Measure of model error on datapoint.
*/
- private[mllib] def computeError(prediction: Double, label: Double): Double
+ private[spark] def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index cb97f6fd29..4eb6810c46 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -45,7 +45,7 @@ object SquaredError extends Loss {
- 2.0 * (label - prediction)
}
- override private[mllib] def computeError(prediction: Double, label: Double): Double = {
+ override private[spark] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
err * err
}