aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-26 14:44:39 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-26 14:44:39 -0700
commit6c5a837c509233d4008cffeaede111f17fea5289 (patch)
tree56e56013c4b18cb3cecb34f2a145382ef884a6d9
parente88476c8c6a74a92eb70b4bda8394936e0036729 (diff)
downloadspark-6c5a837c509233d4008cffeaede111f17fea5289.tar.gz
spark-6c5a837c509233d4008cffeaede111f17fea5289.tar.bz2
spark-6c5a837c509233d4008cffeaede111f17fea5289.zip
[SPARK-12301][ML] Made all tree and ensemble classes not final
## What changes were proposed in this pull request? There have been continuing requests (e.g., SPARK-7131) for allowing users to extend and modify MLlib models and algorithms. This PR makes tree and ensemble classes, Node types, and Split types in spark.ml no longer final. This matches most other spark.ml algorithms. Constructors for models are still private since we may need to refactor how stats are maintained in tree nodes. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12711 from jkbradley/final-trees.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala4
8 files changed, 16 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 1943a4a747..ecb218e2a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.Dataset
*/
@Since("1.4.0")
@Experimental
-final class DecisionTreeClassifier @Since("1.4.0") (
+class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeClassifierParams with DefaultParamsWritable {
@@ -138,7 +138,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
*/
@Since("1.4.0")
@Experimental
-final class DecisionTreeClassificationModel private[ml] (
+class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: Node,
@Since("1.6.0")override val numFeatures: Int,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 1bd6dae7dd..e736f01cc6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
-final class GBTClassifier @Since("1.4.0") (
+class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging {
@@ -170,7 +170,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
*/
@Since("1.6.0")
@Experimental
-final class GBTClassificationModel private[ml](
+class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index c04ecc88ae..28364c2593 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
-final class RandomForestClassifier @Since("1.4.0") (
+class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestClassifierParams with DefaultParamsWritable {
@@ -149,7 +149,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
*/
@Since("1.4.0")
@Experimental
-final class RandomForestClassificationModel private[ml] (
+class RandomForestClassificationModel private[ml] (
@Since("1.5.0") override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
@Since("1.6.0") override val numFeatures: Int,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index c04c416aaf..339a8cf486 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
-final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeRegressorParams with DefaultParamsWritable {
@@ -129,7 +129,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
*/
@Since("1.4.0")
@Experimental
-final class DecisionTreeRegressionModel private[ml] (
+class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: Node,
override val numFeatures: Int)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da51cb7800..c41fb4b062 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
-final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
with GBTRegressorParams with DefaultParamsWritable with Logging {
@@ -157,7 +157,7 @@ object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
*/
@Since("1.4.0")
@Experimental
-final class GBTRegressionModel private[ml](
+class GBTRegressionModel private[ml](
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 8eaed8b682..b6ab2fd625 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -43,7 +43,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
-final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestRegressorParams with DefaultParamsWritable {
@@ -137,7 +137,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
*/
@Since("1.4.0")
@Experimental
-final class RandomForestRegressionModel private[ml] (
+class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index b5cb378829..f71d28cf59 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -115,7 +115,7 @@ private[ml] object Node {
* @param impurity Impurity measure at this node (for training data)
*/
@DeveloperApi
-final class LeafNode private[ml] (
+class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
@@ -158,7 +158,7 @@ final class LeafNode private[ml] (
* @param split Information about the test used to split to the left or right child.
*/
@DeveloperApi
-final class InternalNode private[ml] (
+class InternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
val gain: Double,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 5d11ed0971..a4287483d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -75,7 +75,7 @@ private[tree] object Split {
* @param numCategories Number of categories for this feature.
*/
@DeveloperApi
-final class CategoricalSplit private[ml] (
+class CategoricalSplit private[ml] (
override val featureIndex: Int,
_leftCategories: Array[Double],
@Since("2.0.0") val numCategories: Int)
@@ -160,7 +160,7 @@ final class CategoricalSplit private[ml] (
* Otherwise, it goes right.
*/
@DeveloperApi
-final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
+class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {