aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala46
1 files changed, 42 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366fde7..6db9ce150d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -328,7 +327,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
+ case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit")
}
+ case _ => throw new AssertionError("model.rootNode was not an InternalNode")
}
}
@@ -353,6 +354,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(n.leftChild.isInstanceOf[InternalNode])
assert(n.rightChild.isInstanceOf[InternalNode])
Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode])
+ case _ => throw new AssertionError("rootNode was not an InternalNode")
}
// Single group second level tree construction.
@@ -424,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+ val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val integerStrategies = Array("1", "10", "100", "1000", "10000")
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
+ }
+ }
+
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
+ }
+ }
}
test("Binary classification with continuous features: subsampling features") {
@@ -471,7 +509,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test feature importance computed at different subtrees.
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
val map = new OpenHashMap[Int, Double]()
- RandomForest.computeFeatureImportance(node, map)
+ TreeEnsembleModel.computeFeatureImportance(node, map)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
@@ -493,7 +531,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
.asInstanceOf[DecisionTreeModel]
}
- val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
@@ -504,7 +542,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val map = new OpenHashMap[Int, Double]()
map(0) = 1.0
map(2) = 2.0
- RandomForest.normalizeMapValues(map)
+ TreeEnsembleModel.normalizeMapValues(map)
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}