aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-21 21:44:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-21 21:44:44 -0700
commit607eff0edfc10a1473fa9713a0500bf09f105c13 (patch)
tree534b2b61a9876750dd73fbc9016bf795427459c1
parent70f9f8ff38560967f2c84de77263a5455c45c495 (diff)
downloadspark-607eff0edfc10a1473fa9713a0500bf09f105c13.tar.gz
spark-607eff0edfc10a1473fa9713a0500bf09f105c13.tar.bz2
spark-607eff0edfc10a1473fa9713a0500bf09f105c13.zip
[SPARK-6113] [ML] Small cleanups after original tree API PR
This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala7
3 files changed, 25 insertions, 11 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 921b396e79..2cd515c89d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
* {{{
* ./bin/run-example ml.DecisionTreeExample [options]
* }}}
+ * Note that Decision Trees can take a large amount of memory. If the run-example command above
+ * fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DecisionTreeExample {
@@ -70,7 +77,7 @@ object DecisionTreeExample {
val parser = new OptionParser[Params]("DecisionTreeExample") {
head("DecisionTreeExample: an example decision tree app.")
opt[String]("algo")
- .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
@@ -222,18 +229,23 @@ object DecisionTreeExample {
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
- val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
- val featuresIndexer = new VectorIndexer().setInputCol("features")
- .setOutputCol("indexedFeatures").setMaxCategories(10)
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
stages += featuresIndexer
// (3) Learn DecisionTree
val dt = algo match {
case "classification" =>
- new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+ new DecisionTreeClassifier()
+ .setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
@@ -242,7 +254,8 @@ object DecisionTreeExample {
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
case "regression" =>
- new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+ new DecisionTreeRegressor()
+ .setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index 6f4509f03d..eb2609faef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
- this.asInstanceOf[this.type]
+ this
}
/** @group getParam */
@@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
- protected def getOldImpurity: OldImpurity = {
+ private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index cb940f6299..708c769087 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}
-private[ml] object Split {
+private[tree] object Split {
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
@@ -58,7 +58,7 @@ private[ml] object Split {
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
-final class CategoricalSplit(
+final class CategoricalSplit private[ml] (
override val featureIndex: Int,
leftCategories: Array[Double],
private val numCategories: Int)
@@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
-final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
+ extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold