aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2016-11-08 16:00:45 -0800
committerFelix Cheung <felixcheung@apache.org>2016-11-08 16:00:45 -0800
commit55964c15a7b639f920dfe6c104ae4fdcd673705c (patch)
tree1e551bd8c155145135acc161f711e0464b053f8c /mllib/src/main
parent6f7ecb0f2975d24a71e4240cf623f5bd8992bbeb (diff)
downloadspark-55964c15a7b639f920dfe6c104ae4fdcd673705c.tar.gz
spark-55964c15a7b639f920dfe6c104ae4fdcd673705c.tar.bz2
spark-55964c15a7b639f920dfe6c104ae4fdcd673705c.zip
[SPARK-18239][SPARKR] Gradient Boosted Tree for R
## What changes were proposed in this pull request? Gradient Boosted Tree in R. With a few minor improvements to RandomForest in R. Since this is relatively isolated I'd like to target this for branch-2.1 ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15746 from felixcheung/rgbt.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala164
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala144
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala14
5 files changed, 326 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
new file mode 100644
index 0000000000..8946025032
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class GBTClassifierWrapper private (
+ val pipeline: PipelineModel,
+ val formula: String,
+ val features: Array[String]) extends MLWritable {
+
+ import GBTClassifierWrapper._
+
+ private val gbtcModel: GBTClassificationModel =
+ pipeline.stages(1).asInstanceOf[GBTClassificationModel]
+
+ lazy val numFeatures: Int = gbtcModel.numFeatures
+ lazy val featureImportances: Vector = gbtcModel.featureImportances
+ lazy val numTrees: Int = gbtcModel.getNumTrees
+ lazy val treeWeights: Array[Double] = gbtcModel.treeWeights
+
+ def summary: String = gbtcModel.toDebugString
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(gbtcModel.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new
+ GBTClassifierWrapper.GBTClassifierWrapperWriter(this)
+}
+
+private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] {
+
+ val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+ val PREDICTED_LABEL_COL = "prediction"
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ formula: String,
+ maxDepth: Int,
+ maxBins: Int,
+ maxIter: Int,
+ stepSize: Double,
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ checkpointInterval: Int,
+ lossType: String,
+ seed: String,
+ subsamplingRate: Double,
+ maxMemoryInMB: Int,
+ cacheNodeIds: Boolean): GBTClassifierWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ .setForceIndexLabel(true)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+
+ // get label names from output schema
+ val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
+ .asInstanceOf[NominalAttribute]
+ val labels = labelAttr.values.get
+
+ // assemble and fit the pipeline
+ val rfc = new GBTClassifier()
+ .setMaxDepth(maxDepth)
+ .setMaxBins(maxBins)
+ .setMaxIter(maxIter)
+ .setStepSize(stepSize)
+ .setMinInstancesPerNode(minInstancesPerNode)
+ .setMinInfoGain(minInfoGain)
+ .setCheckpointInterval(checkpointInterval)
+ .setLossType(lossType)
+ .setSubsamplingRate(subsamplingRate)
+ .setMaxMemoryInMB(maxMemoryInMB)
+ .setCacheNodeIds(cacheNodeIds)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+ if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
+
+ val idxToStr = new IndexToString()
+ .setInputCol(PREDICTED_LABEL_INDEX_COL)
+ .setOutputCol(PREDICTED_LABEL_COL)
+ .setLabels(labels)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, rfc, idxToStr))
+ .fit(data)
+
+ new GBTClassifierWrapper(pipeline, formula, features)
+ }
+
+ override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader
+
+ override def load(path: String): GBTClassifierWrapper = super.load(path)
+
+ class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("formula" -> instance.formula) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] {
+
+ override def load(path: String): GBTClassifierWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val formula = (rMetadata \ "formula").extract[String]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ new GBTClassifierWrapper(pipeline, formula, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
new file mode 100644
index 0000000000..585077588e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class GBTRegressorWrapper private (
+ val pipeline: PipelineModel,
+ val formula: String,
+ val features: Array[String]) extends MLWritable {
+
+ private val gbtrModel: GBTRegressionModel =
+ pipeline.stages(1).asInstanceOf[GBTRegressionModel]
+
+ lazy val numFeatures: Int = gbtrModel.numFeatures
+ lazy val featureImportances: Vector = gbtrModel.featureImportances
+ lazy val numTrees: Int = gbtrModel.getNumTrees
+ lazy val treeWeights: Array[Double] = gbtrModel.treeWeights
+
+ def summary: String = gbtrModel.toDebugString
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(gbtrModel.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new
+ GBTRegressorWrapper.GBTRegressorWrapperWriter(this)
+}
+
+private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ formula: String,
+ maxDepth: Int,
+ maxBins: Int,
+ maxIter: Int,
+ stepSize: Double,
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ checkpointInterval: Int,
+ lossType: String,
+ seed: String,
+ subsamplingRate: Double,
+ maxMemoryInMB: Int,
+ cacheNodeIds: Boolean): GBTRegressorWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+
+ // assemble and fit the pipeline
+ val rfr = new GBTRegressor()
+ .setMaxDepth(maxDepth)
+ .setMaxBins(maxBins)
+ .setMaxIter(maxIter)
+ .setStepSize(stepSize)
+ .setMinInstancesPerNode(minInstancesPerNode)
+ .setMinInfoGain(minInfoGain)
+ .setCheckpointInterval(checkpointInterval)
+ .setLossType(lossType)
+ .setSubsamplingRate(subsamplingRate)
+ .setMaxMemoryInMB(maxMemoryInMB)
+ .setCacheNodeIds(cacheNodeIds)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, rfr))
+ .fit(data)
+
+ new GBTRegressorWrapper(pipeline, formula, features)
+ }
+
+ override def read: MLReader[GBTRegressorWrapper] = new GBTRegressorWrapperReader
+
+ override def load(path: String): GBTRegressorWrapper = super.load(path)
+
+ class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("formula" -> instance.formula) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] {
+
+ override def load(path: String): GBTRegressorWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val formula = (rMetadata \ "formula").extract[String]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ new GBTRegressorWrapper(pipeline, formula, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 0e09e18027..b59fe29234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] {
RandomForestRegressorWrapper.load(path)
case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
RandomForestClassifierWrapper.load(path)
+ case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
+ GBTRegressorWrapper.load(path)
+ case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
+ GBTClassifierWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index b0088ddaf3..6947ba7e75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -35,18 +35,18 @@ private[r] class RandomForestClassifierWrapper private (
val formula: String,
val features: Array[String]) extends MLWritable {
- private val DTModel: RandomForestClassificationModel =
+ private val rfcModel: RandomForestClassificationModel =
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
- lazy val numFeatures: Int = DTModel.numFeatures
- lazy val featureImportances: Vector = DTModel.featureImportances
- lazy val numTrees: Int = DTModel.getNumTrees
- lazy val treeWeights: Array[Double] = DTModel.treeWeights
+ lazy val numFeatures: Int = rfcModel.numFeatures
+ lazy val featureImportances: Vector = rfcModel.featureImportances
+ lazy val numTrees: Int = rfcModel.getNumTrees
+ lazy val treeWeights: Array[Double] = rfcModel.treeWeights
- def summary: String = DTModel.toDebugString
+ def summary: String = rfcModel.toDebugString
def transform(dataset: Dataset[_]): DataFrame = {
- pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
+ pipeline.transform(dataset).drop(rfcModel.getFeaturesCol)
}
override def write: MLWriter = new
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
index c8874407fa..4b9a3a731d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -35,18 +35,18 @@ private[r] class RandomForestRegressorWrapper private (
val formula: String,
val features: Array[String]) extends MLWritable {
- private val DTModel: RandomForestRegressionModel =
+ private val rfrModel: RandomForestRegressionModel =
pipeline.stages(1).asInstanceOf[RandomForestRegressionModel]
- lazy val numFeatures: Int = DTModel.numFeatures
- lazy val featureImportances: Vector = DTModel.featureImportances
- lazy val numTrees: Int = DTModel.getNumTrees
- lazy val treeWeights: Array[Double] = DTModel.treeWeights
+ lazy val numFeatures: Int = rfrModel.numFeatures
+ lazy val featureImportances: Vector = rfrModel.featureImportances
+ lazy val numTrees: Int = rfrModel.getNumTrees
+ lazy val treeWeights: Array[Double] = rfrModel.treeWeights
- def summary: String = DTModel.toDebugString
+ def summary: String = rfrModel.toDebugString
def transform(dataset: Dataset[_]): DataFrame = {
- pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
+ pipeline.transform(dataset).drop(rfrModel.getFeaturesCol)
}
override def write: MLWriter = new