aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-04 13:32:14 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-04 13:32:14 -0800
commit93ef9b6a2aa1830170cb101f191022f2dda62c41 (patch)
treefae4dc34f0bc0bc12d3ee374e97c3090e1f7da90 /mllib/src
parentba5f81859d6ba37a228a1c43d26c47e64c0382cd (diff)
downloadspark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.gz
spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.bz2
spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.zip
[SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction
DecisionTreeRegressor will provide variance of prediction as a Double column. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8866 from yanboliang/spark-9622.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala26
5 files changed, 92 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index c7bca12430..4aff749ff7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen {
" probabilities. Note: Not all models output well-calibrated probability estimates!" +
" These probabilities should be treated as confidences, not precise probabilities",
Some("\"probability\"")),
+ ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index cb2a060a34..c088c16d1b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -139,6 +139,21 @@ private[ml] trait HasProbabilityCol extends Params {
}
/**
+ * Trait for shared param varianceCol.
+ */
+private[ml] trait HasVarianceCol extends Params {
+
+ /**
+ * Param for Column name for the biased sample variance of prediction.
+ * @group param
+ */
+ final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction")
+
+ /** @group getParam */
+ final def getVarianceCol: String = $(varianceCol)
+}
+
+/**
* Trait for shared param threshold (default: 0.5).
*/
private[ml] trait HasThreshold extends Params {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 477030d9ea..18c94f3638 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
@@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
/**
* :: Experimental ::
@@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame
@Experimental
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
- with DecisionTreeParams with TreeRegressorParams {
+ with DecisionTreeRegressorParams {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
@@ -73,6 +74,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setSeed(value: Long): this.type = super.setSeed(value)
+ /** @group setParam */
+ def setVarianceCol(value: String): this.type = set(varianceCol, value)
+
override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -113,7 +117,10 @@ final class DecisionTreeRegressionModel private[ml] (
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
- with DecisionTreeModel with Serializable {
+ with DecisionTreeModel with DecisionTreeRegressorParams with Serializable {
+
+ /** @group setParam */
+ def setVarianceCol(value: String): this.type = set(varianceCol, value)
require(rootNode != null,
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
@@ -129,6 +136,29 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).prediction
}
+ /** We need to update this function if we ever add other impurity measures. */
+ protected def predictVariance(features: Vector): Double = {
+ rootNode.predictImpl(features).impurityStats.calculate()
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ transformImpl(dataset)
+ }
+
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf { (features: Vector) => predict(features) }
+ val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
+ var output = dataset
+ if ($(predictionCol).nonEmpty) {
+ output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol))))
+ }
+ output
+ }
+
@Since("1.4.0")
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 1da97db927..7443097492 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,9 +20,11 @@ package org.apache.spark.ml.tree
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
/**
* Parameters for Decision Tree-based algorithms.
@@ -256,6 +258,22 @@ private[ml] object TreeRegressorParams {
final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
}
+private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
+ with TreeRegressorParams with HasVarianceCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType)
+ } else {
+ newSchema
+ }
+ }
+}
+
/**
* Parameters for Decision Tree-based ensemble algorithms.
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 6999a910c3..0b39af5543 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Row, DataFrame}
class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
MLTestingUtils.checkCopy(model)
}
+ test("predictVariance") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ .setPredictionCol("")
+ .setVarianceCol("variance")
+ val categoricalFeatures = Map(0 -> 2, 1 -> 2)
+
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val model = dt.fit(df)
+
+ val predictions = model.transform(df)
+ .select(model.getFeaturesCol, model.getVarianceCol)
+ .collect()
+
+ predictions.foreach { case Row(features: Vector, variance: Double) =>
+ val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
+ assert(variance === expectedVariance,
+ s"Expected variance $expectedVariance but got $variance.")
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////