aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 11:31:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 11:31:10 -0700
commitf9d578eaa107d8e8503c1563a2b3990c85104298 (patch)
treec4410c020f61d9c48780eb0f108be250d254f42f
parent7d2ed8cc030f3d84fea47fded072c320c3d87ca7 (diff)
downloadspark-f9d578eaa107d8e8503c1563a2b3990c85104298.tar.gz
spark-f9d578eaa107d8e8503c1563a2b3990c85104298.tar.bz2
spark-f9d578eaa107d8e8503c1563a2b3990c85104298.zip
[SPARK-13783][ML] Model export/import for spark.ml: GBTs
## What changes were proposed in this pull request? * Added save/load for ```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```. * Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to support save/load ```treeWeights```. ## How was this patch tested? Adds standard unit tests for GBT save/load. cc jkbradley GayathriMurali Author: Yanbo Liang <ybliang8@gmail.com> Closes #12230 from yanboliang/spark-13783.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala114
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala73
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala36
8 files changed, 262 insertions, 137 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 46e8b89d01..39a698af15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -18,19 +18,21 @@
package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +60,7 @@ import org.apache.spark.sql.functions._
final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
- with GBTParams with TreeClassifierParams with Logging {
+ with GBTClassifierParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc"))
@@ -115,40 +117,12 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTClassifier:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "logistic"
- * (default = logistic)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "logistic")
+ // Parameters from GBTClassifierParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "logistic" => OldLogLoss
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
- }
- }
-
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -175,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object GBTClassifier {
- // The losses below should be lowercase.
+object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
+
/** Accessor for supported loss settings: logistic */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassifier = super.load(path)
}
/**
@@ -199,7 +176,8 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
- with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
+ with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
@@ -267,12 +245,62 @@ final class GBTClassificationModel private[ml](
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
-private[ml] object GBTClassificationModel {
+@Since("2.0.0")
+object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassificationModel = super.load(path)
+
+ private[GBTClassificationModel]
+ class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+ val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 9d80b8eb68..dfa711b243 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -294,7 +294,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
- val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 0b52fe2d13..741724d7a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,19 +18,20 @@
package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
- with GBTParams with TreeRegressorParams with Logging {
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtr"))
@@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTRegressor:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "squared" (L2) and "absolute" (L1)
- * (default = squared)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "squared")
+ // Parameters from GBTRegressorParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "squared" => OldSquaredError
- case "absolute" => OldAbsoluteError
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
- }
- }
-
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
@Experimental
-object GBTRegressor {
- // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
}
/**
@@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
@@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml](
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index bee13c2ebf..4c4ff278d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -249,7 +249,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
- val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index c4ab673d9a..f38e1ec7c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite {
sql: SQLContext,
extraMetadata: JObject): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
- val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map {
+ val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
case (tree, treeID) =>
- treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext)
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
+ instance.treeWeights(treeID))
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata")
+ sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
.write.parquet(treesMetadataPath)
val dataPath = new Path(path, "data").toString
val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
@@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite {
path: String,
sql: SQLContext,
className: String,
- treeClassName: String): (Metadata, Array[(Metadata, Node)]) = {
+ treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
import sql.implicits._
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
@@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite {
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath)
- .select("treeID", "metadata").as[(Int, String)].rdd.map {
- case (treeID: Int, json: String) =>
- treeID -> DefaultParamsReader.parseMetadata(json, treeClassName)
+ val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
+ case (treeID: Int, json: String, weights: Double) =>
+ treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights)
}
- val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect()
+
+ val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+ val treesMetadata = treesMetadataWeights.map(_._1)
+ val treesWeights = treesMetadataWeights.map(_._2)
val dataPath = new Path(path, "data").toString
val nodeData: Dataset[EnsembleNodeData] =
@@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite {
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
- (metadata, treesMetadata.zip(rootNodes))
+ (metadata, treesMetadata.zip(rootNodes), treesWeights)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 0767dc17e5..b6783911ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 76d8c9372e..7e6aec6b1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -34,7 +34,8 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTClassifierSuite.compareAPIs
@@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
- val newModel = GBTClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTClassificationModel,
+ model2: GBTClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 3c11631f98..216377959e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -32,7 +32,8 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTRegressorSuite.compareAPIs
@@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
- val newModel = GBTRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTRegressionModel,
+ model2: GBTRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTRegressorSuite extends SparkFunSuite {