aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorManish Amde <manish9ue@gmail.com>2014-04-01 21:40:49 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-01 21:40:49 -0700
commit8b3045ceab591a3f3ca18823c7e2c5faca38a06e (patch)
tree11fa8935ed2793e4c1b3e06b042b459e8a5aee27 /mllib
parent45df9127365f8942794273b8ada004bf6ea3ef10 (diff)
downloadspark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.tar.gz
spark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.tar.bz2
spark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.zip
MLI-1 Decision Trees
Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala1150
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/README.md17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala43
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala90
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala64
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala425
17 files changed, 2188 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
new file mode 100644
index 0000000000..33205b919d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -0,0 +1,1150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.util.control.Breaks._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * A class that implements a decision tree algorithm for classification and regression. It
+ * supports both continuous and categorical features.
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of algorithm (classification, regression, etc.), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+class DecisionTree private(val strategy: Strategy) extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model over an RDD
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+
+ // Cache input RDD for speedup during multiple passes.
+ input.cache()
+ logDebug("algo = " + strategy.algo)
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ logDebug("numSplits = " + bins(0).length)
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ // the max number of nodes possible given the depth of the tree
+ val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
+ // Initialize an array to hold filters applied to points for each node.
+ val filters = new Array[List[Filter]](maxNumNodes)
+ // The filter at the top node is an empty list.
+ filters(0) = List()
+ // Initialize an array to hold parent impurity calculations for each node.
+ val parentImpurities = new Array[Double](maxNumNodes)
+ // dummy value for top node (updated during first split calculation)
+ val nodes = new Array[Node](maxNumNodes)
+
+
+ /*
+ * The main idea here is to perform level-wise training of the decision tree nodes thus
+ * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
+ * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
+ * the sample is only used for the split calculation at the node if the sampled would have
+ * still survived the filters of the parent nodes.
+ */
+
+ // TODO: Convert for loop to while loop
+ breakable {
+ for (level <- 0 until maxDepth) {
+
+ logDebug("#####################################")
+ logDebug("level = " + level)
+ logDebug("#####################################")
+
+ // Find best split for all nodes at a level.
+ val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
+ level, filters, splits, bins)
+
+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ // Extract info for nodes at the current level.
+ extractNodeInfo(nodeSplitStats, level, index, nodes)
+ // Extract info for nodes at the next lower level.
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
+ filters)
+ logDebug("final best split = " + nodeSplitStats._1)
+ }
+ require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+ // Check whether all the nodes at the current level at leaves.
+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
+ logDebug("all leaf = " + allLeaf)
+ if (allLeaf) break // no more tree construction
+ }
+ }
+
+ // Initialize the top or root node of the tree.
+ val topNode = nodes(0)
+ // Build the full tree using the node info calculated in the level-wise best split calculations.
+ topNode.build(nodes)
+
+ new DecisionTreeModel(topNode, strategy.algo)
+ }
+
+ /**
+ * Extract the decision tree node information for the given tree level and node index
+ */
+ private def extractNodeInfo(
+ nodeSplitStats: (Split, InformationGainStats),
+ level: Int,
+ index: Int,
+ nodes: Array[Node]): Unit = {
+ val split = nodeSplitStats._1
+ val stats = nodeSplitStats._2
+ val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
+ val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
+ val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ logDebug("Node = " + node)
+ nodes(nodeIndex) = node
+ }
+
+ /**
+ * Extract the decision tree node information for the children of the node
+ */
+ private def extractInfoForLowerLevels(
+ level: Int,
+ index: Int,
+ maxDepth: Int,
+ nodeSplitStats: (Split, InformationGainStats),
+ parentImpurities: Array[Double],
+ filters: Array[List[Filter]]): Unit = {
+ // 0 corresponds to the left child node and 1 corresponds to the right child node.
+ // TODO: Convert to while loop
+ for (i <- 0 to 1) {
+ // Calculate the index of the node from the node level and the index at the current level.
+ val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
+ if (level < maxDepth - 1) {
+ val impurity = if (i == 0) {
+ nodeSplitStats._2.leftImpurity
+ } else {
+ nodeSplitStats._2.rightImpurity
+ }
+ logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
+ // noting the parent impurities
+ parentImpurities(nodeIndex) = impurity
+ // noting the parents filters for the child nodes
+ val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
+ filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
+ for (filter <- filters(nodeIndex)) {
+ logDebug("Filter = " + filter)
+ }
+ }
+ }
+ }
+}
+
+object DecisionTree extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The method supports binary classification and regression. For the
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
+ * classes. The parameters for the algorithm are specified using the strategy parameter.
+ *
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of algorithm (classification, regression, etc.), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The method supports binary classification and regression. For the
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
+ * classes.
+ *
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+ * training data
+ * @param algo algorithm, classification or regression
+ * @param impurity impurity criterion used for information gain calculation
+ * @param maxDepth maxDepth maximum depth of the tree
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int): DecisionTreeModel = {
+ val strategy = new Strategy(algo,impurity,maxDepth)
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The decision tree method supports binary classification and
+ * regression. For the binary classification, the label for each instance should either be 0 or
+ * 1 to denote the two classes. The method also supports categorical features inputs where the
+ * number of categories can specified using the categoricalFeaturesInfo option.
+ *
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+ * training data for DecisionTree
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth maximum depth of the tree
+ * @param maxBins maximum number of bins used for splitting features
+ * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ maxBins: Int,
+ quantileCalculationStrategy: QuantileStrategy,
+ categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
+ val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
+ categoricalFeaturesInfo)
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+ private val InvalidBinIndex = -1
+
+ /**
+ * Returns an array of optimal splits for all nodes at a given level
+ *
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param parentImpurities Impurities for all parent nodes for the current level
+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+ * parameters for construction the DecisionTree
+ * @param level Level of the tree
+ * @param filters Filters for all nodes at a given level
+ * @param splits possible splits for all features
+ * @param bins possible bins for all features
+ * @return array of splits with best splits for all nodes at a given level.
+ */
+ protected[tree] def findBestSplits(
+ input: RDD[LabeledPoint],
+ parentImpurities: Array[Double],
+ strategy: Strategy,
+ level: Int,
+ filters: Array[List[Filter]],
+ splits: Array[Array[Split]],
+ bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
+
+ /*
+ * The high-level description for the best split optimizations are noted here.
+ *
+ * *Level-wise training*
+ * We perform bin calculations for all nodes at the given level to avoid making multiple
+ * passes over the data. Thus, for a slightly increased computation and storage cost we save
+ * several iterations over the data especially at higher levels of the decision tree.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+ * is ordered (read ordering for categorical variables in the findSplitsBins method),
+ * we exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ // common calculations for multiple nested methods
+ val numNodes = scala.math.pow(2, level).toInt
+ logDebug("numNodes = " + numNodes)
+ // Find the number of features by looking at the first sample.
+ val numFeatures = input.first().features.length
+ logDebug("numFeatures = " + numFeatures)
+ val numBins = bins(0).length
+ logDebug("numBins = " + numBins)
+
+ /** Find the filters used before reaching the current code. */
+ def findParentFilters(nodeIndex: Int): List[Filter] = {
+ if (level == 0) {
+ List[Filter]()
+ } else {
+ val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
+ filters(nodeFilterIndex)
+ }
+ }
+
+ /**
+ * Find whether the sample is valid input for the current node, i.e., whether it passes through
+ * all the filters for the current node.
+ */
+ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
+ // leaf
+ if ((level > 0) & (parentFilters.length == 0)) {
+ return false
+ }
+
+ // Apply each filter and check sample validity. Return false when invalid condition found.
+ for (filter <- parentFilters) {
+ val features = labeledPoint.features
+ val featureIndex = filter.split.feature
+ val threshold = filter.split.threshold
+ val comparison = filter.comparison
+ val categories = filter.split.categories
+ val isFeatureContinuous = filter.split.featureType == Continuous
+ val feature = features(featureIndex)
+ if (isFeatureContinuous) {
+ comparison match {
+ case -1 => if (feature > threshold) return false
+ case 1 => if (feature <= threshold) return false
+ }
+ } else {
+ val containsFeature = categories.contains(feature)
+ comparison match {
+ case -1 => if (!containsFeature) return false
+ case 1 => if (containsFeature) return false
+ }
+
+ }
+ }
+
+ // Return true when the sample is valid for all filters.
+ true
+ }
+
+ /**
+ * Find bin for one feature.
+ */
+ def findBin(
+ featureIndex: Int,
+ labeledPoint: LabeledPoint,
+ isFeatureContinuous: Boolean): Int = {
+ val binForFeatures = bins(featureIndex)
+ val feature = labeledPoint.features(featureIndex)
+
+ /**
+ * Binary search helper method for continuous feature.
+ */
+ def binarySearchForBins(): Int = {
+ var left = 0
+ var right = binForFeatures.length - 1
+ while (left <= right) {
+ val mid = left + (right - left) / 2
+ val bin = binForFeatures(mid)
+ val lowThreshold = bin.lowSplit.threshold
+ val highThreshold = bin.highSplit.threshold
+ if ((lowThreshold < feature) & (highThreshold >= feature)){
+ return mid
+ }
+ else if (lowThreshold >= feature) {
+ right = mid - 1
+ }
+ else {
+ left = mid + 1
+ }
+ }
+ -1
+ }
+
+ /**
+ * Sequential search helper method to find bin for categorical feature.
+ */
+ def sequentialBinSearchForCategoricalFeature(): Int = {
+ val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+ var binIndex = 0
+ while (binIndex < numCategoricalBins) {
+ val bin = bins(featureIndex)(binIndex)
+ val category = bin.category
+ val features = labeledPoint.features
+ if (category == features(featureIndex)) {
+ return binIndex
+ }
+ binIndex += 1
+ }
+ -1
+ }
+
+ if (isFeatureContinuous) {
+ // Perform binary search for finding bin for continuous features.
+ val binIndex = binarySearchForBins()
+ if (binIndex == -1){
+ throw new UnknownError("no bin was found for continuous variable.")
+ }
+ binIndex
+ } else {
+ // Perform sequential search to find bin for categorical features.
+ val binIndex = sequentialBinSearchForCategoricalFeature()
+ if (binIndex == -1){
+ throw new UnknownError("no bin was found for categorical variable.")
+ }
+ binIndex
+ }
+ }
+
+ /**
+ * Finds bins for all nodes (and all features) at a given level.
+ * For l nodes, k features the storage is as follows:
+ * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
+ * where b_ij is an integer between 0 and numBins - 1.
+ * Invalid sample is denoted by noting bin for feature 1 as -1.
+ */
+ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
+ // Calculate bin index and label per feature per node.
+ val arr = new Array[Double](1 + (numFeatures * numNodes))
+ arr(0) = labeledPoint.label
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ val parentFilters = findParentFilters(nodeIndex)
+ // Find out whether the sample qualifies for the particular node.
+ val sampleValid = isSampleValid(parentFilters, labeledPoint)
+ val shift = 1 + numFeatures * nodeIndex
+ if (!sampleValid) {
+ // Mark one bin as -1 is sufficient.
+ arr(shift) = InvalidBinIndex
+ } else {
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ arr
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition for classification. For l nodes,
+ * k features, either the left count or the right count of one of the p bins is
+ * incremented based upon whether the feature is classified as 0 or 1.
+ *
+ * @param agg Array[Double] storing aggregate calculation of size
+ * 2 * numSplits * numFeatures*numNodes for classification
+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+ * @return Array[Double] storing aggregate calculation of size
+ * 2 * numSplits * numFeatures * numNodes for classification
+ */
+ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ // Iterate over all nodes.
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ // Check whether the instance was valid for this nodeIndex.
+ val validSignalIndex = 1 + numFeatures * nodeIndex
+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+ if (isSampleValidForNode) {
+ // actual class label
+ val label = arr(0)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update the left or right count for one bin.
+ val aggShift = 2 * numBins * numFeatures * nodeIndex
+ val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
+ label match {
+ case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
+ case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+ }
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition for regression. For l nodes, k features,
+ * the count, sum, sum of squares of one of the p bins is incremented.
+ *
+ * @param agg Array[Double] storing aggregate calculation of size
+ * 3 * numSplits * numFeatures * numNodes for classification
+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+ * @return Array[Double] storing aggregate calculation of size
+ * 3 * numSplits * numFeatures * numNodes for regression
+ */
+ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ // Iterate over all nodes.
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ // Check whether the instance was valid for this nodeIndex.
+ val validSignalIndex = 1 + numFeatures * nodeIndex
+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+ if (isSampleValidForNode) {
+ // actual class label
+ val label = arr(0)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update count, sum, and sum^2 for one bin.
+ val aggShift = 3 * numBins * numFeatures * nodeIndex
+ val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
+ agg(aggIndex) = agg(aggIndex) + 1
+ agg(aggIndex + 1) = agg(aggIndex + 1) + label
+ agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition.
+ */
+ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
+ strategy.algo match {
+ case Classification => classificationBinSeqOp(arr, agg)
+ case Regression => regressionBinSeqOp(arr, agg)
+ }
+ agg
+ }
+
+ // Calculate bin aggregate length for classification or regression.
+ val binAggregateLength = strategy.algo match {
+ case Classification => 2 * numBins * numFeatures * numNodes
+ case Regression => 3 * numBins * numFeatures * numNodes
+ }
+ logDebug("binAggregateLength = " + binAggregateLength)
+
+ /**
+ * Combines the aggregates from partitions.
+ * @param agg1 Array containing aggregates from one or more partitions
+ * @param agg2 Array containing aggregates from one or more partitions
+ * @return Combined aggregate from agg1 and agg2
+ */
+ def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
+ var index = 0
+ val combinedAggregate = new Array[Double](binAggregateLength)
+ while (index < binAggregateLength) {
+ combinedAggregate(index) = agg1(index) + agg2(index)
+ index += 1
+ }
+ combinedAggregate
+ }
+
+ // Find feature bins for all nodes at a level.
+ val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+ // Calculate bin aggregates.
+ val binAggregates = {
+ binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+ }
+ logDebug("binAggregates.length = " + binAggregates.length)
+
+ /**
+ * Calculates the information gain for all splits based upon left/right split aggregates.
+ * @param leftNodeAgg left node aggregates
+ * @param featureIndex feature index
+ * @param splitIndex split index
+ * @param rightNodeAgg right node aggregate
+ * @param topImpurity impurity of the parent node
+ * @return information gain and statistics for all splits
+ */
+ def calculateGainForSplit(
+ leftNodeAgg: Array[Array[Double]],
+ featureIndex: Int,
+ splitIndex: Int,
+ rightNodeAgg: Array[Array[Double]],
+ topImpurity: Double): InformationGainStats = {
+ strategy.algo match {
+ case Classification =>
+ val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
+ val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
+ val leftCount = left0Count + left1Count
+
+ val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
+ val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
+ val rightCount = right0Count + right1Count
+
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+ }
+ }
+
+ if (leftCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+ }
+ if (rightCount == 0) {
+ return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+ }
+
+ val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
+ val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+
+ val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+ val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+ val gain = {
+ if (level > 0) {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ } else {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ }
+ }
+
+ val predict = (left1Count + right1Count) / (leftCount + rightCount)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ case Regression =>
+ val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
+ val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
+ val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+ val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
+ val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
+ val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ val count = leftCount + rightCount
+ val sum = leftSum + rightSum
+ val sumSquares = leftSumSquares + rightSumSquares
+ strategy.impurity.calculate(count, sum, sumSquares)
+ }
+ }
+
+ if (leftCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+ rightSum / rightCount)
+ }
+ if (rightCount == 0) {
+ return new InformationGainStats(0, topImpurity ,topImpurity,
+ Double.MinValue, leftSum / leftCount)
+ }
+
+ val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
+ val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+
+ val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+ val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+ val gain = {
+ if (level > 0) {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ } else {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ }
+ }
+
+ val predict = (leftSum + rightSum) / (leftCount + rightCount)
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ }
+ }
+
+ /**
+ * Extracts left and right split aggregates.
+ * @param binData Array[Double] of size 2*numFeatures*numSplits
+ * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
+ * Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+ */
+ def extractLeftRightNodeAggregates(
+ binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+ strategy.algo match {
+ case Classification =>
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // shift for this featureIndex
+ val shift = 2 * featureIndex * numBins
+
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+ leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(2 * (numBins - 2))
+ = binData(shift + (2 * (numBins - 1)))
+ rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
+ = binData(shift + (2 * (numBins - 1)) + 1)
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
+ leftNodeAgg(featureIndex)(2 * splitIndex - 2)
+ leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
+ leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
+
+ // calculating right node aggregate for a split as a sum of right node aggregate of a
+ // higher split and the right bin aggregate of a bin where the split is a low split
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
+ binData(shift + (2 *(numBins - 2 - splitIndex))) +
+ rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
+ binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+ rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
+
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
+ case Regression =>
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // shift for this featureIndex
+ val shift = 3 * featureIndex * numBins
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+ leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+ leftNodeAgg(featureIndex)(2) = binData(shift + 2)
+
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
+ binData(shift + (3 * (numBins - 1)))
+ rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
+ binData(shift + (3 * (numBins - 1)) + 1)
+ rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
+ binData(shift + (3 * (numBins - 1)) + 2)
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3)
+ leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
+ leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
+
+ // calculating right node aggregate for a split as a sum of right node aggregate of a
+ // higher split and the right bin aggregate of a bin where the split is a low split
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
+ binData(shift + (3 * (numBins - 2 - splitIndex))) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
+ binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
+ binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
+
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
+ }
+ }
+
+ /**
+ * Calculates information gain for all nodes splits.
+ */
+ def calculateGainsForAllNodeSplits(
+ leftNodeAgg: Array[Array[Double]],
+ rightNodeAgg: Array[Array[Double]],
+ nodeImpurity: Double): Array[Array[InformationGainStats]] = {
+ val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
+
+ for (featureIndex <- 0 until numFeatures) {
+ for (splitIndex <- 0 until numBins - 1) {
+ gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
+ splitIndex, rightNodeAgg, nodeImpurity)
+ }
+ }
+ gains
+ }
+
+ /**
+ * Find the best split for a node.
+ * @param binData Array[Double] of size 2 * numSplits * numFeatures
+ * @param nodeImpurity impurity of the top node
+ * @return tuple of split and information gain
+ */
+ def binsToBestSplit(
+ binData: Array[Double],
+ nodeImpurity: Double): (Split, InformationGainStats) = {
+
+ logDebug("node impurity = " + nodeImpurity)
+
+ // Extract left right node aggregates.
+ val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
+
+ // Calculate gains for all splits.
+ val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
+
+ val (bestFeatureIndex,bestSplitIndex, gainStats) = {
+ // Initialize with infeasible values.
+ var bestFeatureIndex = Int.MinValue
+ var bestSplitIndex = Int.MinValue
+ var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
+ // Iterate over features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Iterate over all splits.
+ var splitIndex = 0
+ while (splitIndex < numBins - 1) {
+ val gainStats = gains(featureIndex)(splitIndex)
+ if (gainStats.gain > bestGainStats.gain) {
+ bestGainStats = gainStats
+ bestFeatureIndex = featureIndex
+ bestSplitIndex = splitIndex
+ }
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (bestFeatureIndex, bestSplitIndex, bestGainStats)
+ }
+
+ logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+ logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
+
+ (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
+ }
+
+ /**
+ * Get bin data for one node.
+ */
+ def getBinDataForNode(node: Int): Array[Double] = {
+ strategy.algo match {
+ case Classification =>
+ val shift = 2 * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
+ binsForNode
+ case Regression =>
+ val shift = 3 * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+ binsForNode
+ }
+ }
+
+ // Calculate best splits for all nodes at a given level
+ val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ // Iterating over all nodes at this level
+ var node = 0
+ while (node < numNodes) {
+ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
+ val binsForNode: Array[Double] = getBinDataForNode(node)
+ logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
+ val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
+ logDebug("node impurity = " + parentNodeImpurity)
+ bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
+ node += 1
+ }
+
+ bestSplits
+ }
+
+ /**
+ * Returns split and bins for decision tree calculation.
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+ * parameters for construction the DecisionTree
+ * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree
+ * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
+ * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
+ */
+ protected[tree] def findSplitsBins(
+ input: RDD[LabeledPoint],
+ strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+ val count = input.count()
+
+ // Find the number of features by looking at the first sample
+ val numFeatures = input.take(1)(0).features.length
+
+ val maxBins = strategy.maxBins
+ val numBins = if (maxBins <= count) maxBins else count.toInt
+ logDebug("numBins = " + numBins)
+
+ /*
+ * TODO: Add a require statement ensuring #bins is always greater than the categories.
+ * It's a limitation of the current implementation but a reasonable trade-off since features
+ * with large number of categories get favored over continuous features.
+ */
+ if (strategy.categoricalFeaturesInfo.size > 0) {
+ val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+ require(numBins >= maxCategoriesForFeatures)
+ }
+
+ // Calculate the number of sample for approximate quantile calculation.
+ val requiredSamples = numBins*numBins
+ val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
+ logDebug("fraction of data used for calculating quantiles = " + fraction)
+
+ // sampled input for RDD calculation
+ val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect()
+ val numSamples = sampledInput.length
+
+ val stride: Double = numSamples.toDouble / numBins
+ logDebug("stride = " + stride)
+
+ strategy.quantileCalculationStrategy match {
+ case Sort =>
+ val splits = Array.ofDim[Split](numFeatures, numBins - 1)
+ val bins = Array.ofDim[Bin](numFeatures, numBins)
+
+ // Find all splits.
+
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures){
+ // Check whether the feature is continuous.
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
+ val stride: Double = numSamples.toDouble / numBins
+ logDebug("stride = " + stride)
+ for (index <- 0 until numBins - 1) {
+ val sampleIndex = (index + 1) * stride.toInt
+ val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
+ splits(featureIndex)(index) = split
+ }
+ } else {
+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
+ require(maxFeatureValue < numBins, "number of categories should be less than number " +
+ "of bins")
+
+ // For categorical variables, each bin is a category. The bins are sorted and they
+ // are ordered by calculating the centroid of their corresponding labels.
+ val centroidForCategories =
+ sampledInput.map(lp => (lp.features(featureIndex),lp.label))
+ .groupBy(_._1)
+ .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+
+ // Check for missing categorical variables and putting them last in the sorted list.
+ val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+ for (i <- 0 until maxFeatureValue) {
+ if (centroidForCategories.contains(i)) {
+ fullCentroidForCategories(i) = centroidForCategories(i)
+ } else {
+ fullCentroidForCategories(i) = Double.MaxValue
+ }
+ }
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+ logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+ var categoriesForSplit = List[Double]()
+ categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+ case ((key, value), index) =>
+ categoriesForSplit = key :: categoriesForSplit
+ splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
+ categoriesForSplit)
+ bins(featureIndex)(index) = {
+ if (index == 0) {
+ new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+ splits(featureIndex)(0), Categorical, key)
+ } else {
+ new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+ Categorical, key)
+ }
+ }
+ }
+ }
+ featureIndex += 1
+ }
+
+ // Find all bins.
+ featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
+ bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+ splits(featureIndex)(0), Continuous, Double.MinValue)
+ for (index <- 1 until numBins - 1){
+ val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+ Continuous, Double.MinValue)
+ bins(featureIndex)(index) = bin
+ }
+ bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
+ new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+ }
+ featureIndex += 1
+ }
+ (splits,bins)
+ case MinMax =>
+ throw new UnsupportedOperationException("minmax not supported yet.")
+ case ApproxHist =>
+ throw new UnsupportedOperationException("approximate histogram not supported yet.")
+ }
+ }
+
+ val usage = """
+ Usage: DecisionTreeRunner <master>[slices] --algo <Classification,
+ Regression> --trainDataDir path --testDataDir path --maxDepth num [--impurity <Gini,Entropy,
+ Variance>] [--maxBins num]
+ """
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ System.err.println(usage)
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(args(0), "DecisionTree")
+
+ val argList = args.toList.drop(1)
+ type OptionMap = Map[Symbol, Any]
+
+ def nextOption(map : OptionMap, list: List[String]): OptionMap = {
+ list match {
+ case Nil => map
+ case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail)
+ case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
+ case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
+ case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
+ case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string)
+ , tail)
+ case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string),
+ tail)
+ case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
+ case option :: tail => logError("Unknown option " + option)
+ sys.exit(1)
+ }
+ }
+ val options = nextOption(Map(), argList)
+ logDebug(options.toString())
+
+ // Load training data.
+ val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
+
+ // Identify the type of algorithm.
+ val algoStr = options.get('algo).get.toString
+ val algo = algoStr match {
+ case "Classification" => Classification
+ case "Regression" => Regression
+ }
+
+ // Identify the type of impurity.
+ val impurityStr = options.getOrElse('impurity,
+ if (algo == Classification) "Gini" else "Variance").toString
+ val impurity = impurityStr match {
+ case "Gini" => Gini
+ case "Entropy" => Entropy
+ case "Variance" => Variance
+ }
+
+ val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
+ val maxBins = options.getOrElse('maxBins, "100").toString.toInt
+
+ val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
+ val model = DecisionTree.train(trainData, strategy)
+
+ // Load test data.
+ val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
+
+ // Measure algorithm accuracy
+ if (algo == Classification) {
+ val accuracy = accuracyScore(model, testData)
+ logDebug("accuracy = " + accuracy)
+ }
+
+ if (algo == Regression) {
+ val mse = meanSquaredError(model, testData)
+ logDebug("mean square error = " + mse)
+ }
+
+ sc.stop()
+ }
+
+ /**
+ * Load labeled data from a file. The data format used here is
+ * <L>, <f1> <f2> ...,
+ * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
+ *
+ * @param sc SparkContext
+ * @param dir Directory to the input data files.
+ * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
+ * the label, and the second element represents the feature values (an array of Double).
+ */
+ def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
+ sc.textFile(dir).map { line =>
+ val parts = line.trim().split(",")
+ val label = parts(0).toDouble
+ val features = parts.slice(1,parts.length).map(_.toDouble)
+ LabeledPoint(label, features)
+ }
+ }
+
+ // TODO: Port this method to a generic metrics package.
+ /**
+ * Calculates the classifier accuracy.
+ */
+ private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
+ threshold: Double = 0.5): Double = {
+ def predictedValue(features: Array[Double]) = {
+ if (model.predict(features) < threshold) 0.0 else 1.0
+ }
+ val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+ val count = data.count()
+ logDebug("correct prediction count = " + correctCount)
+ logDebug("data count = " + count)
+ correctCount.toDouble / count
+ }
+
+ // TODO: Port this method to a generic metrics package
+ /**
+ * Calculates the mean squared error for regression.
+ */
+ private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = tree.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
new file mode 100644
index 0000000000..0fd71aa973
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
@@ -0,0 +1,17 @@
+This package contains the default implementation of the decision tree algorithm.
+
+The decision tree algorithm supports:
++ Binary classification
++ Regression
++ Information loss calculation with entropy and gini for classification and variance for regression
++ Both continuous and categorical features
+
+# Tree improvements
++ Node model pruning
++ Printing to dot files
+
+# Future Ensemble Extensions
+
++ Random forests
++ Boosting
++ Extremely randomized trees
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
new file mode 100644
index 0000000000..2dd1f0f27b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to select the algorithm for the decision tree
+ */
+object Algo extends Enumeration {
+ type Algo = Value
+ val Classification, Regression = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
new file mode 100644
index 0000000000..09ee0586c5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to describe whether a feature is "continuous" or "categorical"
+ */
+object FeatureType extends Enumeration {
+ type FeatureType = Value
+ val Continuous, Categorical = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
new file mode 100644
index 0000000000..2457a480c2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum for selecting the quantile calculation strategy
+ */
+object QuantileStrategy extends Enumeration {
+ type QuantileStrategy = Value
+ val Sort, MinMax, ApproxHist = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
new file mode 100644
index 0000000000..df565f3eb8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+
+/**
+ * Stores all the configuration options for tree construction
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth maximum depth of the tree
+ * @param maxBins maximum number of bins used for splitting features
+ * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ * number of discrete values they take. For example, an entry (n ->
+ * k) implies the feature n is categorical with k categories 0,
+ * 1, 2, ... , k-1. It's important to note that features are
+ * zero-indexed.
+ */
+class Strategy (
+ val algo: Algo,
+ val impurity: Impurity,
+ val maxDepth: Int,
+ val maxBins: Int = 100,
+ val quantileCalculationStrategy: QuantileStrategy = Sort,
+ val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
new file mode 100644
index 0000000000..b93995fcf9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
+ * binary classification.
+ */
+object Entropy extends Impurity {
+
+ def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
+
+ /**
+ * entropy calculation
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return entropy value
+ */
+ def calculate(c0: Double, c1: Double): Double = {
+ if (c0 == 0 || c1 == 0) {
+ 0
+ } else {
+ val total = c0 + c1
+ val f0 = c0 / total
+ val f1 = c1 / total
+ -(f0 * log2(f0)) - (f1 * log2(f1))
+ }
+ }
+
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ throw new UnsupportedOperationException("Entropy.calculate")
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
new file mode 100644
index 0000000000..c0407554a9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating the
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
+ * during binary classification.
+ */
+object Gini extends Impurity {
+
+ /**
+ * Gini coefficient calculation
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return Gini coefficient value
+ */
+ override def calculate(c0: Double, c1: Double): Double = {
+ if (c0 == 0 || c1 == 0) {
+ 0
+ } else {
+ val total = c0 + c1
+ val f0 = c0 / total
+ val f1 = c1 / total
+ 1 - f0 * f0 - f1 * f1
+ }
+ }
+
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ throw new UnsupportedOperationException("Gini.calculate")
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
new file mode 100644
index 0000000000..a4069063af
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Trait for calculating information gain.
+ */
+trait Impurity extends Serializable {
+
+ /**
+ * information calculation for binary classification
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return information value
+ */
+ def calculate(c0 : Double, c1 : Double): Double
+
+ /**
+ * information calculation for regression
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ * @return information value
+ */
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
new file mode 100644
index 0000000000..b74577dcec
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating variance during regression
+ */
+object Variance extends Impurity {
+ override def calculate(c0: Double, c1: Double): Double =
+ throw new UnsupportedOperationException("Variance.calculate")
+
+ /**
+ * variance calculation
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ */
+ override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
+ val squaredLoss = sumSquares - (sum * sum) / count
+ squaredLoss / count
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
new file mode 100644
index 0000000000..a57faa1374
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+/**
+ * Used for "binning" the features bins for faster best split calculation. For a continuous
+ * feature, a bin is determined by a low and a high "split". For a categorical feature,
+ * the a bin is determined using a single label value (category).
+ * @param lowSplit signifying the lower threshold for the continuous feature to be
+ * accepted in the bin
+ * @param highSplit signifying the upper threshold for the continuous feature to be
+ * accepted in the bin
+ * @param featureType type of feature -- categorical or continuous
+ * @param category categorical label value accepted in the bin
+ */
+case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
new file mode 100644
index 0000000000..a8bbf21dae
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model to store the decision tree parameters
+ * @param topNode root node
+ * @param algo algorithm type -- classification or regression
+ */
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ def predict(features: Array[Double]): Double = {
+ topNode.predictIfLeaf(features)
+ }
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Int] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Array[Double]]): RDD[Double] = {
+ features.map(x => predict(x))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
new file mode 100644
index 0000000000..ebc9595eaf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Filter specifying a split and type of comparison to be applied on features
+ * @param split split specifying the feature index, type and threshold
+ * @param comparison integer specifying <,=,>
+ */
+case class Filter(split: Split, comparison: Int) {
+ // Comparison -1,0,1 signifies <.=,>
+ override def toString = " split = " + split + "comparison = " + comparison
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
new file mode 100644
index 0000000000..99bf79cf12
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Information gain statistics for each split
+ * @param gain information gain value
+ * @param impurity current node impurity
+ * @param leftImpurity left node impurity
+ * @param rightImpurity right node impurity
+ * @param predict predicted value
+ */
+class InformationGainStats(
+ val gain: Double,
+ val impurity: Double,
+ val leftImpurity: Double,
+ val rightImpurity: Double,
+ val predict: Double) extends Serializable {
+
+ override def toString = {
+ "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
+ .format(gain, impurity, leftImpurity, rightImpurity, predict)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
new file mode 100644
index 0000000000..ea4693c5c2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+/**
+ * Node in a decision tree
+ * @param id integer node id
+ * @param predict predicted value at the node
+ * @param isLeaf whether the leaf is a node
+ * @param split split to calculate left and right nodes
+ * @param leftNode left child
+ * @param rightNode right child
+ * @param stats information gain stats
+ */
+class Node (
+ val id: Int,
+ val predict: Double,
+ val isLeaf: Boolean,
+ val split: Option[Split],
+ var leftNode: Option[Node],
+ var rightNode: Option[Node],
+ val stats: Option[InformationGainStats]) extends Serializable with Logging {
+
+ override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+ "split = " + split + ", stats = " + stats
+
+ /**
+ * build the left node and right nodes if not leaf
+ * @param nodes array of nodes
+ */
+ def build(nodes: Array[Node]): Unit = {
+
+ logDebug("building node " + id + " at level " +
+ (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+ logDebug("id = " + id + ", split = " + split)
+ logDebug("stats = " + stats)
+ logDebug("predict = " + predict)
+ if (!isLeaf) {
+ val leftNodeIndex = id*2 + 1
+ val rightNodeIndex = id*2 + 2
+ leftNode = Some(nodes(leftNodeIndex))
+ rightNode = Some(nodes(rightNodeIndex))
+ leftNode.get.build(nodes)
+ rightNode.get.build(nodes)
+ }
+ }
+
+ /**
+ * predict value if node is not leaf
+ * @param feature feature value
+ * @return predicted value
+ */
+ def predictIfLeaf(feature: Array[Double]) : Double = {
+ if (isLeaf) {
+ predict
+ } else{
+ if (split.get.featureType == Continuous) {
+ if (feature(split.get.feature) <= split.get.threshold) {
+ leftNode.get.predictIfLeaf(feature)
+ } else {
+ rightNode.get.predictIfLeaf(feature)
+ }
+ } else {
+ if (split.get.categories.contains(feature(split.get.feature))) {
+ leftNode.get.predictIfLeaf(feature)
+ } else {
+ rightNode.get.predictIfLeaf(feature)
+ }
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
new file mode 100644
index 0000000000..4e64a81dda
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+
+/**
+ * Split applied to a feature
+ * @param feature feature index
+ * @param threshold threshold for continuous feature
+ * @param featureType type of feature -- categorical or continuous
+ * @param categories accepted values for categorical variables
+ */
+case class Split(
+ feature: Int,
+ threshold: Double,
+ featureType: FeatureType,
+ categories: List[Double]){
+
+ override def toString =
+ "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
+ ", categories = " + categories
+}
+
+/**
+ * Split with minimum threshold for continuous features. Helps with the smallest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyLowSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MinValue, featureType, List())
+
+/**
+ * Split with maximum threshold for continuous features. Helps with the highest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyHighSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MaxValue, featureType, List())
+
+/**
+ * Split with no acceptable feature values for categorical features. Helps with the first bin
+ * creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MaxValue, featureType, List())
+
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
new file mode 100644
index 0000000000..4349c7000a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.Filter
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("split and bin calculation") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+ }
+
+ test("split and bin calculation for categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2) === null)
+ }
+
+ test("split and bin calculations for categorical variables with no sample for one category") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 3)
+ assert(splits(0)(2).categories.contains(1.0))
+ assert(splits(0)(2).categories.contains(0.0))
+ assert(splits(0)(2).categories.contains(2.0))
+
+ assert(splits(0)(3) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2).feature === 1)
+ assert(splits(1)(2).threshold === Double.MinValue)
+ assert(splits(1)(2).featureType === Categorical)
+ assert(splits(1)(2).categories.length === 3)
+ assert(splits(1)(2).categories.contains(1.0))
+ assert(splits(1)(2).categories.contains(0.0))
+ assert(splits(1)(2).categories.contains(2.0))
+
+ assert(splits(1)(3) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2).category === 2.0)
+ assert(bins(0)(2).lowSplit.categories.length === 2)
+ assert(bins(0)(2).lowSplit.categories.contains(1.0))
+ assert(bins(0)(2).lowSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.length === 3)
+ assert(bins(0)(2).highSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(0)(3) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2).category === 2.0)
+ assert(bins(1)(2).lowSplit.categories.length === 2)
+ assert(bins(1)(2).lowSplit.categories.contains(0.0))
+ assert(bins(1)(2).lowSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.length === 3)
+ assert(bins(1)(2).highSplit.categories.contains(0.0))
+ assert(bins(1)(2).highSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(1)(3) === null)
+ }
+
+ test("classification stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("regression stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("stump with fixed label 0 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ }
+
+ test("stump with fixed label 1 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+
+ test("stump with fixed label 0 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 0)
+ }
+
+ test("stump with fixed label 1 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+}
+
+object DecisionTreeSuite {
+
+ def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateCategoricalDataPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ if (i < 600){
+ arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
+ } else {
+ arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
+ }
+ }
+ arr
+ }
+}