From 4f4721a21cc9acc2b6f685bbfc8757d29563a775 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 28 Apr 2016 16:20:00 -0700 Subject: [SPARK-14862][ML] Updated Classifiers to not require labelCol metadata ## What changes were proposed in this pull request? Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata. * They first check for metadata. * If numClasses is not specified in metadata, they identify the largest label value (up to a limit). This functionality is implemented in a new Classifier.getNumClasses method. Also * Updated Classifier.extractLabeledPoints to (a) check label values and (b) include a second version which takes a numClasses value for validity checking. ## How was this patch tested? * Unit tests in ClassifierSuite for helper methods * Unit tests for DecisionTreeClassifier, RandomForestClassifier, GBTClassifier with toy datasets lacking label metadata Author: Joseph K. Bradley Closes #12663 from jkbradley/trees-no-metadata. --- .../spark/ml/classification/Classifier.scala | 70 ++++++++++++- .../ml/classification/DecisionTreeClassifier.scala | 10 +- .../spark/ml/classification/GBTClassifier.scala | 25 ++--- .../ml/classification/RandomForestClassifier.scala | 10 +- .../spark/ml/classification/ClassifierSuite.scala | 108 +++++++++++++++++++++ .../DecisionTreeClassifierSuite.scala | 6 ++ .../ml/classification/GBTClassifierSuite.scala | 40 +++++++- .../RandomForestClassifierSuite.scala | 7 ++ 8 files changed, 245 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 473e801794..bc5fe35ad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.classification +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * (private[spark]) Params for classification. @@ -62,6 +65,67 @@ abstract class Classifier[ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * + * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) + * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * @param numClasses Number of classes label can take. Labels must be integers in the range + * [0, numClasses). + * @throws SparkException if any label is not an integer >= 0 + */ + protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { + require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + + s" $numClasses, but requires numClasses > 0.") + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => + require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + + s" dataset with invalid label $label. Labels must be integers in range" + + s" [0, 1, ..., $numClasses), where numClasses=$numClasses.") + LabeledPoint(label, features) + } + } + + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in [[extractLabeledPoints()]]. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = { + MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => + // Get number of classes from dataset itself. + val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) + if (maxLabelRow.isEmpty) { + throw new SparkException("ML algorithm was given empty dataset.") + } + val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) + require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" + + s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})") + val numClasses = maxDoubleLabel.toInt + 1 + require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" + + s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" + + s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" + + s" classes, specify numClasses explicitly in the metadata; this can be done by applying" + + s" StringIndexer to the label column.") + logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" + + s" labelCol=$labelCol since numClasses was not specified in the column metadata.") + numClasses + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index ecb218e2a3..2b2e13d496 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -85,14 +85,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { - case Some(n: Int) => n - case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + - s" with invalid label column ${$(labelCol)}, without the number of classes" + - " specified. See StringIndexer.") - // TODO: Automatically index labels: SPARK-7126 - } - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numClasses: Int = getNumClasses(dataset) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), parentUID = Some(uid)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index e736f01cc6..acc04582b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -35,8 +35,9 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -126,16 +127,16 @@ class GBTClassifier @Since("1.4.0") ( override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { - case Some(n: Int) => n - case None => throw new IllegalArgumentException("GBTClassifier was given input" + - s" with invalid label column ${$(labelCol)}, without the number of classes" + - " specified. See StringIndexer.") - // TODO: Automatically index labels: SPARK-7126 - } - require(numClasses == 2, - s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports + // 2 classes now. This lets us provide a more precise error message. + val oldDataset: RDD[LabeledPoint] = + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => + require(label == 0 || label == 1, s"GBTClassifier was given" + + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + + s" GBTClassifier currently only supports binary classification.") + LabeledPoint(label, features) + } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, @@ -165,6 +166,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * model for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. + * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ @@ -185,6 +187,7 @@ class GBTClassificationModel private[ml]( /** * Construct a GBTClassificationModel + * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 28364c2593..fb3418d78b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -101,14 +101,8 @@ class RandomForestClassifier @Since("1.4.0") ( override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { - case Some(n: Int) => n - case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + - s" with invalid label column ${$(labelCol)}, without the number of classes" + - " specified. See StringIndexer.") - // TODO: Automatically index labels: SPARK-7126 - } - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numClasses: Int = getNumClasses(dataset) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) val trees = diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index d0e3fe7ad1..89afb94b0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -17,6 +17,86 @@ package org.apache.spark.ml.classification +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} + +class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("extractLabeledPoints") { + def getTestData(labels: Seq[Double]): DataFrame = { + val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } + sqlContext.createDataFrame(data) + } + + val c = new MockClassifier + // Valid dataset + val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) + c.extractLabeledPoints(df0, 6).count() + // Invalid datasets + val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0)) + withClue("Classifier should fail if label is negative") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df1, 6).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0)) + withClue("Classifier should fail if label is not an integer") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df2, 6).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + // extractLabeledPoints with numClasses specified + withClue("Classifier should fail if label is >= numClasses") { + val e: SparkException = intercept[SparkException] { + c.extractLabeledPoints(df0, numClasses = 5).count() + } + assert(e.getMessage.contains("given dataset with invalid label")) + } + withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.extractLabeledPoints(df0, numClasses = 0).count() + } + assert(e.getMessage.contains("but requires numClasses > 0")) + } + } + + test("getNumClasses") { + def getTestData(labels: Seq[Double]): DataFrame = { + val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } + sqlContext.createDataFrame(data) + } + + val c = new MockClassifier + // Valid dataset + val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) + assert(c.getNumClasses(df0) === 6) + // Invalid datasets + val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1)) + withClue("getNumClasses should fail if label is max label not an integer") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.getNumClasses(df1) + } + assert(e.getMessage.contains("requires integers in range")) + } + val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble)) + withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + c.getNumClasses(df2) + } + assert(e.getMessage.contains("requires integers in range")) + } + } +} + object ClassifierSuite { /** @@ -29,4 +109,32 @@ object ClassifierSuite { "rawPredictionCol" -> "myRawPrediction" ) + class MockClassifier(override val uid: String) + extends Classifier[Vector, MockClassifier, MockClassificationModel] { + + def this() = this(Identifiable.randomUID("mockclassifier")) + + override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError() + + override def train(dataset: Dataset[_]): MockClassificationModel = + throw new NotImplementedError() + + // Make methods public + override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = + super.extractLabeledPoints(dataset, numClasses) + def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset) + } + + class MockClassificationModel(override val uid: String) + extends ClassificationModel[Vector, MockClassificationModel] { + + def this() = this(Identifiable.randomUID("mockclassificationmodel")) + + protected def predictRaw(features: Vector): Vector = throw new NotImplementedError() + + override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError() + + override def numClasses: Int = throw new NotImplementedError() + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index fe839e15e9..29845b5554 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -342,6 +342,12 @@ class DecisionTreeClassifierSuite } } + test("Fitting without numClasses in metadata") { + val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val dt = new DecisionTreeClassifier().setMaxDepth(1) + dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 7e6aec6b1b..087e201234 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -128,6 +129,43 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext } */ + test("Fitting without numClasses in metadata") { + val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) + gbt.fit(df) + } + + test("extractLabeledPoints with bad data") { + def getTestData(labels: Seq[Double]): DataFrame = { + val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } + sqlContext.createDataFrame(data) + } + + val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) + // Invalid datasets + val df1 = getTestData(Seq(0.0, -1.0, 1.0, 0.0)) + withClue("Classifier should fail if label is negative") { + val e: SparkException = intercept[SparkException] { + gbt.fit(df1) + } + assert(e.getMessage.contains("currently only supports binary classification")) + } + val df2 = getTestData(Seq(0.0, 0.1, 1.0, 0.0)) + withClue("Classifier should fail if label is not an integer") { + val e: SparkException = intercept[SparkException] { + gbt.fit(df2) + } + assert(e.getMessage.contains("currently only supports binary classification")) + } + val df3 = getTestData(Seq(0.0, 2.0, 1.0, 0.0)) + withClue("Classifier should fail if label is >= 2") { + val e: SparkException = intercept[SparkException] { + gbt.fit(df3) + } + assert(e.getMessage.contains("currently only supports binary classification")) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of feature importance ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index aaaa429103..90744353d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -154,9 +154,16 @@ class RandomForestClassifierSuite } } + test("Fitting without numClasses in metadata") { + val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) + rf.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of feature importance ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { val numClasses = 2 val rf = new RandomForestClassifier() -- cgit v1.2.3