diff options
author | Kazuki Taniguchi <kazuki.t.1018@gmail.com> | 2015-01-30 00:39:44 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-30 00:39:44 -0800 |
commit | bc1fc9b60dab69ae74419e35dc6bd263dc504f34 (patch) | |
tree | 99bb73f18a7cf2bb70ab31b99cfa72e71699bdf5 /mllib | |
parent | dd4d84cf809e6e425958fe768c518679d1828779 (diff) | |
download | spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.gz spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.bz2 spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.zip |
[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
This PR is implementing the Gradient Boosted Trees for Python API.
Author: Kazuki Taniguchi <kazuki.t.1018@gmail.com>
Closes #3951 from kazk1018/gbt_for_py and squashes the following commits:
620d247 [Kazuki Taniguchi] [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 430d763ef7..a66d6f0cf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -533,6 +534,35 @@ class PythonMLLibAPI extends Serializable { } /** + * Java stub for Python mllib GradientBoostedTrees.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainGradientBoostedTreesModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + categoricalFeaturesInfo: JMap[Int, Int], + lossStr: String, + numIterations: Int, + learningRate: Double, + maxDepth: Int): GradientBoostedTreesModel = { + val boostingStrategy = BoostingStrategy.defaultParams(algoStr) + boostingStrategy.setLoss(Losses.fromString(lossStr)) + boostingStrategy.setNumIterations(numIterations) + boostingStrategy.setLearningRate(learningRate) + boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + GradientBoostedTrees.train(cached, boostingStrategy) + } finally { + cached.unpersist(blocking = false) + } + } + + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. */ |