aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala139
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala238
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala248
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Model.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala228
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala185
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala249
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala218
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala167
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala46
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java100
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java103
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java99
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java102
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala136
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala166
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala137
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala122
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala6
31 files changed, 2658 insertions, 174 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 2cd515c89d..9002e99d82 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -22,10 +22,9 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
-import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
@@ -64,8 +63,6 @@ object DecisionTreeExample {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
- numTrees: Int = 1,
- featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
@@ -123,8 +120,8 @@ object DecisionTreeExample {
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest > 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
@@ -200,9 +197,18 @@ object DecisionTreeExample {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}
- val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings)
+ val training = dataframes(0).cache()
+ val test = dataframes(1).cache()
- (dataframes(0), dataframes(1))
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
+
+ (training, test)
}
def run(params: Params) {
@@ -217,13 +223,6 @@ object DecisionTreeExample {
val (training: DataFrame, test: DataFrame) =
loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
- val numTraining = training.count()
- val numTest = test.count()
- val numFeatures = training.select("features").first().getAs[Vector](0).size
- println("Loaded data:")
- println(s" numTraining = $numTraining, numTest = $numTest")
- println(s" numFeatures = $numFeatures")
-
// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
@@ -241,7 +240,7 @@ object DecisionTreeExample {
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
- // (3) Learn DecisionTree
+ // (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier()
@@ -275,62 +274,86 @@ object DecisionTreeExample {
println(s"Training time: $elapsedTime seconds")
// Get the trained Decision Tree from the fitted PipelineModel
- val treeModel: DecisionTreeModel = algo match {
+ algo match {
case "classification" =>
- pipelineModel.getModel[DecisionTreeClassificationModel](
+ val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
case "regression" =>
- pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- if (treeModel.numNodes < 20) {
- println(treeModel.toDebugString) // Print full model.
- } else {
- println(treeModel) // Print model summary.
- }
-
- // Predict on training
- val trainingFullPredictions = pipelineModel.transform(training).cache()
- val trainingPredictions = trainingFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
- // Predict on test data
- val testFullPredictions = pipelineModel.transform(test).cache()
- val testPredictions = testFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
-
- // For classification, print number of classes for reference.
- if (algo == "classification") {
- val numClasses =
- MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
- case Some(n) => n
- case None => throw new RuntimeException(
- "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
+ dt.asInstanceOf[DecisionTreeRegressor])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
}
- println(s"numClasses = $numClasses.")
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
// Evaluate model on training, test data
algo match {
case "classification" =>
- val trainingAccuracy =
- new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
- println(s"Train accuracy = $trainingAccuracy")
- val testAccuracy =
- new MulticlassMetrics(testPredictions.zip(testLabels)).precision
- println(s"Test accuracy = $testAccuracy")
+ println("Training data results:")
+ evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
- val trainingRMSE =
- new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
- println(s"Training root mean squared error (RMSE) = $trainingRMSE")
- val testRMSE =
- new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
- println(s"Test root mean squared error (RMSE) = $testRMSE")
+ println("Training data results:")
+ evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
sc.stop()
}
+
+ /**
+ * Evaluate the given ClassificationModel on data. Print the results.
+ * @param model Must fit ClassificationModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateClassificationModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ // Print number of classes for reference
+ val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "Unknown failure when indexing labels for classification.")
+ }
+ val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
+ println(s" Accuracy ($numClasses classes): $accuracy")
+ }
+
+ /**
+ * Evaluate the given RegressionModel on data. Print the results.
+ * @param model Must fit RegressionModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to RegressionModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateRegressionModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
+ println(s" Root mean squared error (RMSE): $RMSE")
+ }
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
new file mode 100644
index 0000000000..5fccb142d4
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.GBTExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object GBTExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ maxIter: Int = 10,
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GBTExample") {
+ head("GBTExample: an example Gradient-Boosted Trees app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("maxIter")
+ .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"GBTExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"GBTExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn GBT
+ val dt = algo match {
+ case "classification" =>
+ new GBTClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case "regression" =>
+ new GBTRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained GBT from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
new file mode 100644
index 0000000000..9b909324ec
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.RandomForestExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object RandomForestExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 10,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("RandomForestExample") {
+ head("RandomForestExample: an example random forest app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("numTrees")
+ .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(numTrees = x))
+ opt[String]("featureSubsetStrategy")
+ .text(s"number of features to use per node (supported:" +
+ s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
+ s" default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(featureSubsetStrategy = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"RandomForestExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn Random Forest
+ val dt = algo match {
+ case "classification" =>
+ new RandomForestClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case "regression" =>
+ new RandomForestRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Random Forest from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
+ dt.asInstanceOf[RandomForestClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
+ dt.asInstanceOf[RandomForestRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 431ead8c0c..0763a77363 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
+
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index cae5082b51..a491bc7ee8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap
abstract class Model[M <: Model[M]] extends Transformer {
/**
* The parent estimator that produced this model.
+ * Note: For ensembles' component Models, this value can be null.
*/
val parent: Estimator[M]
/**
* Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+ * Note: For ensembles' component Models, this value can be null.
*/
val fittingParamMap: ParamMap
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3855e396b5..ee2a8dc6db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame
@AlphaComponent
final class DecisionTreeClassifier
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
- with DecisionTreeParams
- with TreeClassifierParams {
+ with DecisionTreeParams with TreeClassifierParams {
// Override parameter setters from parent trait for Java API compatibility.
@@ -59,11 +58,9 @@ final class DecisionTreeClassifier
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
- override def setCacheNodeIds(value: Boolean): this.type =
- super.setCacheNodeIds(value)
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
- override def setCheckpointInterval(value: Int): this.type =
- super.setCheckpointInterval(value)
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
override def setImpurity(value: String): this.type = super.setImpurity(value)
@@ -75,8 +72,9 @@ final class DecisionTreeClassifier
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
- s" with invalid label column, without the number of classes specified.")
- // TODO: Automatically index labels.
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
@@ -85,18 +83,16 @@ final class DecisionTreeClassifier
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
- override private[ml] def getOldStrategy(
+ private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
- strategy.algo = OldAlgo.Classification
- strategy.setImpurity(getOldImpurity)
- strategy
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
+ subsamplingRate = 1.0)
}
}
object DecisionTreeClassifier {
- /** Accessor for supported impurities */
+ /** Accessor for supported impurities: entropy, gini */
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
new file mode 100644
index 0000000000..d2e052fbbb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Param, Params, ParamMap}
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ */
+@AlphaComponent
+final class GBTClassifier
+ extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
+ with GBTParams with TreeClassifierParams with Logging {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ */
+ override def setImpurity(value: String): this.type = {
+ logWarning("GBTClassifier.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = {
+ logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+ super.setSeed(value)
+ }
+
+ // Parameters from GBTParams:
+
+ override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+ override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+ // Parameters for GBTClassifier:
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+
+ setDefault(lossType -> "logistic")
+
+ /** @group setParam */
+ def setLossType(value: String): this.type = {
+ val lossStr = value.toLowerCase
+ require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
+ s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+ set(lossType, lossStr)
+ this
+ }
+
+ /** @group getParam */
+ def getLossType: String = getOrDefault(lossType)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): GBTClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("GBTClassifier was given input" +
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
+ }
+ require(numClasses == 2,
+ s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val oldGBT = new OldGBT(boostingStrategy)
+ val oldModel = oldGBT.run(oldDataset)
+ GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object GBTClassifier {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTClassificationModel(
+ override val parent: GBTClassifier,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double])
+ extends PredictionModel[Vector, GBTClassificationModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
+ s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model: SPARK-7127
+ // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
+ // Classifies by thresholding sum of weighted tree predictions
+ val treePredictions = _trees.map(_.rootNode.predict(features))
+ val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ if (prediction > 0.0) 1.0 else 0.0
+ }
+
+ override protected def copy(): GBTClassificationModel = {
+ val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"GBTClassificationModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
+ }
+}
+
+private[ml] object GBTClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
new file mode 100644
index 0000000000..cfd6508fce
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for
+ * classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class RandomForestClassifier
+ extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ with RandomForestParams with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
+ // Parameters from RandomForestParams:
+
+ override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ super.setFeatureSubsetStrategy(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): RandomForestClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+ val oldModel = OldRandomForest.trainClassifier(
+ oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+ RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object RandomForestClassifier {
+ /** Accessor for supported impurity settings: entropy, gini */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ * @param _trees Decision trees in the ensemble.
+ * Warning: These have null parents.
+ */
+@AlphaComponent
+final class RandomForestClassificationModel private[ml] (
+ override val parent: RandomForestClassifier,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeClassificationModel])
+ extends PredictionModel[Vector, RandomForestClassificationModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
+ // Classifies using majority votes.
+ // Ignore the weights since all are 1.0 for now.
+ val votes = mutable.Map.empty[Int, Double]
+ _trees.view.foreach { tree =>
+ val prediction = tree.rootNode.predict(features).toInt
+ votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+ }
+ votes.maxBy(_._2)._1
+ }
+
+ override protected def copy(): RandomForestClassificationModel = {
+ val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"RandomForestClassificationModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
+ }
+}
+
+private[ml] object RandomForestClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new RandomForestClassificationModel(parent, fittingParamMap, newTrees)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index eb2609faef..ab6281b9b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.impl.estimator.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo,
+ BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
/**
@@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
- this
}
/** @group getParam */
- def getMaxDepth: Int = getOrDefault(maxDepth)
+ final def getMaxDepth: Int = getOrDefault(maxDepth)
/** @group setParam */
def setMaxBins(value: Int): this.type = {
require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
set(maxBins, value)
- this
}
/** @group getParam */
- def getMaxBins: Int = getOrDefault(maxBins)
+ final def getMaxBins: Int = getOrDefault(maxBins)
/** @group setParam */
def setMinInstancesPerNode(value: Int): this.type = {
require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
set(minInstancesPerNode, value)
- this
}
/** @group getParam */
- def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+ final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
/** @group setParam */
- def setMinInfoGain(value: Double): this.type = {
- set(minInfoGain, value)
- this
- }
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group getParam */
- def getMinInfoGain: Double = getOrDefault(minInfoGain)
+ final def getMinInfoGain: Double = getOrDefault(minInfoGain)
/** @group expertSetParam */
def setMaxMemoryInMB(value: Int): this.type = {
require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
set(maxMemoryInMB, value)
- this
}
/** @group expertGetParam */
- def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+ final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
/** @group expertSetParam */
- def setCacheNodeIds(value: Boolean): this.type = {
- set(cacheNodeIds, value)
- this
- }
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/** @group expertGetParam */
- def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+ final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
/** @group expertSetParam */
def setCheckpointInterval(value: Int): this.type = {
require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
set(checkpointInterval, value)
- this
}
/** @group expertGetParam */
- def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
- /**
- * Create a Strategy instance to use with the old API.
- * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
- * the default for single trees).
- */
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
- numClasses: Int): OldStrategy = {
- val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity,
+ subsamplingRate: Double): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(oldAlgo)
+ strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
@@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
- strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy.subsamplingRate = subsamplingRate
strategy
}
}
/**
- * (private trait) Parameters for Decision Tree-based classification algorithms.
+ * Parameters for Decision Tree-based classification algorithms.
*/
private[ml] trait TreeClassifierParams extends Params {
@@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params {
* (default = gini)
* @group param
*/
- val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
@@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params {
s"Tree-based classifier was given unrecognized impurity: $value." +
s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
set(impurity, impurityStr)
- this
}
/** @group getParam */
- def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params {
private[ml] object TreeClassifierParams {
// These options should be lowercase.
- val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+ final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
}
/**
- * (private trait) Parameters for Decision Tree-based regression algorithms.
+ * Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends Params {
@@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params {
* (default = variance)
* @group param
*/
- val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
@@ -276,11 +267,10 @@ private[ml] trait TreeRegressorParams extends Params {
s"Tree-based regressor was given unrecognized impurity: $value." +
s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
set(impurity, impurityStr)
- this
}
/** @group getParam */
- def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params {
private[ml] object TreeRegressorParams {
// These options should be lowercase.
- val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+ final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+
+ /**
+ * Fraction of the training data used for learning each decision tree.
+ * (default = 1.0)
+ * @group param
+ */
+ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree.")
+
+ setDefault(subsamplingRate -> 1.0)
+
+ /** @group setParam */
+ def setSubsamplingRate(value: Double): this.type = {
+ require(value > 0.0 && value <= 1.0,
+ s"Subsampling rate must be in range (0,1]. Bad rate: $value")
+ set(subsamplingRate, value)
+ }
+
+ /** @group getParam */
+ final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and seed.
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Random Forest algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)")
+
+ /**
+ * The number of features to consider for splits at each tree node.
+ * Supported options:
+ * - "auto": Choose automatically for task:
+ * If numTrees == 1, set to "all."
+ * If numTrees > 1 (forest), set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * - "all": use all features
+ * - "onethird": use 1/3 of the features
+ * - "sqrt": use sqrt(number of features)
+ * - "log2": use log2(number of features)
+ * (default = "auto")
+ *
+ * These various settings are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
+ * random forests]]
+ *
+ * @group param
+ */
+ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node." +
+ s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+
+ setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+
+ /** @group setParam */
+ def setNumTrees(value: Int): this.type = {
+ require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.")
+ set(numTrees, value)
+ }
+
+ /** @group getParam */
+ final def getNumTrees: Int = getOrDefault(numTrees)
+
+ /** @group setParam */
+ def setFeatureSubsetStrategy(value: String): this.type = {
+ val strategyStr = value.toLowerCase
+ require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr),
+ s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" +
+ s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+ set(featureSubsetStrategy, strategyStr)
+ }
+
+ /** @group getParam */
+ final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
+}
+
+private[ml] object RandomForestParams {
+ // These options should be lowercase.
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
+
+ /**
+ * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+ * estimator.
+ * (default = 0.1)
+ * @group param
+ */
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
+ " learning rate) in interval (0, 1] for shrinking the contribution of each estimator")
+
+ /* TODO: Add this doc when we add this param. SPARK-7132
+ * Threshold for stopping early when runWithValidation is used.
+ * If the error rate on the validation input changes by less than the validationTol,
+ * then learning will stop early (before [[numIterations]]).
+ * This parameter is ignored when run is used.
+ * (default = 1e-5)
+ * @group param
+ */
+ // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
+ // validationTol -> 1e-5
+
+ setDefault(maxIter -> 20, stepSize -> 0.1)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = {
+ require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.")
+ set(maxIter, value)
+ }
+
+ /** @group setParam */
+ def setStepSize(value: Double): this.type = {
+ require(value > 0.0 && value <= 1.0,
+ s"GBT given invalid step size ($value). Value should be in (0,1].")
+ set(stepSize, value)
+ }
+
+ /** @group getParam */
+ final def getStepSize: Double = getOrDefault(stepSize)
+
+ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+ private[ml] def getOldBoostingStrategy(
+ categoricalFeatures: Map[Int, Int],
+ oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+ // NOTE: The old API does not support "seed" so we ignore it.
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+ }
+
+ /** Get old Gradient Boosting Loss type */
+ private[ml] def getOldLossType: OldLoss
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 95d7e64790..e88c48741e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name"),
ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
- ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")))
+ ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+ ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen {
|
|import org.apache.spark.annotation.DeveloperApi
|import org.apache.spark.ml.param._
+ |import org.apache.spark.util.Utils
|
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
|
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 72b08bf276..a860b8834c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
+import org.apache.spark.util.Utils
// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
@@ -256,4 +257,23 @@ trait HasFitIntercept extends Params {
/** @group getParam */
final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param seed (default: Utils.random.nextLong()).
+ */
+@DeveloperApi
+trait HasSeed extends Params {
+
+ /**
+ * Param for random seed.
+ * @group param
+ */
+ final val seed: LongParam = new LongParam(this, "seed", "random seed")
+
+ setDefault(seed, Utils.random.nextLong())
+
+ /** @group getParam */
+ final def getSeed: Long = getOrDefault(seed)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 49a8b77acf..756725a64b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame
@AlphaComponent
final class DecisionTreeRegressor
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
- with DecisionTreeParams
- with TreeRegressorParams {
+ with DecisionTreeParams with TreeRegressorParams {
// Override parameter setters from parent trait for Java API compatibility.
@@ -60,8 +59,7 @@ final class DecisionTreeRegressor
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
- override def setCheckpointInterval(value: Int): this.type =
- super.setCheckpointInterval(value)
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
override def setImpurity(value: String): this.type = super.setImpurity(value)
@@ -78,15 +76,13 @@ final class DecisionTreeRegressor
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
- strategy.algo = OldAlgo.Regression
- strategy.setImpurity(getOldImpurity)
- strategy
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
+ subsamplingRate = 1.0)
}
}
object DecisionTreeRegressor {
- /** Accessor for supported impurities */
+ /** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
new file mode 100644
index 0000000000..c784cf39ed
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap, Param}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
+ SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class GBTRegressor
+ extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
+ with GBTParams with TreeRegressorParams with Logging {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ */
+ override def setImpurity(value: String): this.type = {
+ logWarning("GBTRegressor.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = {
+ logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+ super.setSeed(value)
+ }
+
+ // Parameters from GBTParams:
+
+ override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+ override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+ // Parameters for GBTRegressor:
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+
+ setDefault(lossType -> "squared")
+
+ /** @group setParam */
+ def setLossType(value: String): this.type = {
+ val lossStr = value.toLowerCase
+ require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" +
+ s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+ set(lossType, lossStr)
+ this
+ }
+
+ /** @group getParam */
+ def getLossType: String = getOrDefault(lossType)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): GBTRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+ val oldGBT = new OldGBT(boostingStrategy)
+ val oldModel = oldGBT.run(oldDataset)
+ GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object GBTRegressor {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTRegressionModel(
+ override val parent: GBTRegressor,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double])
+ extends PredictionModel[Vector, GBTRegressionModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
+ s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
+ // Classifies by thresholding sum of weighted tree predictions
+ val treePredictions = _trees.map(_.rootNode.predict(features))
+ val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ if (prediction > 0.0) 1.0 else 0.0
+ }
+
+ override protected def copy(): GBTRegressionModel = {
+ val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"GBTRegressionModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
+ }
+}
+
+private[ml] object GBTRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
new file mode 100644
index 0000000000..2171ef3d32
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class RandomForestRegressor
+ extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
+ with RandomForestParams with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
+ // Parameters from RandomForestParams:
+
+ override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ super.setFeatureSubsetStrategy(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): RandomForestRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+ val oldModel = OldRandomForest.trainRegressor(
+ oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+ RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object RandomForestRegressor {
+ /** Accessor for supported impurity settings: variance */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees Decision trees in the ensemble.
+ */
+@AlphaComponent
+final class RandomForestRegressionModel private[ml] (
+ override val parent: RandomForestRegressor,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel])
+ extends PredictionModel[Vector, RandomForestRegressionModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
+ // Predict average of tree predictions.
+ // Ignore the weights since all are 1.0 for now.
+ _trees.map(_.rootNode.predict(features)).sum / numTrees
+ }
+
+ override protected def copy(): RandomForestRegressionModel = {
+ val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"RandomForestRegressionModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
+ }
+}
+
+private[ml] object RandomForestRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new RandomForestRegressionModel(parent, fittingParamMap, newTrees)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index d6e2203d9f..d2dec0c76c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation
sealed abstract class Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
- // code into the new API and deprecate the old API.
+ // code into the new API and deprecate the old API. SPARK-3727
- /** Prediction this node makes (or would make, if it is an internal node) */
+ /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
def prediction: Double
/** Impurity measure at this node (for training data) */
@@ -194,7 +194,7 @@ private object InternalNode {
s"$featureStr > ${contSplit.threshold}"
}
case catSplit: CategoricalSplit =>
- val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
if (left) {
s"$featureStr in $categoriesStr"
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 708c769087..90f1d05276 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -44,7 +44,7 @@ private[tree] object Split {
oldSplit.featureType match {
case OldFeatureType.Categorical =>
new CategoricalSplit(featureIndex = oldSplit.feature,
- leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
case OldFeatureType.Continuous =>
new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
}
@@ -54,30 +54,30 @@ private[tree] object Split {
/**
* Split which tests a categorical feature.
* @param featureIndex Index of the feature to test
- * @param leftCategories If the feature value is in this set of categories, then the split goes
- * left. Otherwise, it goes right.
+ * @param _leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
- leftCategories: Array[Double],
+ _leftCategories: Array[Double],
private val numCategories: Int)
extends Split {
- require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
- s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+ require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}")
/**
* If true, then "categories" is the set of categories for splitting to the left, and vice versa.
*/
- private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+ private val isLeft: Boolean = _leftCategories.length <= numCategories / 2
/** Set of categories determining the splitting rule, along with [[isLeft]]. */
private val categories: Set[Double] = {
if (isLeft) {
- leftCategories.toSet
+ _leftCategories.toSet
} else {
- setComplement(leftCategories.toSet)
+ setComplement(_leftCategories.toSet)
}
}
@@ -107,13 +107,13 @@ final class CategoricalSplit private[ml] (
}
/** Get sorted categories which split to the left */
- def getLeftCategories: Array[Double] = {
+ def leftCategories: Array[Double] = {
val cats = if (isLeft) categories else setComplement(categories)
cats.toArray.sorted
}
/** Get sorted categories which split to the right */
- def getRightCategories: Array[Double] = {
+ def rightCategories: Array[Double] = {
val cats = if (isLeft) setComplement(categories) else categories
cats.toArray.sorted
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 8e3bc3849d..1929f9d021 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,18 +17,13 @@
package org.apache.spark.ml.tree
-import org.apache.spark.annotation.AlphaComponent
-
/**
- * :: AlphaComponent ::
- *
* Abstraction for Decision Tree models.
*
- * TODO: Add support for predicting probabilities and raw predictions
+ * TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*/
-@AlphaComponent
-trait DecisionTreeModel {
+private[ml] trait DecisionTreeModel {
/** Root of the decision tree */
def rootNode: Node
@@ -58,3 +53,40 @@ trait DecisionTreeModel {
header + rootNode.subtreeToString(2)
}
}
+
+/**
+ * Abstraction for models which are ensembles of decision trees
+ *
+ * TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ */
+private[ml] trait TreeEnsembleModel {
+
+ // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
+ // DecisionTreeModel.
+
+ /** Trees in this ensemble. Warning: These have null parent Estimators. */
+ def trees: Array[DecisionTreeModel]
+
+ /** Weights for each tree, zippable with [[trees]] */
+ def treeWeights: Array[Double]
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"TreeEnsembleModel with $numTrees trees"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) =>
+ s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
+ /** Total number of nodes, summed over all trees in the ensemble. */
+ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 43b8787f9d..60f25e5cce 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.classification;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +55,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -71,8 +69,8 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
new file mode 100644
index 0000000000..3c69467fa1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ GBTClassifier rf = new GBTClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTClassifier.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
new file mode 100644
index 0000000000..32d0b3856b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestClassifier rf = new RandomForestClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestClassifier.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ RandomForestClassificationModel sameModel =
+ RandomForestClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a3a339004f..71b041818d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.regression;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,22 +55,22 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
- .setMaxDepth(2)
- .setMaxBins(10)
- .setMinInstancesPerNode(5)
- .setMinInfoGain(0.0)
- .setMaxMemoryInMB(256)
- .setCacheNodeIds(false)
- .setCheckpointInterval(10)
- .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
new file mode 100644
index 0000000000..fc8c13db07
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ GBTRegressor rf = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTRegressor.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
new file mode 100644
index 0000000000..e306ebadfe
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestRegressor rf = new RandomForestRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestRegressor.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index af88595df5..9b31adecdc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
/*
test("model save/load") {
val tempDir = Utils.createTempDir()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
new file mode 100644
index 0000000000..e6ccc2c93c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTClassifier]].
+ */
+class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import GBTClassifierSuite.compareAPIs
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ private val testCombinations =
+ Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+ private var data: RDD[LabeledPoint] = _
+ private var trainData: RDD[LabeledPoint] = _
+ private var validationData: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+ trainData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+ validationData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+ }
+
+ test("Binary classification with continuous features: Log Loss") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ testCombinations.foreach {
+ case (maxIter, learningRate, subsamplingRate) =>
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setSubsamplingRate(subsamplingRate)
+ .setLossType("logistic")
+ .setMaxIter(maxIter)
+ .setStepSize(learningRate)
+ compareAPIs(data, None, gbt, categoricalFeatures)
+ }
+ }
+
+ // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
+ /*
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ // Set maxIter large enough so that it stops early.
+ val maxIter = 20
+ GBTClassifier.supportedLossTypes.foreach { loss =>
+ val gbt = new GBTClassifier()
+ .setMaxIter(maxIter)
+ .setMaxDepth(2)
+ .setLossType(loss)
+ .setValidationTol(0.0)
+ compareAPIs(trainData, None, gbt, categoricalFeatures)
+ compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+ }
+ }
+ */
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+ val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
+ val newModel = GBTClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = GBTClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object GBTClassifierSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ validationData: Option[RDD[LabeledPoint]],
+ gbt: GBTClassifier,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldBoostingStrategy =
+ gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldModel = oldGBT.run(data)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val newModel = gbt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
new file mode 100644
index 0000000000..ed41a9664f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestClassifier]].
+ */
+class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import RandomForestClassifierSuite.compareAPIs
+
+ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+ private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ orderedLabeledPoints50_1000 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+ orderedLabeledPoints5_20 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) {
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+ val newRF = rf
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestClassifier()
+ binaryClassificationTestWithContinuousFeatures(rf)
+ }
+
+ test("Binary classification with continuous features and node Id cache:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestClassifier()
+ .setCacheNodeIds(true)
+ binaryClassificationTestWithContinuousFeatures(rf)
+ }
+
+ test("alternating categorical and continuous features with multiclass labels to test indexing") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)),
+ LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+ )
+ val rdd = sc.parallelize(arr)
+ val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4)
+ val numClasses = 3
+
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(5)
+ .setNumTrees(2)
+ .setFeatureSubsetStrategy("sqrt")
+ .setSeed(12345)
+ compareAPIs(rdd, rf, categoricalFeatures, numClasses)
+ }
+
+ test("subsampling rate in RandomForest"){
+ val rdd = orderedLabeledPoints5_20
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val rf1 = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setCacheNodeIds(true)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(rdd, rf1, categoricalFeatures, numClasses)
+
+ val rf2 = rf1.setSubsamplingRate(0.5)
+ compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees =
+ Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
+ val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
+ val newModel = RandomForestClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = RandomForestClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object RandomForestClassifierSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ rf: RandomForestClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy =
+ rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
+ val oldModel = OldRandomForest.trainClassifier(
+ data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newModel = rf.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 2e57d4ce37..1505ad8725 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -23,8 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
-import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.ml.tree._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
@@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite {
}
}
- // TODO: Reinstate after adding ensembles
/**
* Check if the two models are exactly the same.
* If the models are not equal, this throws an exception.
*/
- /*
def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
try {
- a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
TreeTests.checkEqual(treeA, treeB)
}
- assert(a.getTreeWeights === b.getTreeWeights)
+ assert(a.treeWeights === b.treeWeights)
} catch {
case ex: Exception => throw new AssertionError(
"checkEqual failed since the two tree ensembles were not identical")
}
}
- */
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 0b40fe33fa..c87a171b4b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: test("model save/load")
+ // TODO: test("model save/load") SPARK-6725
}
private[ml] object DecisionTreeRegressorSuite extends FunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
new file mode 100644
index 0000000000..4aec36948a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTRegressor]].
+ */
+class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import GBTRegressorSuite.compareAPIs
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ private val testCombinations =
+ Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+ private var data: RDD[LabeledPoint] = _
+ private var trainData: RDD[LabeledPoint] = _
+ private var validationData: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+ trainData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+ validationData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+ }
+
+ test("Regression with continuous features: SquaredError") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ GBTRegressor.supportedLossTypes.foreach { loss =>
+ testCombinations.foreach {
+ case (maxIter, learningRate, subsamplingRate) =>
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setSubsamplingRate(subsamplingRate)
+ .setLossType(loss)
+ .setMaxIter(maxIter)
+ .setStepSize(learningRate)
+ compareAPIs(data, None, gbt, categoricalFeatures)
+ }
+ }
+ }
+
+ // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
+ /*
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ // Set maxIter large enough so that it stops early.
+ val maxIter = 20
+ GBTRegressor.supportedLossTypes.foreach { loss =>
+ val gbt = new GBTRegressor()
+ .setMaxIter(maxIter)
+ .setMaxDepth(2)
+ .setLossType(loss)
+ .setValidationTol(0.0)
+ compareAPIs(trainData, None, gbt, categoricalFeatures)
+ compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+ }
+ }
+ */
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+ val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
+ val newModel = GBTRegressionModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = GBTRegressionModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object GBTRegressorSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ validationData: Option[RDD[LabeledPoint]],
+ gbt: GBTRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+ val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldModel = oldGBT.run(data)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newModel = gbt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
new file mode 100644
index 0000000000..c6dc1cc29b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestRegressor]].
+ */
+class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import RandomForestRegressorSuite.compareAPIs
+
+ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ orderedLabeledPoints50_1000 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val newRF = rf
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo)
+ }
+
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestRegressor()
+ regressionTestWithContinuousFeatures(rf)
+ }
+
+ test("Regression with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestRegressor()
+ .setCacheNodeIds(true)
+ regressionTestWithContinuousFeatures(rf)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
+ val newModel = RandomForestRegressionModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = RandomForestRegressionModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object RandomForestRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ rf: RandomForestRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy =
+ rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
+ val oldModel = OldRandomForest.trainRegressor(
+ data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newModel = rf.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 249b8eae19..ce983eb27f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite {
node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
categories = List(0.0, 1.0)))
}
- // TODO: The information gain stats should be consistent with the same info stored in children.
+ // TODO: The information gain stats should be consistent with info in children: SPARK-7131
node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
node
@@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
- * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
+ * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131
*/
- private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
+ private[spark] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))