aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-28 16:20:00 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-28 16:20:00 -0700
commit4f4721a21cc9acc2b6f685bbfc8757d29563a775 (patch)
tree6cd62a33cb375e32ba72abfba71c2cf9b64df616
parentdae538a4d7c36191c1feb02ba87ffc624ab960dc (diff)
downloadspark-4f4721a21cc9acc2b6f685bbfc8757d29563a775.tar.gz
spark-4f4721a21cc9acc2b6f685bbfc8757d29563a775.tar.bz2
spark-4f4721a21cc9acc2b6f685bbfc8757d29563a775.zip
[SPARK-14862][ML] Updated Classifiers to not require labelCol metadata
## What changes were proposed in this pull request? Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata. * They first check for metadata. * If numClasses is not specified in metadata, they identify the largest label value (up to a limit). This functionality is implemented in a new Classifier.getNumClasses method. Also * Updated Classifier.extractLabeledPoints to (a) check label values and (b) include a second version which takes a numClasses value for validity checking. ## How was this patch tested? * Unit tests in ClassifierSuite for helper methods * Unit tests for DecisionTreeClassifier, RandomForestClassifier, GBTClassifier with toy datasets lacking label metadata Author: Joseph K. Bradley <joseph@databricks.com> Closes #12663 from jkbradley/trees-no-metadata.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala70
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala7
8 files changed, 245 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 473e801794..bc5fe35ad4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,14 +17,17 @@
package org.apache.spark.ml.classification
+import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* (private[spark]) Params for classification.
@@ -62,6 +65,67 @@ abstract class Classifier[
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
+
+ /**
+ * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ *
+ * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
+ * and features ([[Vector]]). Labels are cast to [[DoubleType]].
+ * @param numClasses Number of classes label can take. Labels must be integers in the range
+ * [0, numClasses).
+ * @throws SparkException if any label is not an integer >= 0
+ */
+ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
+ require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
+ s" $numClasses, but requires numClasses > 0.")
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) =>
+ require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
+ s" dataset with invalid label $label. Labels must be integers in range" +
+ s" [0, 1, ..., $numClasses), where numClasses=$numClasses.")
+ LabeledPoint(label, features)
+ }
+ }
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+ * such as in [[extractLabeledPoints()]].
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
+ MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+ case Some(n: Int) => n
+ case None =>
+ // Get number of classes from dataset itself.
+ val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
+ if (maxLabelRow.isEmpty) {
+ throw new SparkException("ML algorithm was given empty dataset.")
+ }
+ val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+ require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
+ s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
+ val numClasses = maxDoubleLabel.toInt + 1
+ require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
+ s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
+ s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
+ s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
+ s" StringIndexer to the label column.")
+ logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
+ s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
+ numClasses
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ecb218e2a3..2b2e13d496 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -85,14 +85,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
- case Some(n: Int) => n
- case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
- s" with invalid label column ${$(labelCol)}, without the number of classes" +
- " specified. See StringIndexer.")
- // TODO: Automatically index labels: SPARK-7126
- }
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val numClasses: Int = getNumClasses(dataset)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index e736f01cc6..acc04582b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -35,8 +35,9 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -126,16 +127,16 @@ class GBTClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
- case Some(n: Int) => n
- case None => throw new IllegalArgumentException("GBTClassifier was given input" +
- s" with invalid label column ${$(labelCol)}, without the number of classes" +
- " specified. See StringIndexer.")
- // TODO: Automatically index labels: SPARK-7126
- }
- require(numClasses == 2,
- s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
+ // 2 classes now. This lets us provide a more precise error message.
+ val oldDataset: RDD[LabeledPoint] =
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) =>
+ require(label == 0 || label == 1, s"GBTClassifier was given" +
+ s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
+ s" GBTClassifier currently only supports binary classification.")
+ LabeledPoint(label, features)
+ }
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
@@ -165,6 +166,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
* model for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
+ *
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
@@ -185,6 +187,7 @@ class GBTClassificationModel private[ml](
/**
* Construct a GBTClassificationModel
+ *
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 28364c2593..fb3418d78b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -101,14 +101,8 @@ class RandomForestClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
- case Some(n: Int) => n
- case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
- s" with invalid label column ${$(labelCol)}, without the number of classes" +
- " specified. See StringIndexer.")
- // TODO: Automatically index labels: SPARK-7126
- }
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val numClasses: Int = getNumClasses(dataset)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val trees =
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index d0e3fe7ad1..89afb94b0f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -17,6 +17,86 @@
package org.apache.spark.ml.classification
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("extractLabeledPoints") {
+ def getTestData(labels: Seq[Double]): DataFrame = {
+ val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+ sqlContext.createDataFrame(data)
+ }
+
+ val c = new MockClassifier
+ // Valid dataset
+ val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+ c.extractLabeledPoints(df0, 6).count()
+ // Invalid datasets
+ val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
+ withClue("Classifier should fail if label is negative") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df1, 6).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
+ withClue("Classifier should fail if label is not an integer") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df2, 6).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ // extractLabeledPoints with numClasses specified
+ withClue("Classifier should fail if label is >= numClasses") {
+ val e: SparkException = intercept[SparkException] {
+ c.extractLabeledPoints(df0, numClasses = 5).count()
+ }
+ assert(e.getMessage.contains("given dataset with invalid label"))
+ }
+ withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.extractLabeledPoints(df0, numClasses = 0).count()
+ }
+ assert(e.getMessage.contains("but requires numClasses > 0"))
+ }
+ }
+
+ test("getNumClasses") {
+ def getTestData(labels: Seq[Double]): DataFrame = {
+ val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+ sqlContext.createDataFrame(data)
+ }
+
+ val c = new MockClassifier
+ // Valid dataset
+ val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
+ assert(c.getNumClasses(df0) === 6)
+ // Invalid datasets
+ val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1))
+ withClue("getNumClasses should fail if label is max label not an integer") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.getNumClasses(df1)
+ }
+ assert(e.getMessage.contains("requires integers in range"))
+ }
+ val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble))
+ withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") {
+ val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+ c.getNumClasses(df2)
+ }
+ assert(e.getMessage.contains("requires integers in range"))
+ }
+ }
+}
+
object ClassifierSuite {
/**
@@ -29,4 +109,32 @@ object ClassifierSuite {
"rawPredictionCol" -> "myRawPrediction"
)
+ class MockClassifier(override val uid: String)
+ extends Classifier[Vector, MockClassifier, MockClassificationModel] {
+
+ def this() = this(Identifiable.randomUID("mockclassifier"))
+
+ override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
+
+ override def train(dataset: Dataset[_]): MockClassificationModel =
+ throw new NotImplementedError()
+
+ // Make methods public
+ override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
+ super.extractLabeledPoints(dataset, numClasses)
+ def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
+ }
+
+ class MockClassificationModel(override val uid: String)
+ extends ClassificationModel[Vector, MockClassificationModel] {
+
+ def this() = this(Identifiable.randomUID("mockclassificationmodel"))
+
+ protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
+
+ override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
+
+ override def numClasses: Int = throw new NotImplementedError()
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index fe839e15e9..29845b5554 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -342,6 +342,12 @@ class DecisionTreeClassifierSuite
}
}
+ test("Fitting without numClasses in metadata") {
+ val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val dt = new DecisionTreeClassifier().setMaxDepth(1)
+ dt.fit(df)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 7e6aec6b1b..087e201234 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.ml.classification
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -128,6 +129,43 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
}
*/
+ test("Fitting without numClasses in metadata") {
+ val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
+ gbt.fit(df)
+ }
+
+ test("extractLabeledPoints with bad data") {
+ def getTestData(labels: Seq[Double]): DataFrame = {
+ val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
+ sqlContext.createDataFrame(data)
+ }
+
+ val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
+ // Invalid datasets
+ val df1 = getTestData(Seq(0.0, -1.0, 1.0, 0.0))
+ withClue("Classifier should fail if label is negative") {
+ val e: SparkException = intercept[SparkException] {
+ gbt.fit(df1)
+ }
+ assert(e.getMessage.contains("currently only supports binary classification"))
+ }
+ val df2 = getTestData(Seq(0.0, 0.1, 1.0, 0.0))
+ withClue("Classifier should fail if label is not an integer") {
+ val e: SparkException = intercept[SparkException] {
+ gbt.fit(df2)
+ }
+ assert(e.getMessage.contains("currently only supports binary classification"))
+ }
+ val df3 = getTestData(Seq(0.0, 2.0, 1.0, 0.0))
+ withClue("Classifier should fail if label is >= 2") {
+ val e: SparkException = intercept[SparkException] {
+ gbt.fit(df3)
+ }
+ assert(e.getMessage.contains("currently only supports binary classification"))
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index aaaa429103..90744353d9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -154,9 +154,16 @@ class RandomForestClassifierSuite
}
}
+ test("Fitting without numClasses in metadata") {
+ val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
+ val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
+ rf.fit(df)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
+
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()