From 607eff0edfc10a1473fa9713a0500bf09f105c13 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Apr 2015 21:44:44 -0700 Subject: [SPARK-6113] [ML] Small cleanups after original tree API PR This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR --- .../src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d..eb2609faef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f6299..708c769087 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold -- cgit v1.2.3