aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2017-03-12 12:15:19 -0700
committerFelix Cheung <felixcheung@apache.org>2017-03-12 12:15:19 -0700
commit9f8ce4825e378b6a856ce65cb9986a5a0f0b624e (patch)
treef36dcc381c02cbfc86dab0e207699eddd9bc87bc /mllib/src
parent2f5187bde1544c452fe5116a2bd243653332a079 (diff)
downloadspark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.tar.gz
spark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.tar.bz2
spark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.zip
[SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models
## What changes were proposed in this pull request? RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models. Below 4 R wrappers are changed: * `RandomForestClassificationWrapper` * `RandomForestRegressionWrapper` * `GBTClassificationWrapper` * `GBTRegressionWrapper` ## How was this patch tested? Test manually on my local machine. Author: Xin Ren <iamshrek@126.com> Closes #17207 from keypointt/SPARK-19282.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala1
4 files changed, 4 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index aacb41ee26..c07eadb30a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private (
lazy val featureImportances: Vector = gbtcModel.featureImportances
lazy val numTrees: Int = gbtcModel.getNumTrees
lazy val treeWeights: Array[Double] = gbtcModel.treeWeights
+ lazy val maxDepth: Int = gbtcModel.getMaxDepth
def summary: String = gbtcModel.toDebugString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
index 585077588e..b568d78592 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private (
lazy val featureImportances: Vector = gbtrModel.featureImportances
lazy val numTrees: Int = gbtrModel.getNumTrees
lazy val treeWeights: Array[Double] = gbtrModel.treeWeights
+ lazy val maxDepth: Int = gbtrModel.getMaxDepth
def summary: String = gbtrModel.toDebugString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 366f375b58..8a83d4e980 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private (
lazy val featureImportances: Vector = rfcModel.featureImportances
lazy val numTrees: Int = rfcModel.getNumTrees
lazy val treeWeights: Array[Double] = rfcModel.treeWeights
+ lazy val maxDepth: Int = rfcModel.getMaxDepth
def summary: String = rfcModel.toDebugString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
index 4b9a3a731d..038bd79c70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private (
lazy val featureImportances: Vector = rfrModel.featureImportances
lazy val numTrees: Int = rfrModel.getNumTrees
lazy val treeWeights: Array[Double] = rfrModel.treeWeights
+ lazy val maxDepth: Int = rfrModel.getMaxDepth
def summary: String = rfrModel.toDebugString