aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorKazuki Taniguchi <kazuki.t.1018@gmail.com>2015-01-30 00:39:44 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-30 00:39:44 -0800
commitbc1fc9b60dab69ae74419e35dc6bd263dc504f34 (patch)
tree99bb73f18a7cf2bb70ab31b99cfa72e71699bdf5 /mllib
parentdd4d84cf809e6e425958fe768c518679d1828779 (diff)
downloadspark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.gz
spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.bz2
spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.zip
[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
This PR is implementing the Gradient Boosted Trees for Python API. Author: Kazuki Taniguchi <kazuki.t.1018@gmail.com> Closes #3951 from kazk1018/gbt_for_py and squashes the following commits: 620d247 [Kazuki Taniguchi] [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala36
1 files changed, 33 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 430d763ef7..a66d6f0cf2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.test.ChiSqTestResult
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
import org.apache.spark.mllib.tree.impurity._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.loss.Losses
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -533,6 +534,35 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib GradientBoostedTrees.train().
+ * This stub returns a handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on exit;
+ * see the Py4J documentation.
+ */
+ def trainGradientBoostedTreesModel(
+ data: JavaRDD[LabeledPoint],
+ algoStr: String,
+ categoricalFeaturesInfo: JMap[Int, Int],
+ lossStr: String,
+ numIterations: Int,
+ learningRate: Double,
+ maxDepth: Int): GradientBoostedTreesModel = {
+ val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
+ boostingStrategy.setLoss(Losses.fromString(lossStr))
+ boostingStrategy.setNumIterations(numIterations)
+ boostingStrategy.setLearningRate(learningRate)
+ boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
+ boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
+
+ val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+ try {
+ GradientBoostedTrees.train(cached, boostingStrategy)
+ } finally {
+ cached.unpersist(blocking = false)
+ }
+ }
+
+ /**
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
* TODO figure out return type.
*/