aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-04 20:12:09 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-04 20:12:09 -0700
commit8f50574ab4021b9984b0017cd47ba012a894c19a (patch)
treef049ce43aa3852637f782d610ebb7762df983d60 /mllib/src/main
parentba24d1ee9a1d97ca82282f3b811ec011c4285b99 (diff)
downloadspark-8f50574ab4021b9984b0017cd47ba012a894c19a.tar.gz
spark-8f50574ab4021b9984b0017cd47ba012a894c19a.tar.bz2
spark-8f50574ab4021b9984b0017cd47ba012a894c19a.zip
[SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types
## What changes were proposed in this pull request? In spark.ml, GBT and RandomForest expose the trait DecisionTreeModel in the trees method, but they should not since it is a private trait (and not ready to be made public). It will also be more useful to users if we return the concrete types. This PR: return concrete types The MIMA checks appear to be OK with this change. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12158 from jkbradley/hide-dtm.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala14
5 files changed, 21 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index bfefaf1a1a..bee90fb3a5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
- TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
@@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
@@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 2ad893f4fa..cb42532271 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
- with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 02e124a1c0..cef7c643d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
- TreeRegressorParams}
+import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
@@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
@@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ba56b5cd3f..736cd9f776 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
@@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] (
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 48b8fd19ad..db0ff28d82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
@@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel {
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ *
+ * @tparam M Type of tree model in this ensemble
*/
-private[ml] trait TreeEnsembleModel {
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
// DecisionTreeModel.
/** Trees in this ensemble. Warning: These have null parent Estimators. */
- def trees: Array[DecisionTreeModel]
+ def trees: Array[M]
/**
* Number of trees in ensemble
@@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
- def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree
@@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel {
* If -1, then numFeatures is set based on the max feature index in all trees.
* @return Feature importance values, of length numFeatures.
*/
- def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+ def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
featureImportances(Array(tree), numFeatures)
}
@@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite {
* @param path Path to which to save the ensemble model.
* @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
*/
- def saveImpl[M <: Params with TreeEnsembleModel](
+ def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
instance: M,
path: String,
sql: SQLContext,