aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-11-10 17:13:10 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-10 17:13:10 -0800
commit5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 (patch)
treea6f7eff240d2f1f299138bce167e2599634aad83 /mllib/src/main
parenta3356343cbf58b930326f45721fb4ecade6f8029 (diff)
downloadspark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.gz
spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.bz2
spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.zip
[SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.
## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15842 from yanboliang/spark-18401.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala28
1 files changed, 24 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 6947ba7e75..31f846dc6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private (
val formula: String,
val features: Array[String]) extends MLWritable {
+ import RandomForestClassifierWrapper._
+
private val rfcModel: RandomForestClassificationModel =
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
@@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private (
def summary: String = rfcModel.toDebugString
def transform(dataset: Dataset[_]): DataFrame = {
- pipeline.transform(dataset).drop(rfcModel.getFeaturesCol)
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(rfcModel.getFeaturesCol)
}
override def write: MLWriter = new
@@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private (
}
private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] {
+
+ val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+ val PREDICTED_LABEL_COL = "prediction"
+
def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
@@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
val rFormula = new RFormula()
.setFormula(formula)
+ .setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
@@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.attributes.get
val features = featureAttrs.map(_.name.get)
+ // get label names from output schema
+ val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
+ .asInstanceOf[NominalAttribute]
+ val labels = labelAttr.values.get
+
// assemble and fit the pipeline
val rfc = new RandomForestClassifier()
.setMaxDepth(maxDepth)
@@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.setCacheNodeIds(cacheNodeIds)
.setProbabilityCol(probabilityCol)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
+ val idxToStr = new IndexToString()
+ .setInputCol(PREDICTED_LABEL_INDEX_COL)
+ .setOutputCol(PREDICTED_LABEL_COL)
+ .setLabels(labels)
+
val pipeline = new Pipeline()
- .setStages(Array(rFormulaModel, rfc))
+ .setStages(Array(rFormulaModel, rfc, idxToStr))
.fit(data)
new RandomForestClassifierWrapper(pipeline, formula, features)