aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala68
1 files changed, 60 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6e46292451..428bc7a6d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -17,12 +17,17 @@
package org.apache.spark.ml.regression
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
@@ -31,6 +36,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
+
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
@@ -41,7 +47,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
- with DecisionTreeRegressorParams {
+ with DecisionTreeRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
@@ -107,9 +113,12 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object DecisionTreeRegressor {
+object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressor = super.load(path)
}
/**
@@ -125,13 +134,13 @@ final class DecisionTreeRegressionModel private[ml] (
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
- with DecisionTreeModel with DecisionTreeRegressorParams with Serializable {
+ with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
require(rootNode != null,
- "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+ "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.")
/**
* Construct a decision tree regression model.
@@ -200,12 +209,55 @@ final class DecisionTreeRegressionModel private[ml] (
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this)
}
-private[ml] object DecisionTreeRegressionModel {
+@Since("2.0.0")
+object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeRegressionModel] =
+ new DecisionTreeRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressionModel = super.load(path)
+
+ private[DecisionTreeRegressionModel]
+ class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val (nodeData, _) = NodeData.build(instance.rootNode, 0)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
+ }
+ }
+
+ private class DecisionTreeRegressionModelReader
+ extends MLReader[DecisionTreeRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): DecisionTreeRegressionModel = {
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val root = loadTreeNodes(path, metadata, sqlContext)
+ val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeRegressor,
categoricalFeatures: Map[Int, Int],