aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-06 22:58:59 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-06 22:58:59 -0700
commit47ccd5e71be49b723476f3ff8d5768f0f45c2ea6 (patch)
tree18b61526f97c93c4112e5a75dddefb42e0a9fafc /mllib
parentffd1f59a62a9dd9a4d5a7b09490b9d01ff1cd42d (diff)
downloadspark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.gz
spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.tar.bz2
spark-47ccd5e71be49b723476f3ff8d5768f0f45c2ea6.zip
[SPARK-2851] [mllib] DecisionTree Python consistency update
Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits: 6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) 00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency fe6dbfa [Joseph K. Bradley] removed unnecessary imports e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). c699850 [Joseph K. Bradley] a few doc comments eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala151
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala32
4 files changed, 166 insertions, 42 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fd0b9556c7..ba7ccd8ce4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
-import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.correlation.CorrelationNames
@@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable {
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
- val algo: Algo = algoStr match {
- case "classification" => Classification
- case "regression" => Regression
- case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
- }
- val impurity: Impurity = impurityStr match {
- case "gini" => Gini
- case "entropy" => Entropy
- case "variance" => Variance
- case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
- }
+ val algo = Algo.fromString(algoStr)
+ val impurity = Impurities.fromString(impurityStr)
val strategy = new Strategy(
algo = algo,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 1d03e6e3b3..c8a8656596 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,14 +17,18 @@
package org.apache.spark.mllib.tree
+import org.apache.spark.api.java.JavaRDD
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
@@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging {
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+ * is recommended to clearly separate classification and regression.
+ *
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
@@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Method to train a decision tree model where the instances are represented as an RDD of
- * (label, features) pairs. The method supports binary classification and regression. For the
- * binary classification, the label for each instance should either be 0 or 1 to denote the two
- * classes.
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+ * is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Method to train a decision tree model where the instances are represented as an RDD of
- * (label, features) pairs. The method supports binary classification and regression. For the
- * binary classification, the label for each instance should either be 0 or 1 to denote the two
- * classes.
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+ * is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Method to train a decision tree model where the instances are represented as an RDD of
- * (label, features) pairs. The decision tree method supports binary classification and
- * regression. For the binary classification, the label for each instance should either be 0 or
- * 1 to denote the two classes. The method also supports categorical features inputs where the
- * number of categories can specified using the categoricalFeaturesInfo option.
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+ * is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging {
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
+ * E.g., an entry (n -> k) indicates that feature n is categorical
+ * with k categories indexed from 0: {0, 1, ..., k-1}.
* @return DecisionTreeModel that can be used for prediction
*/
def train(
@@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging {
new DecisionTree(strategy).train(input)
}
+ /**
+ * Method to train a decision tree model for binary or multiclass classification.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels should take values {0, 1, ..., numClasses-1}.
+ * @param numClassesForClassification number of classes for classification.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
+ * E.g., an entry (n -> k) indicates that feature n is categorical
+ * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param impurity Criterion used for information gain calculation.
+ * Supported values: "gini" (recommended) or "entropy".
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (suggested value: 4)
+ * @param maxBins maximum number of bins used for splitting features
+ * (suggested value: 100)
+ * @return DecisionTreeModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ val impurityType = Impurities.fromString(impurity)
+ train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
+ categoricalFeaturesInfo)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+ */
+ def trainClassifier(
+ input: JavaRDD[LabeledPoint],
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ trainClassifier(input.rdd, numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ impurity, maxDepth, maxBins)
+ }
+
+ /**
+ * Method to train a decision tree model for regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels are real numbers.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
+ * E.g., an entry (n -> k) indicates that feature n is categorical
+ * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param impurity Criterion used for information gain calculation.
+ * Supported values: "variance".
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (suggested value: 4)
+ * @param maxBins maximum number of bins used for splitting features
+ * (suggested value: 100)
+ * @return DecisionTreeModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ categoricalFeaturesInfo: Map[Int, Int],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ val impurityType = Impurities.fromString(impurity)
+ train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+ */
+ def trainRegressor(
+ input: JavaRDD[LabeledPoint],
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int): DecisionTreeModel = {
+ trainRegressor(input.rdd,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ impurity, maxDepth, maxBins)
+ }
+
+
private val InvalidBinIndex = -1
/**
@@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging {
* Categorical features:
* For each feature, there is 1 bin per split.
* Splits and bins are handled in 2 ways:
- * (a) For multiclass classification with a low-arity feature
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
- * There are 2^(maxFeatureValue - 1) - 1 splits.
- * (b) For regression and binary classification,
+ * There are math.pow(2, maxFeatureValue - 1) - 1 splits.
+ * (b) "ordered features"
+ * For regression and binary classification,
* and for multiclass classification with a high-arity feature,
- * there is one split per category.
-
- * Categorical case (a) features are called unordered features.
- * Other cases are called ordered features.
+ * there is one bin per category.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
index 79a01f5831..0ef9c6181a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value
+
+ private[mllib] def fromString(name: String): Algo = name match {
+ case "classification" => Classification
+ case "regression" => Regression
+ case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala
new file mode 100644
index 0000000000..9a6452aa13
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Factory for Impurity instances.
+ */
+private[mllib] object Impurities {
+
+ def fromString(name: String): Impurity = name match {
+ case "gini" => Gini
+ case "entropy" => Entropy
+ case "variance" => Variance
+ case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
+ }
+
+}