aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala16
6 files changed, 31 insertions, 20 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 837d059147..0890e6263e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -189,9 +189,10 @@ object DecisionTreeRunner {
// Create training, test sets.
val splits = if (params.testInput != "") {
// Load testInput.
+ val numFeatures = examples.take(1)(0).features.size
val origTestExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
}
params.algo match {
case Classification => {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 55f422dff0..ce8825cc03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -65,12 +65,6 @@ private[tree] class DTStatsAggregator(
}
/**
- * Indicator for each feature of whether that feature is an unordered feature.
- * TODO: Is Array[Boolean] any faster?
- */
- def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
-
- /**
* Total number of elements stored in this aggregator
*/
private val allStatsSize: Int = featureOffsets.last
@@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator(
* Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
- def getFeatureOffset(featureIndex: Int): Int = {
- require(!isUnordered(featureIndex),
- s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
- s" for unordered feature $featureIndex.")
- featureOffsets(featureIndex)
- }
+ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
- require(isUnordered(featureIndex),
- s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
- s" but was called for ordered feature $featureIndex.")
val baseOffset = featureOffsets(featureIndex)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 212dce2523..772c02670e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl
import scala.collection.mutable
+import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata(
}
-private[tree] object DecisionTreeMetadata {
+private[tree] object DecisionTreeMetadata extends Logging {
/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
@@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata {
}
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+ if (maxPossibleBins < strategy.maxBins) {
+ logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+ s" (= number of training instances)")
+ }
// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index d8476b5cd7..004838ee5b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,12 +17,15 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
+
/**
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-private[tree] class Predict(
+@DeveloperApi
+class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 4d66d6d81c..6a22e2abe5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
*/
override def toString: String = algo match {
case Classification =>
- s"RandomForestModel classifier with $numTrees trees"
+ s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
case Regression =>
- s"RandomForestModel regressor with $numTrees trees"
+ s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
case _ => throw new IllegalArgumentException(
s"RandomForestModel given unknown algo parameter: $algo.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 20d372dc1d..fb44ceb0f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -173,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
+ test("alternating categorical and continuous features with multiclass labels to test indexing") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0))
+ arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+ val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4)
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
+ featureSubsetStrategy = "sqrt", seed = 12345)
+ RandomForestSuite.validateClassifier(model, arr, 1.0)
+ }
+
}
object RandomForestSuite {