aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-07-31 20:51:48 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-31 20:51:48 -0700
commitb124de584a45b7ebde9fbe10128db429c56aeaee (patch)
tree6f8b70447a4d825cfe1ad71870741f8f8d77ba3d /examples
parentd8430148ee1f6ba02569db0538eeae473a32c78e (diff)
downloadspark-b124de584a45b7ebde9fbe10128db429c56aeaee.tar.gz
spark-b124de584a45b7ebde9fbe10128db429c56aeaee.tar.bz2
spark-b124de584a45b7ebde9fbe10128db429c56aeaee.zip
[SPARK-2756] [mllib] Decision tree bug fixes
(1) Inconsistent aggregate (agg) indexing for unordered features. (2) Fixed gain calculations for edge cases. (3) One-off error in choosing thresholds for continuous features for small datasets. (4) (not a bug) Changed meaning of tree depth by 1 to fit scikit-learn and rpart. (Depth 1 used to mean 1 leaf node; depth 0 now means 1 leaf node.) Other updates, to help with tests: * Updated DecisionTreeRunner to print more info. * Added utility functions to DecisionTreeModel: toString, depth, numNodes * Improved internal DecisionTree documentation Bug fix details: (1) Indexing was inconsistent for aggregate calculations for unordered features (in multiclass classification with categorical features, where the features had few enough values such that they could be considered unordered, i.e., isSpaceSufficientForAllCategoricalSplits=true). * updateBinForUnorderedFeature indexed agg as (node, feature, featureValue, binIndex), where ** featureValue was from arr (so it was a feature value) ** binIndex was in [0,…, 2^(maxFeatureValue-1)-1) * The rest of the code indexed agg as (node, feature, binIndex, label). * Corrected this bug by changing updateBinForUnorderedFeature to use the second indexing pattern. Unit tests in DecisionTreeSuite * Updated a few tests to train a model and test its training accuracy, which catches the indexing bug from updateBinForUnorderedFeature() discussed above. * Added new test (“stump with categorical variables for multiclass classification, with just enough bins”) to test bin extremes. (2) Bug fix: calculateGainForSplit (for classification): * It used to return dummy prediction values when either the right or left children had 0 weight. These were incorrect for multiclass classification. It has been corrected. Updated impurities to allow for count = 0. This was related to the above bug fix for calculateGainForSplit (for classification). Small updates to documentation and coding style. (3) Bug fix: Off-by-1 when finding thresholds for splits for continuous features. * Exhibited bug in new test in DecisionTreeSuite: “stump with 1 continuous variable for binary classification, to check off-by-1 error” * Description: When finding thresholds for possible splits for continuous features in DecisionTree.findSplitsBins, the thresholds were set according to individual training examples’ feature values. * Fix: The threshold is set to be the average of 2 consecutive (sorted) examples’ feature values. E.g.: If the old code set the threshold using example i, the new code sets the threshold using exam * Note: In 4 DecisionTreeSuite tests with all labels identical, removed check of threshold since it is somewhat arbitrary. CC: mengxr manishamde Please let me know if I missed something! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1673 from jkbradley/decisiontree-bugfix and squashes the following commits: 2b20c61 [Joseph K. Bradley] Small doc and style updates dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala92
1 files changed, 72 insertions, 20 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 6db9bf3cf5..cf3d2cca81 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,7 +21,6 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
@@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD
* ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
*/
object DecisionTreeRunner {
@@ -48,11 +50,12 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
+ dataFormat: String = "libsvm",
algo: Algo = Classification,
- numClassesForClassification: Int = 2,
- maxDepth: Int = 5,
+ maxDepth: Int = 4,
impurity: ImpurityType = Gini,
- maxBins: Int = 100)
+ maxBins: Int = 100,
+ fracTest: Double = 0.2)
def main(args: Array[String]) {
val defaultParams = Params()
@@ -69,25 +72,31 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
- opt[Int]("numClassesForClassification")
- .text(s"number of classes for classification, "
- + s"default: ${defaultParams.numClassesForClassification}")
- .action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.algo == Classification &&
- (params.impurity == Gini || params.impurity == Entropy)) {
- success
- } else if (params.algo == Regression && params.impurity == Variance) {
- success
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
} else {
- failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ if (params.algo == Classification &&
+ (params.impurity == Gini || params.impurity == Entropy)) {
+ success
+ } else if (params.algo == Regression && params.impurity == Variance) {
+ success
+ } else {
+ failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ }
}
}
}
@@ -100,16 +109,57 @@ object DecisionTreeRunner {
}
def run(params: Params) {
+
val conf = new SparkConf().setAppName("DecisionTreeRunner")
val sc = new SparkContext(conf)
// Load training data and cache it.
- val examples = MLUtils.loadLabeledPoints(sc, params.input).cache()
+ val origExamples = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ }
+ // For classification, re-index classes if needed.
+ val (examples, numClasses) = params.algo match {
+ case Classification => {
+ // classCounts: class --> # examples in class
+ val classCounts = origExamples.map(_.label).countByValue()
+ val sortedClasses = classCounts.keys.toList.sorted
+ val numClasses = classCounts.size
+ // classIndexMap: class --> index in 0,...,numClasses-1
+ val classIndexMap = {
+ if (classCounts.keySet != Set(0.0, 1.0)) {
+ sortedClasses.zipWithIndex.toMap
+ } else {
+ Map[Double, Int]()
+ }
+ }
+ val examples = {
+ if (classIndexMap.isEmpty) {
+ origExamples
+ } else {
+ origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+ }
+ }
+ val numExamples = examples.count()
+ println(s"numClasses = $numClasses.")
+ println(s"Per-class example fractions, counts:")
+ println(s"Class\tFrac\tCount")
+ sortedClasses.foreach { c =>
+ val frac = classCounts(c) / numExamples.toDouble
+ println(s"$c\t$frac\t${classCounts(c)}")
+ }
+ (examples, numClasses)
+ }
+ case Regression =>
+ (origExamples, 0)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
- val splits = examples.randomSplit(Array(0.8, 0.2))
+ // Split into training, test.
+ val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
val training = splits(0).cache()
val test = splits(1).cache()
-
val numTraining = training.count()
val numTest = test.count()
@@ -129,17 +179,19 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
- numClassesForClassification = params.numClassesForClassification)
+ numClassesForClassification = numClasses)
val model = DecisionTree.train(training, strategy)
+ println(model)
+
if (params.algo == Classification) {
val accuracy = accuracyScore(model, test)
- println(s"Test accuracy = $accuracy.")
+ println(s"Test accuracy = $accuracy")
}
if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse.")
+ println(s"Test mean squared error = $mse")
}
sc.stop()