aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala21
1 files changed, 14 insertions, 7 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b3cc361154..43f13fe24f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
+ numClassesForClassification: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
@@ -68,6 +69,10 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numClassesForClassification")
+ .text(s"number of classes for classification, "
+ + s"default: ${defaultParams.numClassesForClassification}")
+ .action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
@@ -118,7 +123,13 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}
- val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
+ val strategy
+ = new Strategy(
+ algo = params.algo,
+ impurity = impurityCalculator,
+ maxDepth = params.maxDepth,
+ maxBins = params.maxBins,
+ numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)
if (params.algo == Classification) {
@@ -139,12 +150,8 @@ object DecisionTreeRunner {
*/
private def accuracyScore(
model: DecisionTreeModel,
- data: RDD[LabeledPoint],
- threshold: Double = 0.5): Double = {
- def predictedValue(features: Vector): Double = {
- if (model.predict(features) < threshold) 0.0 else 1.0
- }
- val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+ data: RDD[LabeledPoint]): Double = {
+ val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}