aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala92
1 files changed, 72 insertions, 20 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 6db9bf3cf5..cf3d2cca81 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,7 +21,6 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
@@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD
* ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
*/
object DecisionTreeRunner {
@@ -48,11 +50,12 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
+ dataFormat: String = "libsvm",
algo: Algo = Classification,
- numClassesForClassification: Int = 2,
- maxDepth: Int = 5,
+ maxDepth: Int = 4,
impurity: ImpurityType = Gini,
- maxBins: Int = 100)
+ maxBins: Int = 100,
+ fracTest: Double = 0.2)
def main(args: Array[String]) {
val defaultParams = Params()
@@ -69,25 +72,31 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
- opt[Int]("numClassesForClassification")
- .text(s"number of classes for classification, "
- + s"default: ${defaultParams.numClassesForClassification}")
- .action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.algo == Classification &&
- (params.impurity == Gini || params.impurity == Entropy)) {
- success
- } else if (params.algo == Regression && params.impurity == Variance) {
- success
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
} else {
- failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ if (params.algo == Classification &&
+ (params.impurity == Gini || params.impurity == Entropy)) {
+ success
+ } else if (params.algo == Regression && params.impurity == Variance) {
+ success
+ } else {
+ failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ }
}
}
}
@@ -100,16 +109,57 @@ object DecisionTreeRunner {
}
def run(params: Params) {
+
val conf = new SparkConf().setAppName("DecisionTreeRunner")
val sc = new SparkContext(conf)
// Load training data and cache it.
- val examples = MLUtils.loadLabeledPoints(sc, params.input).cache()
+ val origExamples = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ }
+ // For classification, re-index classes if needed.
+ val (examples, numClasses) = params.algo match {
+ case Classification => {
+ // classCounts: class --> # examples in class
+ val classCounts = origExamples.map(_.label).countByValue()
+ val sortedClasses = classCounts.keys.toList.sorted
+ val numClasses = classCounts.size
+ // classIndexMap: class --> index in 0,...,numClasses-1
+ val classIndexMap = {
+ if (classCounts.keySet != Set(0.0, 1.0)) {
+ sortedClasses.zipWithIndex.toMap
+ } else {
+ Map[Double, Int]()
+ }
+ }
+ val examples = {
+ if (classIndexMap.isEmpty) {
+ origExamples
+ } else {
+ origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+ }
+ }
+ val numExamples = examples.count()
+ println(s"numClasses = $numClasses.")
+ println(s"Per-class example fractions, counts:")
+ println(s"Class\tFrac\tCount")
+ sortedClasses.foreach { c =>
+ val frac = classCounts(c) / numExamples.toDouble
+ println(s"$c\t$frac\t${classCounts(c)}")
+ }
+ (examples, numClasses)
+ }
+ case Regression =>
+ (origExamples, 0)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
- val splits = examples.randomSplit(Array(0.8, 0.2))
+ // Split into training, test.
+ val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
val training = splits(0).cache()
val test = splits(1).cache()
-
val numTraining = training.count()
val numTest = test.count()
@@ -129,17 +179,19 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
- numClassesForClassification = params.numClassesForClassification)
+ numClassesForClassification = numClasses)
val model = DecisionTree.train(training, strategy)
+ println(model)
+
if (params.algo == Classification) {
val accuracy = accuracyScore(model, test)
- println(s"Test accuracy = $accuracy.")
+ println(s"Test accuracy = $accuracy")
}
if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse.")
+ println(s"Test mean squared error = $mse")
}
sc.stop()