diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 46 |
1 files changed, 42 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 361366fde7..6db9ce150d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -328,7 +327,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) + case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") } + case _ => throw new AssertionError("model.rootNode was not an InternalNode") } } @@ -353,6 +354,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => throw new AssertionError("rootNode was not an InternalNode") } // Single group second level tree construction. @@ -424,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val integerStrategies = Array("1", "10", "100", "1000", "10000") + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) + } + } + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) + } + } } test("Binary classification with continuous features: subsampling features") { @@ -471,7 +509,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Test feature importance computed at different subtrees. def testNode(node: Node, expected: Map[Int, Double]): Unit = { val map = new OpenHashMap[Int, Double]() - RandomForest.computeFeatureImportance(node, map) + TreeEnsembleModel.computeFeatureImportance(node, map) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } @@ -493,7 +531,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) .asInstanceOf[DecisionTreeModel] } - val importances: Vector = RandomForest.featureImportances(trees, 2) + val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) @@ -504,7 +542,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val map = new OpenHashMap[Int, Double]() map(0) = 1.0 map(2) = 2.0 - RandomForest.normalizeMapValues(map) + TreeEnsembleModel.normalizeMapValues(map) val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } |