aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala2
6 files changed, 22 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index bfefaf1a1a..bee90fb3a5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
- TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
@@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
@@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 2ad893f4fa..cb42532271 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
- with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 02e124a1c0..cef7c643d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
- TreeRegressorParams}
+import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
@@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
@@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ba56b5cd3f..736cd9f776 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
@@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] (
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 48b8fd19ad..db0ff28d82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
@@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel {
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ *
+ * @tparam M Type of tree model in this ensemble
*/
-private[ml] trait TreeEnsembleModel {
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
// DecisionTreeModel.
/** Trees in this ensemble. Warning: These have null parent Estimators. */
- def trees: Array[DecisionTreeModel]
+ def trees: Array[M]
/**
* Number of trees in ensemble
@@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
- def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree
@@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
- def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+ def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
featureImportances(Array(tree), numFeatures)
}
@@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite {
* @param path Path to which to save the ensemble model.
* @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
*/
- def saveImpl[M <: Params with TreeEnsembleModel](
+ def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
instance: M,
path: String,
sql: SQLContext,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index bd5bd17147..b650a9f092 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -131,7 +131,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* Check if the two models are exactly the same.
* If the models are not equal, this throws an exception.
*/
- def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = {
try {
a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
TreeTests.checkEqual(treeA, treeB)