aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-04-01 21:23:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-01 21:23:35 -0700
commit4fc35e6f5c590feb47cbcb5b1136f2e985677b3f (patch)
treebf52d9b2f4cf7b766e09550a84915023c399a11e /mllib
parent36e8fb8005eccea67a9dea8cf68ec3105aa43351 (diff)
downloadspark-4fc35e6f5c590feb47cbcb5b1136f2e985677b3f.tar.gz
spark-4fc35e6f5c590feb47cbcb5b1136f2e985677b3f.tar.bz2
spark-4fc35e6f5c590feb47cbcb5b1136f2e985677b3f.zip
[SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML
## What changes were proposed in this pull request? Decision tree helper classes will be migrated to ML. This patch moves those internal classes that are not part of the public API and removes ones that are no longer used, after [SPARK-12183](https://github.com/apache/spark/pull/11855). No functional changes are made. Details: * Bin.scala is removed as the ML implementation does not require bins * mllib NodeIdCache is removed. It was only used by the mllib implementation previously, which no longer exists * mllib TreePoint is removed. It was only used by the mllib implementation previously, which no longer exists * BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, BaggedPointSuite and TimeTracker are all moved to ML. ## How was this patch tested? No functional changes are made. Existing unit tests ensure behavior is unchanged. Author: sethah <seth.hendrickson16@gmail.com> Closes #12097 from sethah/cleanup_mllib_tree.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala)5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala195
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala150
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala47
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala)2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala2
18 files changed, 15 insertions, 409 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 572815df0b..4e372702f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.commons.math3.distribution.PoissonDistribution
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index c745e9f8db..61091bb803 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
@@ -86,6 +86,7 @@ private[spark] class DTStatsAggregator(
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ *
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
*/
@@ -118,6 +119,7 @@ private[spark] class DTStatsAggregator(
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
*/
@@ -138,6 +140,7 @@ private[spark] class DTStatsAggregator(
/**
* For a given feature, merge the stats for two bins.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* @param binIndex The other bin is merged into this bin.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 4f27dc44ef..df8eb5d1f9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index b37f4e891e..0749d93b7d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
-import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index 2c8286766f..9d697a36b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.{LearningNode, Split}
-import org.apache.spark.mllib.tree.impl.BaggedPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index cccf052b3e..7b1fd089f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
- TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
@@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+ * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
index 70afaa162b..4cc250aa46 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable.{HashMap => MutableHashMap}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index 9fa27e5e1f..3a2bf3c725 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index d166dc7905..0f0c6b466d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.impl.TimeTracker
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
@@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
+ *
* @param input Training dataset.
* @param validationInput Validation dataset, ignored if validate is set to false.
* @param boostingStrategy Boosting parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
deleted file mode 100644
index dc7e969f7b..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: DeveloperApi ::
- * This is used by the node id cache to find the child id that a data point would belong to.
- * @param split Split information.
- * @param nodeIndex The current node index of a data point that this will update.
- */
-@DeveloperApi
-private[tree] case class NodeIndexUpdater(
- split: Split,
- nodeIndex: Int) {
- /**
- * Determine a child node index based on the feature value and the split.
- * @param binnedFeatures Binned feature values.
- * @param bins Bin information to convert the bin indices to approximate feature values.
- * @return Child node index to update to.
- */
- def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
- if (split.featureType == Continuous) {
- val featureIndex = split.feature
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- if (featureValueUpperBound <= split.threshold) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- } else {
- if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * A given TreePoint would belong to a particular node per tree.
- * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
- * in each tree. Initially, values should all be 1 for root node.
- * The nodeIdsForInstances RDD needs to be updated at each iteration.
- * @param nodeIdsForInstances The initial values in the cache
- * (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- */
-@DeveloperApi
-private[spark] class NodeIdCache(
- var nodeIdsForInstances: RDD[Array[Int]],
- val checkpointInterval: Int) {
-
- // Keep a reference to a previous node Ids for instances.
- // Because we will keep on re-persisting updated node Ids,
- // we want to unpersist the previous RDD.
- private var prevNodeIdsForInstances: RDD[Array[Int]] = null
-
- // To keep track of the past checkpointed RDDs.
- private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
- private var rddUpdateCount = 0
-
- /**
- * Update the node index values in the cache.
- * This updates the RDD and its lineage.
- * TODO: Passing bin information to executors seems unnecessary and costly.
- * @param data The RDD of training rows.
- * @param nodeIdUpdaters A map of node index updaters.
- * The key is the indices of nodes that we want to update.
- * @param bins Bin information needed to find child node indices.
- */
- def updateNodeIndices(
- data: RDD[BaggedPoint[TreePoint]],
- nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
- bins: Array[Array[Bin]]): Unit = {
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
-
- prevNodeIdsForInstances = nodeIdsForInstances
- nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
- case (point, node) => {
- var treeId = 0
- while (treeId < nodeIdUpdaters.length) {
- val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
- if (nodeIdUpdater != null) {
- val newNodeIndex = nodeIdUpdater.updateNodeIndex(
- binnedFeatures = point.datum.binnedFeatures,
- bins = bins)
- node(treeId) = newNodeIndex
- }
-
- treeId += 1
- }
-
- node
- }
- }
-
- // Keep on persisting new ones.
- nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
- rddUpdateCount += 1
-
- // Handle checkpointing if the directory is not None.
- if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
- (rddUpdateCount % checkpointInterval) == 0) {
- // Let's see if we can delete previous checkpoints.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // We can delete the oldest checkpoint iff
- // the next checkpoint actually exists in the file system.
- if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
- val old = checkpointQueue.dequeue()
-
- // Since the old checkpoint is not deleted by Spark,
- // we'll manually delete it here.
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(old.getCheckpointFile.get), true)
- } else {
- canDelete = false
- }
- }
-
- nodeIdsForInstances.checkpoint()
- checkpointQueue.enqueue(nodeIdsForInstances)
- }
- }
-
- /**
- * Call this after training is finished to delete any remaining checkpoints.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.nonEmpty) {
- val old = checkpointQueue.dequeue()
- for (checkpointFile <- old.getCheckpointFile) {
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(checkpointFile), true)
- }
- }
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
- }
-}
-
-private[spark] object NodeIdCache {
- /**
- * Initialize the node Id cache with initial node Id values.
- * @param data The RDD of training rows.
- * @param numTrees The number of trees that we want to create cache for.
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- * @param initVal The initial values in the cache.
- * @return A node Id cache containing an RDD of initial root node Indices.
- */
- def init(
- data: RDD[BaggedPoint[TreePoint]],
- numTrees: Int,
- checkpointInterval: Int,
- initVal: Int = 1): NodeIdCache = {
- new NodeIdCache(
- data.map(_ => Array.fill[Int](numTrees)(initVal)),
- checkpointInterval)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
deleted file mode 100644
index 21919d69a3..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.Bin
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Internal representation of LabeledPoint for DecisionTree.
- * This bins feature values based on a subsampled of data as follows:
- * (a) Continuous features are binned into ranges.
- * (b) Unordered categorical features are binned based on subsets of feature values.
- * "Unordered categorical features" are categorical features with low arity used in
- * multiclass classification.
- * (c) Ordered categorical features are binned based on feature values.
- * "Ordered categorical features" are categorical features with high arity,
- * or any categorical feature used in regression or binary classification.
- *
- * @param label Label from LabeledPoint
- * @param binnedFeatures Binned feature values.
- * Same length as LabeledPoint.features, but values are bin indices.
- */
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
- extends Serializable {
-}
-
-private[spark] object TreePoint {
-
- /**
- * Convert an input dataset into its TreePoint representation,
- * binning feature values in preparation for DecisionTree training.
- * @param input Input dataset.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param metadata Learning and dataset metadata
- * @return TreePoint dataset representation
- */
- def convertToTreeRDD(
- input: RDD[LabeledPoint],
- bins: Array[Array[Bin]],
- metadata: DecisionTreeMetadata): RDD[TreePoint] = {
- // Construct arrays for featureArity for efficiency in the inner loop.
- val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- var featureIndex = 0
- while (featureIndex < metadata.numFeatures) {
- featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- featureIndex += 1
- }
- input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, featureArity)
- }
- }
-
- /**
- * Convert one LabeledPoint into its TreePoint representation.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
- * for categorical features.
- */
- private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
- bins: Array[Array[Bin]],
- featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
- val arr = new Array[Int](numFeatures)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
- bins)
- featureIndex += 1
- }
- new TreePoint(labeledPoint.label, arr)
- }
-
- /**
- * Find bin for one (labeledPoint, feature).
- *
- * @param featureArity 0 for continuous features; number of categories for categorical features.
- * @param bins Bins for features, of size (numFeatures, numBins).
- */
- private def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- featureArity: Int,
- bins: Array[Array[Bin]]): Int = {
-
- /**
- * Binary search helper method for continuous feature.
- */
- def binarySearchForBins(): Int = {
- val binForFeatures = bins(featureIndex)
- val feature = labeledPoint.features(featureIndex)
- var left = 0
- var right = binForFeatures.length - 1
- while (left <= right) {
- val mid = left + (right - left) / 2
- val bin = binForFeatures(mid)
- val lowThreshold = bin.lowSplit.threshold
- val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)) {
- return mid
- } else if (lowThreshold >= feature) {
- right = mid - 1
- } else {
- left = mid + 1
- }
- }
- -1
- }
-
- if (featureArity == 0) {
- // Perform binary search for finding bin for continuous features.
- val binIndex = binarySearchForBins()
- if (binIndex == -1) {
- throw new RuntimeException("No bin was found for continuous feature." +
- " This error can occur when given invalid data values (such as NaN)." +
- s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
- }
- binIndex
- } else {
- // Categorical feature bins are indexed by feature values.
- val featureValue = labeledPoint.features(featureIndex)
- if (featureValue < 0 || featureValue >= featureArity) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureArity - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- featureValue.toInt
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 13aff11007..ff7700d2d1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -85,7 +85,7 @@ object Entropy extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class EntropyAggregator(numClasses: Int)
+private[spark] class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 39c7f9c3be..58dc79b739 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -81,7 +81,7 @@ object Gini extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class GiniAggregator(numClasses: Int)
+private[spark] class GiniAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 92d74a1b83..2423516123 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -71,7 +71,7 @@ object Variance extends Impurity {
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
-private[tree] class VarianceAggregator()
+private[spark] class VarianceAggregator()
extends ImpurityAggregator(statsSize = 3) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
deleted file mode 100644
index 0cad473782..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-
-/**
- * Used for "binning" the feature values for faster best split calculation.
- *
- * For a continuous feature, the bin is determined by a low and a high split,
- * where an example with featureValue falls into the bin s.t.
- * lowSplit.threshold < featureValue <= highSplit.threshold.
- *
- * For ordered categorical features, there is a 1-1-1 correspondence between
- * bins, splits, and feature values. The bin is determined by category/feature value.
- * However, the bins are not necessarily ordered by feature value;
- * they are ordered using impurity.
- *
- * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
- * where bins and splits correspond to subsets of feature values (in highSplit.categories).
- * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
- * partitionings of categories into 2 disjoint, non-empty sets.
- *
- * @param lowSplit signifying the lower threshold for the continuous feature to be
- * accepted in the bin
- * @param highSplit signifying the upper threshold for the continuous feature to be
- * accepted in the bin
- * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for ordered features
- */
-private[tree]
-case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 9d756da410..77ab3d8bb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 441338e74e..e64551f03c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bb1041b109..49cb7e1f24 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext