From 2998e38a942351974da36cb619e863c6f0316e7a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 3 Aug 2014 10:36:52 -0700 Subject: [SPARK-2197] [mllib] Java DecisionTree bug fix and easy-of-use Bug fix: Before, when an RDD was created in Java and passed to DecisionTree.train(), the fake class tag caused problems. * Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. Other improvements to Decision Trees for easy-of-use with Java: * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor --> Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. CC: mengxr Author: Joseph K. Bradley Closes #1740 from jkbradley/dt-java-new and squashes the following commits: 0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of JavaConversions 519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java 320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written 13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 8 +- .../spark/mllib/tree/configuration/Strategy.scala | 29 ++++++ .../apache/spark/mllib/tree/impurity/Entropy.scala | 7 ++ .../apache/spark/mllib/tree/impurity/Gini.scala | 7 ++ .../spark/mllib/tree/impurity/Variance.scala | 7 ++ .../spark/mllib/tree/JavaDecisionTreeSuite.java | 102 +++++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 ++ 7 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 382e76a9b7..1d03e6e3b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - input.cache() + val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction @@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index fdad4f029a..4ee4bcd0bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ @@ -61,4 +63,31 @@ class Strategy ( val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + /** + * Java-friendly constructor. + * + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification + * @param maxBins maximum number of bins used for splitting features + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, an entry + * (n -> k) implies the feature n is categorical with k categories + * 0, 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 9297c20596..96d2471e1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -66,4 +66,11 @@ object Entropy extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 2874bcf496..d586f44904 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -62,4 +62,11 @@ object Gini extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 698a1a2a8e..f7d99a40eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -53,4 +53,11 @@ object Variance extends Impurity { val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000..2c281a1ee7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.Algo; +import org.apache.spark.mllib.tree.configuration.Strategy; +import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; + + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeModel model) { + int numCorrect = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numCorrect++; + } + } + return numCorrect; + } + + @Test + public void runDTUsingConstructor() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTree learner = new DecisionTree(strategy); + DecisionTreeModel model = learner.train(rdd.rdd()); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + + @Test + public void runDTUsingStaticMethods() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8665a00f3b..70ca7c8a26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} @@ -815,6 +817,10 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { + generateCategoricalDataPoints().toList.asJava + } + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { -- cgit v1.2.3