aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java98
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java97
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala274
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala132
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala91
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala373
9 files changed, 908 insertions, 205 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
new file mode 100644
index 0000000000..43b8787f9d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeClassifier dt = new DecisionTreeClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ }
+ DecisionTreeClassificationModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ DecisionTreeClassificationModel sameModel =
+ DecisionTreeClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
new file mode 100644
index 0000000000..a3a339004f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeRegressor dt = new DecisionTreeRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ }
+ DecisionTreeRegressionModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 0dcfe5a200..17ddd335de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite {
group("abc")
}
assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField))
+ assert(group === AttributeGroup.fromStructField(group.toStructField()))
}
test("attribute group without attributes") {
@@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6ec35b0365..3e1a7196e3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite {
assert(attr.max.isEmpty)
assert(attr.std.isEmpty)
assert(attr.sparsity.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
intercept[NoSuchElementException] {
@@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite {
assert(!attr.isNominal)
assert(attr.name === Some(name))
assert(attr.index === Some(index))
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
val field = attr.toStructField()
@@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite {
assert(attr2.max === Some(1.0))
assert(attr2.std === Some(0.5))
assert(attr2.sparsity === Some(0.3))
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
}
test("bad numeric attributes") {
@@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values.isEmpty)
assert(attr.numValues.isEmpty)
assert(attr.isOrdinal.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values === Some(values))
assert(attr.indexOf("medium") === 1)
assert(attr.getValue(1) === "medium")
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
@@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite {
assert(attr2.index.isEmpty)
assert(attr2.values.get === Array("small", "medium", "large", "x-large"))
assert(attr2.indexOf("x-large") === 3)
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
- assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
+ assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false)))
}
test("bad nominal attributes") {
@@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name.isEmpty)
assert(attr.index.isEmpty)
assert(attr.values.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name === Some(name))
assert(attr.index === Some(index))
assert(attr.values.get === values)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
new file mode 100644
index 0000000000..af88595df5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeClassifierSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ orderedLabeledPointsWithLabel0RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0())
+ orderedLabeledPointsWithLabel1RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1())
+ categoricalDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass())
+ continuousDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass())
+ categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
+ OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification stump with ordered categorical features") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ val numClasses = 2
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
+ val dt = new DecisionTreeClassifier()
+ .setMaxDepth(3)
+ .setMaxBins(100)
+ val numClasses = 2
+ Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd =>
+ DecisionTreeClassifier.supportedImpurities.foreach { impurity =>
+ dt.setImpurity(impurity)
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+ }
+ }
+
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 3
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Binary classification stump with 2 continuous features") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(maxBins)
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with continuous features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with continuous + unordered categorical features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(10)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min instances per node requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxBins(2)
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInfoGain(1.0)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
+ val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 81ef831c42..1b261b2643 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
}
val attrGroup = new AttributeGroup("features", featureAttributes)
val densePoints1WithMeta =
- densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata()))
val vectorIndexer = getIndexer.setMaxCategories(2)
val model = vectorIndexer.fit(densePoints1WithMeta)
// Check that ML metadata are preserved.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
new file mode 100644
index 0000000000..2e57d4ce37
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+private[ml] object TreeTests extends FunSuite {
+
+ /**
+ * Convert the given data to a DataFrame, and set the features and label metadata.
+ * @param data Dataset. Categorical features and labels must already have 0-based indices.
+ * This must be non-empty.
+ * @param categoricalFeatures Map: categorical feature index -> number of distinct values
+ * @param numClasses Number of classes label can take. If 0, mark as continuous.
+ * @return DataFrame with metadata
+ */
+ def setMetadata(
+ data: RDD[LabeledPoint],
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): DataFrame = {
+ val sqlContext = new SQLContext(data.sparkContext)
+ import sqlContext.implicits._
+ val df = data.toDF()
+ val numFeatures = data.first().features.size
+ val featuresAttributes = Range(0, numFeatures).map { feature =>
+ if (categoricalFeatures.contains(feature)) {
+ NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
+ } else {
+ NumericAttribute.defaultAttr.withIndex(feature)
+ }
+ }.toArray
+ val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
+ val labelAttribute = if (numClasses == 0) {
+ NumericAttribute.defaultAttr.withName("label")
+ } else {
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ }
+ val labelMetadata = labelAttribute.toMetadata()
+ df.select(df("features").as("features", featuresMetadata),
+ df("label").as("label", labelMetadata))
+ }
+
+ /** Java-friendly version of [[setMetadata()]] */
+ def setMetadata(
+ data: JavaRDD[LabeledPoint],
+ categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numClasses: Int): DataFrame = {
+ setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numClasses)
+ }
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ checkEqual(a.rootNode, b.rootNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendants are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.prediction === b.prediction)
+ assert(a.impurity === b.impurity)
+ (a, b) match {
+ case (aye: InternalNode, bee: InternalNode) =>
+ assert(aye.split === bee.split)
+ checkEqual(aye.leftChild, bee.leftChild)
+ checkEqual(aye.rightChild, bee.rightChild)
+ case (aye: LeafNode, bee: LeafNode) => // do nothing
+ case _ =>
+ throw new AssertionError("Found mismatched nodes")
+ }
+ }
+
+ // TODO: Reinstate after adding ensembles
+ /**
+ * Check if the two models are exactly the same.
+ * If the models are not equal, this throws an exception.
+ */
+ /*
+ def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ try {
+ a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ TreeTests.checkEqual(treeA, treeB)
+ }
+ assert(a.getTreeWeights === b.getTreeWeights)
+ } catch {
+ case ex: Exception => throw new AssertionError(
+ "checkEqual failed since the two tree ensembles were not identical")
+ }
+ }
+ */
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
new file mode 100644
index 0000000000..0b40fe33fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeRegressorSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Regression stump with 3-ary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ test("Regression stump with binary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: test("model save/load")
+}
+
+private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 4c162df810..249b8eae19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.util.Utils
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests examining individual elements of training
+ /////////////////////////////////////////////////////////////////////////////
+
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(bins(0).length === 0)
}
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+ numClasses = 2, maxBins = 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
+
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+ val children1 = new Array[Node](2)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
+ val children2 = new Array[Node](2)
+ children2(0) = rootNode2.leftNode.get
+ children2(1) = rootNode2.rightNode.get
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+ assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+ val stats1 = children1(i).stats.get
+ val stats2 = children2(i).stats.get
+ assert(stats1.gain === stats2.gain)
+ assert(stats1.impurity === stats2.impurity)
+ assert(stats1.leftImpurity === stats2.leftImpurity)
+ assert(stats1.rightImpurity === stats2.rightImpurity)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(rootNode.predict.predict === 1)
}
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
@@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
- arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 2)
@@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 2 continuous features") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min instances per node requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
-
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClasses = 2, minInstancesPerNode = 2)
@@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min info gain requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
@@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("Avoid aggregation on the last level") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
test("Node.subtreeIterator") {
val model = DecisionTreeSuite.createModel(Classification)
@@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
+ * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
*/
- private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
@@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite {
* make mistakes such as creating loops of Nodes.
* If the trees are not equal, this prints the two trees and throws an exception.
*/
- private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
try {
assert(a.algo === b.algo)
checkEqual(a.topNode, b.topNode)