diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-04-21 21:44:44 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-21 21:44:44 -0700 |
commit | 607eff0edfc10a1473fa9713a0500bf09f105c13 (patch) | |
tree | 534b2b61a9876750dd73fbc9016bf795427459c1 /mllib/src | |
parent | 70f9f8ff38560967f2c84de77263a5455c45c495 (diff) | |
download | spark-607eff0edfc10a1473fa9713a0500bf09f105c13.tar.gz spark-607eff0edfc10a1473fa9713a0500bf09f105c13.tar.bz2 spark-607eff0edfc10a1473fa9713a0500bf09f105c13.zip |
[SPARK-6113] [ML] Small cleanups after original tree API PR
This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits:
2263b5b [Joseph K. Bradley] Added note about tree example issue.
bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala | 4 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala | 7 |
2 files changed, 6 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d..eb2609faef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f6299..708c769087 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold |