aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala7
2 files changed, 6 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index 6f4509f03d..eb2609faef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
- this.asInstanceOf[this.type]
+ this
}
/** @group getParam */
@@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
- protected def getOldImpurity: OldImpurity = {
+ private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index cb940f6299..708c769087 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}
-private[ml] object Split {
+private[tree] object Split {
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
@@ -58,7 +58,7 @@ private[ml] object Split {
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
-final class CategoricalSplit(
+final class CategoricalSplit private[ml] (
override val featureIndex: Int,
leftCategories: Array[Double],
private val numCategories: Int)
@@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
-final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
+ extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold