aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIlya Matiach <ilmat@microsoft.com>2017-01-18 15:33:41 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-01-18 15:33:41 -0800
commitfe409f31d966d99fcf57137581d1fb682c1c072a (patch)
tree3dcf3ff29c501a756957c8fc6af5a8281c809861
parenta81e336f1eddc2c6245d807aae2c81ddc60eabf9 (diff)
downloadspark-fe409f31d966d99fcf57137581d1fb682c1c072a.tar.gz
spark-fe409f31d966d99fcf57137581d1fb682c1c072a.tar.bz2
spark-fe409f31d966d99fcf57137581d1fb682c1c072a.zip
[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces
## What changes were proposed in this pull request? For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier. Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug. This change corrects the interface and adds the ability for the classifier to give a probabilities vector. ## How was this patch tested? The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests. Author: Ilya Matiach <ilmat@microsoft.com> Closes #16441 from imatiach-msft/ilmat/fix-GBT.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala94
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala161
5 files changed, 248 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c9bbd37a67..ade0960f87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -23,9 +23,8 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
@@ -33,6 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
@Since("1.4.0")
class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
+ extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
@@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val numClasses = 2
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
instr.logNumFeatures(numFeatures)
- instr.logNumClasses(2)
+ instr.logNumClasses(numClasses)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
@@ -202,8 +209,9 @@ class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
- @Since("1.6.0") override val numFeatures: Int)
- extends PredictionModel[Vector, GBTClassificationModel]
+ @Since("1.6.0") override val numFeatures: Int,
+ @Since("2.2.0") override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {
@@ -216,10 +224,24 @@ class GBTClassificationModel private[ml](
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
+ * @param numFeatures The number of features.
+ */
+ private[ml] def this(
+ uid: String,
+ _trees: Array[DecisionTreeRegressionModel],
+ _treeWeights: Array[Double],
+ numFeatures: Int) =
+ this(uid, _trees, _treeWeights, numFeatures, 2)
+
+ /**
+ * Construct a GBTClassificationModel
+ *
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
- this(uid, _trees, _treeWeights, -1)
+ this(uid, _trees, _treeWeights, -1, 2)
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
@@ -242,11 +264,29 @@ class GBTClassificationModel private[ml](
}
override protected def predict(features: Vector): Double = {
- // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
- // Classifies by thresholding sum of weighted tree predictions
- val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
- val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
- if (prediction > 0.0) 1.0 else 0.0
+ // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
+ if (isDefined(thresholds)) {
+ super.predict(features)
+ } else {
+ if (margin(features) > 0.0) 1.0 else 0.0
+ }
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ val prediction: Double = margin(features)
+ Vectors.dense(Array(-prediction, prediction))
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ dv.values(0) = loss.computeProbability(dv.values(0))
+ dv.values(1) = 1.0 - dv.values(0)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
}
/** Number of trees in ensemble */
@@ -254,7 +294,7 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
- copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
extra).setParent(parent)
}
@@ -276,11 +316,20 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+ /** Raw prediction for the positive class. */
+ private def margin(features: Vector): Double = {
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
+ blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ }
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
+ // hard coded loss, which is not meant to be changed in the model
+ private val loss = getOldLossType
+
@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
@@ -288,6 +337,9 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+ private val numFeaturesKey: String = "numFeatures"
+ private val numTreesKey: String = "numTrees"
+
@Since("2.0.0")
override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
@@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map(
- "numFeatures" -> instance.numFeatures,
- "numTrees" -> instance.getNumTrees)
+ numFeaturesKey -> instance.numFeatures,
+ numTreesKey -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
@@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
- val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
- val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+ val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
+ val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
@@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
}
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
- val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
+ val model = new GBTClassificationModel(metadata.uid,
+ trees, treeWeights, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
@@ -339,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
- numFeatures: Int = -1): GBTClassificationModel = {
+ numFeatures: Int = -1,
+ numClasses: Int = 2): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -347,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index c7a8f76eca..5eb707dfe7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
+ override private[ml] def getOldLossType: OldClassificationLoss = {
getLossType match {
case "logistic" => OldLogLoss
case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 5d92ce495b..9339f0a23c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.util.MLUtils
-
/**
* :: DeveloperApi ::
* Class for log loss calculation (for classification).
@@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils
*/
@Since("1.2.0")
@DeveloperApi
-object LogLoss extends Loss {
+object LogLoss extends ClassificationLoss {
/**
* Method to calculate the loss gradients for the gradient boosting calculation for binary
@@ -52,4 +51,11 @@ object LogLoss extends Loss {
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
+
+ /**
+ * Returns the estimated probability of a label of 1.0.
+ */
+ override private[spark] def computeProbability(margin: Double): Double = {
+ 1.0 / (1.0 + math.exp(-2.0 * margin))
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 09274a2e1b..e7ffb3f8f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
-
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -67,3 +66,10 @@ trait Loss extends Serializable {
*/
private[spark] def computeError(prediction: Double, label: Double): Double
}
+
+private[spark] trait ClassificationLoss extends Loss {
+ /**
+ * Computes the class probability given the margin.
+ */
+ private[spark] def computeProbability(margin: Double): Double
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 7c36745ab2..0598943c3d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,20 +17,24 @@
package org.apache.spark.ml.classification
+import com.github.fommil.netlib.BLAS
+
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils
/**
@@ -49,6 +53,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
private var data: RDD[LabeledPoint] = _
private var trainData: RDD[LabeledPoint] = _
private var validationData: RDD[LabeledPoint] = _
+ private val eps: Double = 1e-5
+ private val absEps: Double = 1e-8
override def beforeAll() {
super.beforeAll()
@@ -66,10 +72,156 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
- Array(1.0), 1)
+ Array(1.0), 1, 2)
ParamsSuite.checkParams(model)
}
+ test("GBTClassifier: default params") {
+ val gbt = new GBTClassifier
+ assert(gbt.getLabelCol === "label")
+ assert(gbt.getFeaturesCol === "features")
+ assert(gbt.getPredictionCol === "prediction")
+ assert(gbt.getRawPredictionCol === "rawPrediction")
+ assert(gbt.getProbabilityCol === "probability")
+ val df = trainData.toDF()
+ val model = gbt.fit(df)
+ model.transform(df)
+ .select("label", "probability", "prediction", "rawPrediction")
+ .collect()
+ intercept[NoSuchElementException] {
+ model.getThresholds
+ }
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.getRawPredictionCol === "rawPrediction")
+ assert(model.getProbabilityCol === "probability")
+ assert(model.hasParent)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+ }
+
+ test("setThreshold, getThreshold") {
+ val gbt = new GBTClassifier
+
+ // default
+ withClue("GBTClassifier should not have thresholds set by default.") {
+ intercept[NoSuchElementException] {
+ gbt.getThresholds
+ }
+ }
+
+ // Set via thresholds
+ val gbt2 = new GBTClassifier
+ val threshold = Array(0.3, 0.7)
+ gbt2.setThresholds(threshold)
+ assert(gbt2.getThresholds === threshold)
+ }
+
+ test("thresholds prediction") {
+ val gbt = new GBTClassifier
+ val df = trainData.toDF()
+ val binaryModel = gbt.fit(df)
+
+ // should predict all zeros
+ binaryModel.setThresholds(Array(0.0, 1.0))
+ val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
+ assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+
+ // should predict all ones
+ binaryModel.setThresholds(Array(1.0, 0.0))
+ val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
+ assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
+
+
+ val gbtBase = new GBTClassifier
+ val model = gbtBase.fit(df)
+ val basePredictions = model.transform(df).select("prediction").collect()
+
+ // constant threshold scaling is the same as no thresholds
+ binaryModel.setThresholds(Array(1.0, 1.0))
+ val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
+ assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+ scaled.getDouble(0) === base.getDouble(0)
+ })
+
+ // force it to use the predict method
+ model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
+ val predictionsWithPredict = model.transform(df).select("prediction").collect()
+ assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+ }
+
+ test("GBTClassifier: Predictor, Classifier methods") {
+ val rawPredictionCol = "rawPrediction"
+ val predictionCol = "prediction"
+ val labelCol = "label"
+ val featuresCol = "features"
+ val probabilityCol = "probability"
+
+ val gbt = new GBTClassifier().setSeed(123)
+ val trainingDataset = trainData.toDF(labelCol, featuresCol)
+ val gbtModel = gbt.fit(trainingDataset)
+ assert(gbtModel.numClasses === 2)
+ val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size
+ assert(gbtModel.numFeatures === numFeatures)
+
+ val blas = BLAS.getInstance()
+
+ val validationDataset = validationData.toDF(labelCol, featuresCol)
+ val results = gbtModel.transform(validationDataset)
+ // check that raw prediction is tree predictions dot tree weights
+ results.select(rawPredictionCol, featuresCol).collect().foreach {
+ case Row(raw: Vector, features: Vector) =>
+ assert(raw.size === 2)
+ val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
+ val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
+ assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
+ }
+
+ // Compare rawPrediction with probability
+ results.select(rawPredictionCol, probabilityCol).collect().foreach {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 2)
+ assert(prob.size === 2)
+ // Note: we should check other loss types for classification if they are added
+ val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
+ assert(prob(0) ~== predFromRaw(0) relTol eps)
+ assert(prob(1) ~== predFromRaw(1) relTol eps)
+ assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
+ }
+
+ // Compare prediction with probability
+ results.select(predictionCol, probabilityCol).collect().foreach {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
+
+ // force it to use raw2prediction
+ gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
+ val resultsUsingRaw2Predict =
+ gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+ resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+ case (pred1, pred2) => assert(pred1 === pred2)
+ }
+
+ // force it to use probability2prediction
+ gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
+ val resultsUsingProb2Predict =
+ gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+ resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+ case (pred1, pred2) => assert(pred1 === pred2)
+ }
+
+ // force it to use predict
+ gbtModel.setRawPredictionCol("").setProbabilityCol("")
+ val resultsUsingPredict =
+ gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
+ resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
+ case (pred1, pred2) => assert(pred1 === pred2)
+ }
+ }
+
test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
@@ -246,7 +398,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
- oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
+ oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures,
+ numFeatures, numClasses = 2)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.numFeatures === numFeatures)
assert(oldModelAsNew.numFeatures === numFeatures)