From 8f50574ab4021b9984b0017cd47ba012a894c19a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 4 Apr 2016 20:12:09 -0700 Subject: [SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types ## What changes were proposed in this pull request? In spark.ml, GBT and RandomForest expose the trait DecisionTreeModel in the trees method, but they should not since it is a private trait (and not ready to be made public). It will also be more useful to users if we return the concrete types. This PR: return concrete types The MIMA checks appear to be OK with this change. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #12158 from jkbradley/hide-dtm. --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 7 +++---- .../spark/ml/classification/RandomForestClassifier.scala | 6 +++--- .../org/apache/spark/ml/regression/GBTRegressor.scala | 7 +++---- .../apache/spark/ml/regression/RandomForestRegressor.scala | 5 +++-- .../main/scala/org/apache/spark/ml/tree/treeModels.scala | 14 +++++++++----- 5 files changed, 21 insertions(+), 18 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index bfefaf1a1a..bee90fb3a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, - TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel with Serializable { + with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 2ad893f4fa..cb42532271 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable - with Serializable { + with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") @@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] ( this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeClassificationModel] = _trees // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 02e124a1c0..cef7c643d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, - TreeRegressorParams} +import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel with Serializable { + with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ba56b5cd3f..736cd9f776 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") @@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] ( this(Identifiable.randomUID("rfr"), trees, numFeatures) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 48b8fd19ad..db0ff28d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel { * Abstraction for models which are ensembles of decision trees * * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + * + * @tparam M Type of tree model in this ensemble */ -private[ml] trait TreeEnsembleModel { +private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of // DecisionTreeModel. /** Trees in this ensemble. Warning: These have null parent Estimators. */ - def trees: Array[DecisionTreeModel] + def trees: Array[M] /** * Number of trees in ensemble @@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel { * If -1, then numFeatures is set based on the max feature index in all trees. * @return Feature importance values, of length numFeatures. */ - def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel { * If -1, then numFeatures is set based on the max feature index in all trees. * @return Feature importance values, of length numFeatures. */ - def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = { featureImportances(Array(tree), numFeatures) } @@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite { * @param path Path to which to save the ensemble model. * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees. */ - def saveImpl[M <: Params with TreeEnsembleModel]( + def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]]( instance: M, path: String, sql: SQLContext, -- cgit v1.2.3