aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-17 13:15:36 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-17 13:15:36 -0700
commita83571acc938582865efb41645aa1e414f339e46 (patch)
treecd11b3bea5e50946c015d22672f269970f664f9b
parent50ab8a6543ad5c31e89c16df374d0cb13222fd1e (diff)
downloadspark-a83571acc938582865efb41645aa1e414f339e46.tar.gz
spark-a83571acc938582865efb41645aa1e414f339e46.tar.bz2
spark-a83571acc938582865efb41645aa1e414f339e46.zip
[SPARK-6113] [ml] Stabilize DecisionTree API
This is a PR for cleaning up and finalizing the DecisionTree API. PRs for ensembles will follow once this is merged. ### Goal Here is the description copied from the JIRA (for both trees and ensembles): > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design. > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details. > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API. Overall code layout: * The old API in mllib.tree.* will remain the same. * The new API will reside in ml.classification.* and ml.regression.* ### Summary of changes Old API * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release). New APIs * Under Pipeline API * The new API preserves functionality, except: * New API does NOT store prob (probability of label in classification). I want to have it store the full vector of probabilities but feel that should be in a later PR. * Use abstractions for parameters, estimators, and models to avoid code duplication * Limit parameters to relevant algorithms * For enum-like types, only expose Strings * We can make these pluggable later on by adding new parameters. That is a far-future item. Test suites * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves. * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API. * After code is moved to this new API, we should move the tests from the old suites which test the internals. ### Details #### Changed names Parameters * useNodeIdCache -> cacheNodeIds #### Other changes * Split: Changed categories to set instead of list #### Non-decision tree changes * AttributeGroup * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.) * Attributes * Renamed: toMetadata -> toMetadataImpl * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”) * NominalAttribute: Added getNumValues method which examines both numValues and values. * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters) ### Questions for reviewers * Is "DecisionTreeClassificationModel" too long a name? * Is this OK in the docs? ``` class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor] ``` ### Future We should open up the abstractions at some point. E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms. Follow-up JIRAs will be (in this order): * Tree ensembles * Deprecate old tree code * Move DecisionTree implementation code to new API. * Move tests from the old suites which test the internals. * Update programming guide * Python API * Change RandomForest* to always use bootstrapping, even when numTrees = 1 * Provide the probability of the predicted label for classification. After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API. CC: mengxr manishamde codedeft chouqin MechCoder Author: Joseph K. Bradley <joseph@databricks.com> Closes #5530 from jkbradley/dt-api-dt and squashes the following commits: 6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead ec17947 [Joseph K. Bradley] Updates based on code review. Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final 5626c81 [Joseph K. Bradley] style fixes f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR. small cleanups 7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time) e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example 119f407 [Joseph K. Bradley] added DecisionTreeClassifier example 0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet. Need to add example as well 2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala322
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala43
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala300
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala145
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala205
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala151
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala60
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala32
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java98
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java97
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala274
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala132
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala91
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala373
33 files changed, 2426 insertions, 263 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
new file mode 100644
index 0000000000..d4cc8dede0
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.ml.tree.DecisionTreeModel
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
+import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.DecisionTreeExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DecisionTreeExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 1,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DecisionTreeExample") {
+ head("DecisionTreeExample: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ /** Load a dataset from the given path, using the given format */
+ private[ml] def loadData(
+ sc: SparkContext,
+ path: String,
+ format: String,
+ expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
+ format match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, path)
+ case "libsvm" => expectedNumFeatures match {
+ case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
+ case None => MLUtils.loadLibSVMFile(sc, path)
+ }
+ case _ => throw new IllegalArgumentException(s"Bad data format: $format")
+ }
+ }
+
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset)
+ */
+ private[ml] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: String,
+ fracTest: Double): (DataFrame, DataFrame) = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Load training data
+ val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
+
+ // Load or create test set
+ val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
+ // Load testInput.
+ val numFeatures = origExamples.take(1)(0).features.size
+ val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
+ Array(origExamples, origTestExamples)
+ } else {
+ // Split input into training, test.
+ origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345)
+ }
+
+ // For classification, convert labels to Strings since we will index them later with
+ // StringIndexer.
+ def labelsToStrings(data: DataFrame): DataFrame = {
+ algo.toLowerCase match {
+ case "classification" =>
+ data.withColumn("labelString", data("label").cast(StringType))
+ case "regression" =>
+ data
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ }
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+
+ (dataframes(0), dataframes(1))
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"DecisionTreeExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) =
+ loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer().setInputCol("features")
+ .setOutputCol("indexedFeatures").setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn DecisionTree
+ val dt = algo match {
+ case "classification" =>
+ new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ case "regression" =>
+ new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Decision Tree from the fitted PipelineModel
+ val treeModel: DecisionTreeModel = algo match {
+ case "classification" =>
+ pipelineModel.getModel[DecisionTreeClassificationModel](
+ dt.asInstanceOf[DecisionTreeClassifier])
+ case "regression" =>
+ pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
+
+ // Predict on training
+ val trainingFullPredictions = pipelineModel.transform(training).cache()
+ val trainingPredictions = trainingFullPredictions.select("prediction")
+ .map(_.getDouble(0))
+ val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
+ // Predict on test data
+ val testFullPredictions = pipelineModel.transform(test).cache()
+ val testPredictions = testFullPredictions.select("prediction")
+ .map(_.getDouble(0))
+ val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
+
+ // For classification, print number of classes for reference.
+ if (algo == "classification") {
+ val numClasses =
+ MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ }
+ println(s"numClasses = $numClasses.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ val trainingAccuracy =
+ new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
+ println(s"Train accuracy = $trainingAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(testPredictions.zip(testLabels)).precision
+ println(s"Test accuracy = $testAccuracy")
+ case "regression" =>
+ val trainingRMSE =
+ new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
+ println(s"Training root mean squared error (RMSE) = $trainingRMSE")
+ val testRMSE =
+ new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
+ println(s"Test root mean squared error (RMSE) = $testRMSE")
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index aa27a668f1..d7dee8fed2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -117,12 +117,12 @@ class AttributeGroup private (
case numeric: NumericAttribute =>
// Skip default numeric attributes.
if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
- numericMetadata += numeric.toMetadata(withType = false)
+ numericMetadata += numeric.toMetadataImpl(withType = false)
}
case nominal: NominalAttribute =>
- nominalMetadata += nominal.toMetadata(withType = false)
+ nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute =>
- binaryMetadata += binary.toMetadata(withType = false)
+ binaryMetadata += binary.toMetadataImpl(withType = false)
}
val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) {
@@ -151,7 +151,7 @@ class AttributeGroup private (
}
/** Converts to ML metadata */
- def toMetadata: Metadata = toMetadata(Metadata.empty)
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
/** Converts to a StructField with some existing metadata. */
def toStructField(existingMetadata: Metadata): StructField = {
@@ -159,7 +159,7 @@ class AttributeGroup private (
}
/** Converts to a StructField. */
- def toStructField: StructField = toStructField(Metadata.empty)
+ def toStructField(): StructField = toStructField(Metadata.empty)
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 00b7566aab..5717d6ec2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable {
* Converts this attribute to [[Metadata]].
* @param withType whether to include the type info
*/
- private[attribute] def toMetadata(withType: Boolean): Metadata
+ private[attribute] def toMetadataImpl(withType: Boolean): Metadata
/**
* Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
* save space, because numeric type is the default attribute type. For nominal and binary
* attributes, the type info is included.
*/
- private[attribute] def toMetadata(): Metadata = {
+ private[attribute] def toMetadataImpl(): Metadata = {
if (attrType == AttributeType.Numeric) {
- toMetadata(withType = false)
+ toMetadataImpl(withType = false)
} else {
- toMetadata(withType = true)
+ toMetadataImpl(withType = true)
}
}
+ /** Converts to ML metadata with some existing metadata. */
+ def toMetadata(existingMetadata: Metadata): Metadata = {
+ new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
+ .build()
+ }
+
+ /** Converts to ML metadata */
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
+
/**
* Converts to a [[StructField]] with some existing metadata.
* @param existingMetadata existing metadata to carry over
@@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable {
def toStructField(existingMetadata: Metadata): StructField = {
val newMetadata = new MetadataBuilder()
.withMetadata(existingMetadata)
- .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+ .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
.build()
StructField(name.get, DoubleType, nullable = false, newMetadata)
}
@@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable {
/** Converts to a [[StructField]]. */
def toStructField(): StructField = toStructField(Metadata.empty)
- override def toString: String = toMetadata(withType = true).toString
+ override def toString: String = toMetadataImpl(withType = true).toString
}
/** Trait for ML attribute factories. */
@@ -210,7 +221,7 @@ class NumericAttribute private[ml] (
override def isNominal: Boolean = false
/** Convert this attribute to metadata. */
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -353,6 +364,20 @@ class NominalAttribute private[ml] (
/** Copy without the `numValues`. */
def withoutNumValues: NominalAttribute = copy(numValues = None)
+ /**
+ * Get the number of values, either from `numValues` or from `values`.
+ * Return None if unknown.
+ */
+ def getNumValues: Option[Int] = {
+ if (numValues.nonEmpty) {
+ numValues
+ } else if (values.nonEmpty) {
+ Some(values.get.length)
+ } else {
+ None
+ }
+ }
+
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
@@ -363,7 +388,7 @@ class NominalAttribute private[ml] (
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -465,7 +490,7 @@ class BinaryAttribute private[ml] (
new BinaryAttribute(name, index, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder
if (withType) bldr.putString(TYPE, attrType.name)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000000..3855e396b5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassifier
+ extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ with DecisionTreeParams
+ with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type =
+ super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
+ s" with invalid label column, without the number of classes specified.")
+ // TODO: Automatically index labels.
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ override private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
+ strategy.algo = OldAlgo.Classification
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeClassifier {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassificationModel private[ml] (
+ override val parent: DecisionTreeClassifier,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeClassificationModel = {
+ val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+ }
+}
+
+private[ml] object DecisionTreeClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification,
+ s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+ s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 4d960df357..23956c512c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -118,7 +118,7 @@ class StringIndexerModel private[ml] (
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
- .withName(outputColName).withValues(labels).toStructField().metadata
+ .withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
new file mode 100644
index 0000000000..6f4509f03d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.impl.estimator.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
+ Impurity => OldImpurity, Variance => OldVariance}
+
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+ /**
+ * Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree." +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be >= 2 and >= number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature.")
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Should be >= 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Should be >= 1.")
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.")
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.")
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertParam
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+ " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
+ " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+ " checkpoint directory is set in the SparkContext. Must be >= 1.")
+
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+ maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ def setMaxDepth(value: Int): this.type = {
+ require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
+ set(maxDepth, value)
+ this.asInstanceOf[this.type]
+ }
+
+ /** @group getParam */
+ def getMaxDepth: Int = getOrDefault(maxDepth)
+
+ /** @group setParam */
+ def setMaxBins(value: Int): this.type = {
+ require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
+ set(maxBins, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMaxBins: Int = getOrDefault(maxBins)
+
+ /** @group setParam */
+ def setMinInstancesPerNode(value: Int): this.type = {
+ require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
+ set(minInstancesPerNode, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+
+ /** @group setParam */
+ def setMinInfoGain(value: Double): this.type = {
+ set(minInfoGain, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInfoGain: Double = getOrDefault(minInfoGain)
+
+ /** @group expertSetParam */
+ def setMaxMemoryInMB(value: Int): this.type = {
+ require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
+ set(maxMemoryInMB, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+
+ /** @group expertSetParam */
+ def setCacheNodeIds(value: Boolean): this.type = {
+ set(cacheNodeIds, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+
+ /** @group expertSetParam */
+ def setCheckpointInterval(value: Int): this.type = {
+ require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
+ set(checkpointInterval, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
+ * the default for single trees).
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy
+ }
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "gini")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based classifier was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "variance")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based regressor was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ protected def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object TreeRegressorParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index b45bd1499b..ac75e9de1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -32,6 +32,18 @@ package org.apache.spark
* @groupname getParam Parameter getters
* @groupprio getParam 6
*
+ * @groupname expertParam (expert-only) Parameters
+ * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can
+ * take. Users can set and get the parameter values through setters and getters,
+ * respectively.
+ * @groupprio expertParam 7
+ *
+ * @groupname expertSetParam (expert-only) Parameter setters
+ * @groupprio expertSetParam 8
+ *
+ * @groupname expertGetParam (expert-only) Parameter getters
+ * @groupprio expertGetParam 9
+ *
* @groupname Ungrouped Members
* @groupprio Ungrouped 0
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 849c60433c..ddc5907e7f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -296,8 +296,9 @@ private[spark] object Params {
paramMap: ParamMap,
parent: E,
child: M): Unit = {
+ val childParams = child.params.map(_.name).toSet
parent.params.foreach { param =>
- if (paramMap.contains(param)) {
+ if (paramMap.contains(param) && childParams.contains(param.name)) {
child.set(child.getParam(param.name), paramMap(param))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000000..49a8b77acf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class DecisionTreeRegressor
+ extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+ with DecisionTreeParams
+ with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
+ strategy.algo = OldAlgo.Regression
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeRegressor {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param rootNode Root of the decision tree
+ */
+@AlphaComponent
+final class DecisionTreeRegressionModel private[ml] (
+ override val parent: DecisionTreeRegressor,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeRegressionModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeRegressionModel = {
+ val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+ }
+}
+
+private[ml] object DecisionTreeRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression,
+ s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+ s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000000..d6e2203d9f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
+ Node => OldNode, Predict => OldPredict}
+
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+ // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
+ // code into the new API and deprecate the old API.
+
+ /** Prediction this node makes (or would make, if it is an internal node) */
+ def prediction: Double
+
+ /** Impurity measure at this node (for training data) */
+ def impurity: Double
+
+ /** Recursive prediction helper method */
+ private[ml] def predict(features: Vector): Double = prediction
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
+ */
+ private[tree] def subtreeDepth: Int
+
+ /**
+ * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+ * @param id Node ID using old format IDs
+ */
+ private[ml] def toOld(id: Int): OldNode
+}
+
+private[ml] object Node {
+
+ /**
+ * Create a new Node from the old Node format, recursively creating child nodes as needed.
+ */
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ if (oldNode.isLeaf) {
+ // TODO: Once the implementation has been moved to this API, then include sufficient
+ // statistics here.
+ new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ } else {
+ val gain = if (oldNode.stats.nonEmpty) {
+ oldNode.stats.get.gain
+ } else {
+ 0.0
+ }
+ new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+ gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ }
+ }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction Prediction this node makes
+ * @param impurity Impurity measure at this node (for training data)
+ */
+final class LeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double) extends Node {
+
+ override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+ override private[ml] def predict(features: Vector): Double = prediction
+
+ override private[tree] def numDescendants: Int = 0
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"Predict: $prediction\n"
+ }
+
+ override private[tree] def subtreeDepth: Int = 0
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
+ None, None, None, None)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction Prediction this node would make if it were a leaf node
+ * @param impurity Impurity measure at this node (for training data)
+ * @param gain Information gain value.
+ * Values < 0 indicate missing values; this quirk will be removed with future updates.
+ * @param leftChild Left-hand child node
+ * @param rightChild Right-hand child node
+ * @param split Information about the test used to split to the left or right child.
+ */
+final class InternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ val gain: Double,
+ val leftChild: Node,
+ val rightChild: Node,
+ val split: Split) extends Node {
+
+ override def toString: String = {
+ s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+ }
+
+ override private[ml] def predict(features: Vector): Double = {
+ if (split.shouldGoLeft(features)) {
+ leftChild.predict(features)
+ } else {
+ rightChild.predict(features)
+ }
+ }
+
+ override private[tree] def numDescendants: Int = {
+ 2 + leftChild.numDescendants + rightChild.numDescendants
+ }
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+ leftChild.subtreeToString(indentFactor + 1) +
+ prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+ rightChild.subtreeToString(indentFactor + 1)
+ }
+
+ override private[tree] def subtreeDepth: Int = {
+ 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+ }
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ + " since the old API does not support deep trees.")
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
+ Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+ Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+ new OldPredict(leftChild.prediction, prob = 0.0),
+ new OldPredict(rightChild.prediction, prob = 0.0))))
+ }
+}
+
+private object InternalNode {
+
+ /**
+ * Helper method for [[Node.subtreeToString()]].
+ * @param split Split to print
+ * @param left Indicates whether this is the part of the split going to the left,
+ * or that going to the right.
+ */
+ private def splitToString(split: Split, left: Boolean): String = {
+ val featureStr = s"feature ${split.featureIndex}"
+ split match {
+ case contSplit: ContinuousSplit =>
+ if (left) {
+ s"$featureStr <= ${contSplit.threshold}"
+ } else {
+ s"$featureStr > ${contSplit.threshold}"
+ }
+ case catSplit: CategoricalSplit =>
+ val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ if (left) {
+ s"$featureStr in $categoriesStr"
+ } else {
+ s"$featureStr not in $categoriesStr"
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
new file mode 100644
index 0000000000..cb940f6299
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
+import org.apache.spark.mllib.tree.model.{Split => OldSplit}
+
+
+/**
+ * Interface for a "Split," which specifies a test made at a decision tree node
+ * to choose the left or right path.
+ */
+sealed trait Split extends Serializable {
+
+ /** Index of feature which this split tests */
+ def featureIndex: Int
+
+ /** Return true (split to left) or false (split to right) */
+ private[ml] def shouldGoLeft(features: Vector): Boolean
+
+ /** Convert to old Split format */
+ private[tree] def toOld: OldSplit
+}
+
+private[ml] object Split {
+
+ def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
+ oldSplit.featureType match {
+ case OldFeatureType.Categorical =>
+ new CategoricalSplit(featureIndex = oldSplit.feature,
+ leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ case OldFeatureType.Continuous =>
+ new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
+ }
+ }
+}
+
+/**
+ * Split which tests a categorical feature.
+ * @param featureIndex Index of the feature to test
+ * @param leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
+ * @param numCategories Number of categories for this feature.
+ */
+final class CategoricalSplit(
+ override val featureIndex: Int,
+ leftCategories: Array[Double],
+ private val numCategories: Int)
+ extends Split {
+
+ require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+
+ /**
+ * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
+ */
+ private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+
+ /** Set of categories determining the splitting rule, along with [[isLeft]]. */
+ private val categories: Set[Double] = {
+ if (isLeft) {
+ leftCategories.toSet
+ } else {
+ setComplement(leftCategories.toSet)
+ }
+ }
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ if (isLeft) {
+ categories.contains(features(featureIndex))
+ } else {
+ !categories.contains(features(featureIndex))
+ }
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ val oldCats = if (isLeft) {
+ categories
+ } else {
+ setComplement(categories)
+ }
+ OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
+ }
+
+ /** Get sorted categories which split to the left */
+ def getLeftCategories: Array[Double] = {
+ val cats = if (isLeft) categories else setComplement(categories)
+ cats.toArray.sorted
+ }
+
+ /** Get sorted categories which split to the right */
+ def getRightCategories: Array[Double] = {
+ val cats = if (isLeft) setComplement(categories) else categories
+ cats.toArray.sorted
+ }
+
+ /** [0, numCategories) \ cats */
+ private def setComplement(cats: Set[Double]): Set[Double] = {
+ Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
+ }
+}
+
+/**
+ * Split which tests a continuous feature.
+ * @param featureIndex Index of the feature to test
+ * @param threshold If the feature value is <= this threshold, then the split goes left.
+ * Otherwise, it goes right.
+ */
+final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ features(featureIndex) <= threshold
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: ContinuousSplit =>
+ featureIndex == other.featureIndex && threshold == other.threshold
+ case _ =>
+ false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000000..8e3bc3849d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
new file mode 100644
index 0000000000..c84c8b4eb7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
+ NumericAttribute}
+import org.apache.spark.sql.types.StructField
+
+
+/**
+ * :: Experimental ::
+ *
+ * Helper utilities for tree-based algorithms
+ */
+@Experimental
+object MetadataUtils {
+
+ /**
+ * Examine a schema to identify the number of classes in a label column.
+ * Returns None if the number of labels is not specified, or if the label column is continuous.
+ */
+ def getNumClasses(labelSchema: StructField): Option[Int] = {
+ Attribute.fromStructField(labelSchema) match {
+ case numAttr: NumericAttribute => None
+ case binAttr: BinaryAttribute => Some(2)
+ case nomAttr: NominalAttribute => nomAttr.getNumValues
+ }
+ }
+
+ /**
+ * Examine a schema to identify categorical (Binary and Nominal) features.
+ *
+ * @param featuresSchema Schema of the features column.
+ * If a feature does not have metadata, it is assumed to be continuous.
+ * If a feature is Nominal, then it must have the number of values
+ * specified.
+ * @return Map: feature index --> number of categories.
+ * The map's set of keys will be the set of categorical feature indices.
+ */
+ def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
+ val metadata = AttributeGroup.fromStructField(featuresSchema)
+ if (metadata.attributes.isEmpty) {
+ HashMap.empty[Int, Int]
+ } else {
+ metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
+ if (attr == null) {
+ Iterator()
+ } else {
+ attr match {
+ case numAttr: NumericAttribute => Iterator()
+ case binAttr: BinaryAttribute => Iterator(idx -> 2)
+ case nomAttr: NominalAttribute =>
+ nomAttr.getNumValues match {
+ case Some(numValues: Int) => Iterator(idx -> numValues)
+ case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
+ " Nominal (categorical), but it does not have the number of values specified.")
+ }
+ }
+ }
+ }.toMap
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b9d0c56dd1..dfe3a0b691 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging {
}
}
- assert(splits.length > 0)
+ // TODO: Do not fail; just ignore the useless feature.
+ assert(splits.length > 0,
+ s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+ " Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index c02c79f094..0e31c7ed58 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @param validationInput Validation dataset:
- RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- Should be different from and follow the same distribution as input.
- e.g., these two datasets could be created from an original dataset
- by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return a gradient boosted trees model that can be used for prediction
*/
def runWithValidation(
@@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging {
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
- val startingModel = new GradientBoostedTreesModel(
- Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index db01f2e229..055e60c7d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
nodeIdCache.get.deleteAllCheckpoints()
} catch {
case e:IOException =>
- logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
+ logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 664c8df019..2d6b01524f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -89,14 +89,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStragtegy = Strategy.defaultStategy(algo)
- treeStragtegy.maxDepth = 3
+ val treeStrategy = Strategy.defaultStategy(algo)
+ treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
- treeStragtegy.numClasses = 2
- new BoostingStrategy(treeStragtegy, LogLoss)
+ treeStrategy.numClasses = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case Algo.Regression =>
- new BoostingStrategy(treeStragtegy, SquaredError)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 6f570b4e09..2bdef73c4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 24ee9f3d51..778c24526d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -47,10 +47,9 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index d3b82b752f..64ffccbce0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -57,6 +58,5 @@ trait Loss extends Serializable {
* @param label True label.
* @return Measure of model error on datapoint.
*/
- def computeError(prediction: Double, label: Double): Double
-
+ private[mllib] def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 58857ae15e..a5582d3ef3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object SquaredError extends Loss {
2.0 * (prediction - label)
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = prediction - label
err * err
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c9bafd60fb..331af42853 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = DecisionTreeModel.formatVersion
}
object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
+ private[spark] def formatVersion: String = "1.0"
+
private[tree] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 4f72bb8014..708ba04b56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -175,7 +175,7 @@ class Node (
}
}
-private[tree] object Node {
+private[spark] object Node {
/**
* Return a node with the given node id (but nothing else set).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index fef3d2acb2..8341219bfa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
+
/**
* :: Experimental ::
* Represents a random forest model.
@@ -47,7 +48,7 @@ import org.apache.spark.util.Utils
*/
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
- extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy = if (algo == Classification) Vote else Average)
with Saveable {
@@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
RandomForestModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+ override protected def formatVersion: String = RandomForestModel.formatVersion
}
object RandomForestModel extends Loader[RandomForestModel] {
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): RandomForestModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -102,15 +105,13 @@ class GradientBoostedTreesModel(
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
with Saveable {
- require(trees.size == treeWeights.size)
+ require(trees.length == treeWeights.length)
override def save(sc: SparkContext, path: String): Unit = {
TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
-
/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -138,7 +139,7 @@ class GradientBoostedTreesModel(
evaluationArray(0) = predictionAndError.values.mean()
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).map { nTree =>
+ (1 until numIterations).foreach { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = localTreeWeights(nTree)
@@ -155,6 +156,7 @@ class GradientBoostedTreesModel(
evaluationArray
}
+ override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
- iter.map {
- case (lp, (pred, error)) => {
- val newPred = pred + tree.predict(lp.features) * treeWeight
- val newError = loss.computeError(newPred, lp.label)
- (newPred, newError)
- }
+ iter.map { case (lp, (pred, error)) =>
+ val newPred = pred + tree.predict(lp.features) * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
}
}
newPredError
}
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel(
}
/**
- * Get number of trees in forest.
+ * Get number of trees in ensemble.
*/
- def numTrees: Int = trees.size
+ def numTrees: Int = trees.length
/**
- * Get total number of nodes, summed over all trees in the forest.
+ * Get total number of nodes, summed over all trees in the ensemble.
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
new file mode 100644
index 0000000000..43b8787f9d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeClassifier dt = new DecisionTreeClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ }
+ DecisionTreeClassificationModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ DecisionTreeClassificationModel sameModel =
+ DecisionTreeClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
new file mode 100644
index 0000000000..a3a339004f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeRegressor dt = new DecisionTreeRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ }
+ DecisionTreeRegressionModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 0dcfe5a200..17ddd335de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite {
group("abc")
}
assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField))
+ assert(group === AttributeGroup.fromStructField(group.toStructField()))
}
test("attribute group without attributes") {
@@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6ec35b0365..3e1a7196e3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite {
assert(attr.max.isEmpty)
assert(attr.std.isEmpty)
assert(attr.sparsity.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
intercept[NoSuchElementException] {
@@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite {
assert(!attr.isNominal)
assert(attr.name === Some(name))
assert(attr.index === Some(index))
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
val field = attr.toStructField()
@@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite {
assert(attr2.max === Some(1.0))
assert(attr2.std === Some(0.5))
assert(attr2.sparsity === Some(0.3))
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
}
test("bad numeric attributes") {
@@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values.isEmpty)
assert(attr.numValues.isEmpty)
assert(attr.isOrdinal.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values === Some(values))
assert(attr.indexOf("medium") === 1)
assert(attr.getValue(1) === "medium")
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
@@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite {
assert(attr2.index.isEmpty)
assert(attr2.values.get === Array("small", "medium", "large", "x-large"))
assert(attr2.indexOf("x-large") === 3)
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
- assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
+ assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false)))
}
test("bad nominal attributes") {
@@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name.isEmpty)
assert(attr.index.isEmpty)
assert(attr.values.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name === Some(name))
assert(attr.index === Some(index))
assert(attr.values.get === values)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
new file mode 100644
index 0000000000..af88595df5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeClassifierSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ orderedLabeledPointsWithLabel0RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0())
+ orderedLabeledPointsWithLabel1RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1())
+ categoricalDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass())
+ continuousDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass())
+ categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
+ OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification stump with ordered categorical features") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ val numClasses = 2
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
+ val dt = new DecisionTreeClassifier()
+ .setMaxDepth(3)
+ .setMaxBins(100)
+ val numClasses = 2
+ Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd =>
+ DecisionTreeClassifier.supportedImpurities.foreach { impurity =>
+ dt.setImpurity(impurity)
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+ }
+ }
+
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 3
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Binary classification stump with 2 continuous features") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(maxBins)
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with continuous features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with continuous + unordered categorical features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(10)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min instances per node requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxBins(2)
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInfoGain(1.0)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
+ val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 81ef831c42..1b261b2643 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
}
val attrGroup = new AttributeGroup("features", featureAttributes)
val densePoints1WithMeta =
- densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata()))
val vectorIndexer = getIndexer.setMaxCategories(2)
val model = vectorIndexer.fit(densePoints1WithMeta)
// Check that ML metadata are preserved.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
new file mode 100644
index 0000000000..2e57d4ce37
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+private[ml] object TreeTests extends FunSuite {
+
+ /**
+ * Convert the given data to a DataFrame, and set the features and label metadata.
+ * @param data Dataset. Categorical features and labels must already have 0-based indices.
+ * This must be non-empty.
+ * @param categoricalFeatures Map: categorical feature index -> number of distinct values
+ * @param numClasses Number of classes label can take. If 0, mark as continuous.
+ * @return DataFrame with metadata
+ */
+ def setMetadata(
+ data: RDD[LabeledPoint],
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): DataFrame = {
+ val sqlContext = new SQLContext(data.sparkContext)
+ import sqlContext.implicits._
+ val df = data.toDF()
+ val numFeatures = data.first().features.size
+ val featuresAttributes = Range(0, numFeatures).map { feature =>
+ if (categoricalFeatures.contains(feature)) {
+ NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
+ } else {
+ NumericAttribute.defaultAttr.withIndex(feature)
+ }
+ }.toArray
+ val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
+ val labelAttribute = if (numClasses == 0) {
+ NumericAttribute.defaultAttr.withName("label")
+ } else {
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ }
+ val labelMetadata = labelAttribute.toMetadata()
+ df.select(df("features").as("features", featuresMetadata),
+ df("label").as("label", labelMetadata))
+ }
+
+ /** Java-friendly version of [[setMetadata()]] */
+ def setMetadata(
+ data: JavaRDD[LabeledPoint],
+ categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numClasses: Int): DataFrame = {
+ setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numClasses)
+ }
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ checkEqual(a.rootNode, b.rootNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendants are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.prediction === b.prediction)
+ assert(a.impurity === b.impurity)
+ (a, b) match {
+ case (aye: InternalNode, bee: InternalNode) =>
+ assert(aye.split === bee.split)
+ checkEqual(aye.leftChild, bee.leftChild)
+ checkEqual(aye.rightChild, bee.rightChild)
+ case (aye: LeafNode, bee: LeafNode) => // do nothing
+ case _ =>
+ throw new AssertionError("Found mismatched nodes")
+ }
+ }
+
+ // TODO: Reinstate after adding ensembles
+ /**
+ * Check if the two models are exactly the same.
+ * If the models are not equal, this throws an exception.
+ */
+ /*
+ def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ try {
+ a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ TreeTests.checkEqual(treeA, treeB)
+ }
+ assert(a.getTreeWeights === b.getTreeWeights)
+ } catch {
+ case ex: Exception => throw new AssertionError(
+ "checkEqual failed since the two tree ensembles were not identical")
+ }
+ }
+ */
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
new file mode 100644
index 0000000000..0b40fe33fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeRegressorSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Regression stump with 3-ary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ test("Regression stump with binary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: test("model save/load")
+}
+
+private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 4c162df810..249b8eae19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.util.Utils
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests examining individual elements of training
+ /////////////////////////////////////////////////////////////////////////////
+
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(bins(0).length === 0)
}
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+ numClasses = 2, maxBins = 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
+
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+ val children1 = new Array[Node](2)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
+ val children2 = new Array[Node](2)
+ children2(0) = rootNode2.leftNode.get
+ children2(1) = rootNode2.rightNode.get
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+ assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+ val stats1 = children1(i).stats.get
+ val stats2 = children2(i).stats.get
+ assert(stats1.gain === stats2.gain)
+ assert(stats1.impurity === stats2.impurity)
+ assert(stats1.leftImpurity === stats2.leftImpurity)
+ assert(stats1.rightImpurity === stats2.rightImpurity)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(rootNode.predict.predict === 1)
}
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
@@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
- arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 2)
@@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 2 continuous features") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min instances per node requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
-
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClasses = 2, minInstancesPerNode = 2)
@@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min info gain requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
@@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("Avoid aggregation on the last level") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
test("Node.subtreeIterator") {
val model = DecisionTreeSuite.createModel(Classification)
@@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
+ * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
*/
- private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
@@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite {
* make mistakes such as creating loops of Nodes.
* If the trees are not equal, this prints the two trees and throws an exception.
*/
- private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
try {
assert(a.algo === b.algo)
checkEqual(a.topNode, b.topNode)