aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-03 10:36:52 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-03 10:36:52 -0700
commit2998e38a942351974da36cb619e863c6f0316e7a (patch)
treebba84386898fa3fd49cbc4acfc3b38457f0c4882
parenta0bcbc159e89be868ccc96175dbf1439461557e1 (diff)
downloadspark-2998e38a942351974da36cb619e863c6f0316e7a.tar.gz
spark-2998e38a942351974da36cb619e863c6f0316e7a.tar.bz2
spark-2998e38a942351974da36cb619e863c6f0316e7a.zip
[SPARK-2197] [mllib] Java DecisionTree bug fix and easy-of-use
Bug fix: Before, when an RDD was created in Java and passed to DecisionTree.train(), the fake class tag caused problems. * Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. Other improvements to Decision Trees for easy-of-use with Java: * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor --> Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. CC: mengxr Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1740 from jkbradley/dt-java-new and squashes the following commits: 0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of JavaConversions 519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java 320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written 13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala7
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java102
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala6
7 files changed, 162 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 382e76a9b7..1d03e6e3b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Cache input RDD for speedup during multiple passes.
- input.cache()
+ val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
logDebug("algo = " + strategy.algo)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
val numBins = bins(0).length
logDebug("numBins = " + numBins)
@@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
- val numFeatures = input.take(1)(0).features.size
+ val numFeatures = retaggedInput.take(1)(0).features.size
// Calculate level for single group construction
@@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")
// Find best split for all nodes at a level.
- val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
+ val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities,
strategy, level, filters, splits, bins, maxLevelForSingleGroup)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index fdad4f029a..4ee4bcd0bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.tree.configuration
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -61,4 +63,31 @@ class Strategy (
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ /**
+ * Java-friendly constructor.
+ *
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * @param numClassesForClassification number of classes for classification. Default value is 2
+ * leads to binary classification
+ * @param maxBins maximum number of bins used for splitting features
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example, an entry
+ * (n -> k) implies the feature n is categorical with k categories
+ * 0, 1, 2, ... , k-1. It's important to note that features are
+ * zero-indexed.
+ */
+ def this(
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClassesForClassification: Int,
+ maxBins: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
+ this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 9297c20596..96d2471e1f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -66,4 +66,11 @@ object Entropy extends Impurity {
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
+
+ /**
+ * Get this impurity instance.
+ * This is useful for passing impurity parameters to a Strategy in Java.
+ */
+ def instance = this
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 2874bcf496..d586f44904 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -62,4 +62,11 @@ object Gini extends Impurity {
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
+
+ /**
+ * Get this impurity instance.
+ * This is useful for passing impurity parameters to a Strategy in Java.
+ */
+ def instance = this
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 698a1a2a8e..f7d99a40eb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -53,4 +53,11 @@ object Variance extends Impurity {
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
}
+
+ /**
+ * Get this impurity instance.
+ * This is useful for passing impurity parameters to a Strategy in Java.
+ */
+ def instance = this
+
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
new file mode 100644
index 0000000000..2c281a1ee7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.configuration.Algo;
+import org.apache.spark.mllib.tree.configuration.Strategy;
+import org.apache.spark.mllib.tree.impurity.Gini;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+
+
+public class JavaDecisionTreeSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
+ int numCorrect = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ if (prediction == point.label()) {
+ numCorrect++;
+ }
+ }
+ return numCorrect;
+ }
+
+ @Test
+ public void runDTUsingConstructor() {
+ List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+ JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
+ categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+ int maxDepth = 4;
+ int numClasses = 2;
+ int maxBins = 100;
+ Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
+ maxBins, categoricalFeaturesInfo);
+
+ DecisionTree learner = new DecisionTree(strategy);
+ DecisionTreeModel model = learner.train(rdd.rdd());
+
+ int numCorrect = validatePrediction(arr, model);
+ Assert.assertTrue(numCorrect == rdd.count());
+ }
+
+ @Test
+ public void runDTUsingStaticMethods() {
+ List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+ JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+ HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
+ categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+ int maxDepth = 4;
+ int numClasses = 2;
+ int maxBins = 100;
+ Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
+ maxBins, categoricalFeaturesInfo);
+
+ DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
+
+ int numCorrect = validatePrediction(arr, model);
+ Assert.assertTrue(numCorrect == rdd.count());
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8665a00f3b..70ca7c8a26 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.tree
+import scala.collection.JavaConverters._
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
@@ -815,6 +817,10 @@ object DecisionTreeSuite {
arr
}
+ def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = {
+ generateCategoricalDataPoints().toList.asJava
+ }
+
def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {