From fe409f31d966d99fcf57137581d1fb682c1c072a Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 Jan 2017 15:33:41 -0800 Subject: [SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces ## What changes were proposed in this pull request? For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier. Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug. This change corrects the interface and adds the ability for the classifier to give a probabilities vector. ## How was this patch tested? The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests. Author: Ilya Matiach Closes #16441 from imatiach-msft/ilmat/fix-GBT. --- .../spark/ml/classification/GBTClassifier.scala | 94 +++++++++--- .../org/apache/spark/ml/tree/treeParams.scala | 4 +- .../org/apache/spark/mllib/tree/loss/LogLoss.scala | 10 +- .../org/apache/spark/mllib/tree/loss/Loss.scala | 8 +- .../ml/classification/GBTClassifierSuite.scala | 161 ++++++++++++++++++++- 5 files changed, 248 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c9bbd37a67..ade0960f87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -23,9 +23,8 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ @@ -33,6 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) - extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel] with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") @@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val numClasses = 2 + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) instr.logNumFeatures(numFeatures) - instr.logNumClasses(2) + instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) @@ -202,8 +209,9 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - @Since("1.6.0") override val numFeatures: Int) - extends PredictionModel[Vector, GBTClassificationModel] + @Since("1.6.0") override val numFeatures: Int, + @Since("2.2.0") override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { @@ -211,6 +219,20 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + */ + private[ml] def this( + uid: String, + _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double], + numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) + /** * Construct a GBTClassificationModel * @@ -219,7 +241,7 @@ class GBTClassificationModel private[ml]( */ @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1) + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -242,11 +264,29 @@ class GBTClassificationModel private[ml]( } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 - // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization + if (isDefined(thresholds)) { + super.predict(features) + } else { + if (margin(features) > 0.0) 1.0 else 0.0 + } + } + + override protected def predictRaw(features: Vector): Vector = { + val prediction: Double = margin(features) + Vectors.dense(Array(-prediction, prediction)) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + dv.values(0) = loss.computeProbability(dv.values(0)) + dv.values(1) = 1.0 - dv.values(0) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in GBTClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } /** Number of trees in ensemble */ @@ -254,7 +294,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), extra).setParent(parent) } @@ -276,11 +316,20 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** Raw prediction for the positive class. */ + private def margin(features: Vector): Double = { + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + } + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + // hard coded loss, which is not meant to be changed in the model + private val loss = getOldLossType + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } @@ -288,6 +337,9 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + private val numFeaturesKey: String = "numFeatures" + private val numTreesKey: String = "numTrees" + @Since("2.0.0") override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader @@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override protected def saveImpl(path: String): Unit = { val extraMetadata: JObject = Map( - "numFeatures" -> instance.numFeatures, - "numTrees" -> instance.getNumTrees) + numFeaturesKey -> instance.numFeatures, + numTreesKey -> instance.getNumTrees) EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) - val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val numTrees = (metadata.metadata \ "numTrees").extract[Int] + val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] + val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => @@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") - val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + val model = new GBTClassificationModel(metadata.uid, + trees, treeWeights, numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -339,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], - numFeatures: Int = -1): GBTClassificationModel = { + numFeatures: Int = -1, + numClasses: Int = 2): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -347,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c7a8f76eca..5eb707dfe7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { + override private[ml] def getOldLossType: OldClassificationLoss = { getLossType match { case "logistic" => OldLogLoss case _ => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 5d92ce495b..9339f0a23c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.util.MLUtils - /** * :: DeveloperApi :: * Class for log loss calculation (for classification). @@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils */ @Since("1.2.0") @DeveloperApi -object LogLoss extends Loss { +object LogLoss extends ClassificationLoss { /** * Method to calculate the loss gradients for the gradient boosting calculation for binary @@ -52,4 +51,11 @@ object LogLoss extends Loss { // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } + + /** + * Returns the estimated probability of a label of 1.0. + */ + override private[spark] def computeProbability(margin: Double): Double = { + 1.0 / (1.0 + math.exp(-2.0 * margin)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 09274a2e1b..e7ffb3f8f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD - /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -67,3 +66,10 @@ trait Loss extends Serializable { */ private[spark] def computeError(prediction: Double, label: Double): Double } + +private[spark] trait ClassificationLoss extends Loss { + /** + * Computes the class probability given the margin. + */ + private[spark] def computeProbability(margin: Double): Double +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 7c36745ab2..0598943c3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,20 +17,24 @@ package org.apache.spark.ml.classification +import com.github.fommil.netlib.BLAS + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils /** @@ -49,6 +53,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext private var data: RDD[LabeledPoint] = _ private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ + private val eps: Double = 1e-5 + private val absEps: Double = 1e-8 override def beforeAll() { super.beforeAll() @@ -66,10 +72,156 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), - Array(1.0), 1) + Array(1.0), 1, 2) ParamsSuite.checkParams(model) } + test("GBTClassifier: default params") { + val gbt = new GBTClassifier + assert(gbt.getLabelCol === "label") + assert(gbt.getFeaturesCol === "features") + assert(gbt.getPredictionCol === "prediction") + assert(gbt.getRawPredictionCol === "rawPrediction") + assert(gbt.getProbabilityCol === "probability") + val df = trainData.toDF() + val model = gbt.fit(df) + model.transform(df) + .select("label", "probability", "prediction", "rawPrediction") + .collect() + intercept[NoSuchElementException] { + model.getThresholds + } + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") + assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("setThreshold, getThreshold") { + val gbt = new GBTClassifier + + // default + withClue("GBTClassifier should not have thresholds set by default.") { + intercept[NoSuchElementException] { + gbt.getThresholds + } + } + + // Set via thresholds + val gbt2 = new GBTClassifier + val threshold = Array(0.3, 0.7) + gbt2.setThresholds(threshold) + assert(gbt2.getThresholds === threshold) + } + + test("thresholds prediction") { + val gbt = new GBTClassifier + val df = trainData.toDF() + val binaryModel = gbt.fit(df) + + // should predict all zeros + binaryModel.setThresholds(Array(0.0, 1.0)) + val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + + // should predict all ones + binaryModel.setThresholds(Array(1.0, 0.0)) + val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) + + + val gbtBase = new GBTClassifier + val model = gbtBase.fit(df) + val basePredictions = model.transform(df).select("prediction").collect() + + // constant threshold scaling is the same as no thresholds + binaryModel.setThresholds(Array(1.0, 1.0)) + val scaledPredictions = binaryModel.transform(df).select("prediction").collect() + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + + // force it to use the predict method + model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) + val predictionsWithPredict = model.transform(df).select("prediction").collect() + assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + } + + test("GBTClassifier: Predictor, Classifier methods") { + val rawPredictionCol = "rawPrediction" + val predictionCol = "prediction" + val labelCol = "label" + val featuresCol = "features" + val probabilityCol = "probability" + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF(labelCol, featuresCol) + val gbtModel = gbt.fit(trainingDataset) + assert(gbtModel.numClasses === 2) + val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size + assert(gbtModel.numFeatures === numFeatures) + + val blas = BLAS.getInstance() + + val validationDataset = validationData.toDF(labelCol, featuresCol) + val results = gbtModel.transform(validationDataset) + // check that raw prediction is tree predictions dot tree weights + results.select(rawPredictionCol, featuresCol).collect().foreach { + case Row(raw: Vector, features: Vector) => + assert(raw.size === 2) + val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) + val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) + assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) + } + + // Compare rawPrediction with probability + results.select(rawPredictionCol, probabilityCol).collect().foreach { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + // Note: we should check other loss types for classification if they are added + val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) + assert(prob(0) ~== predFromRaw(0) relTol eps) + assert(prob(1) ~== predFromRaw(1) relTol eps) + assert(prob(0) + prob(1) ~== 1.0 absTol absEps) + } + + // Compare prediction with probability + results.select(predictionCol, probabilityCol).collect().foreach { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } + + // force it to use raw2prediction + gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") + val resultsUsingRaw2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) + val resultsUsingProb2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use predict + gbtModel.setRawPredictionCol("").setProbabilityCol("") + val resultsUsingPredict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { @@ -246,7 +398,8 @@ private object GBTClassifierSuite extends SparkFunSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, + numFeatures, numClasses = 2) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.numFeatures === numFeatures) assert(oldModelAsNew.numFeatures === numFeatures) -- cgit v1.2.3