aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-17 13:15:36 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-17 13:15:36 -0700
commita83571acc938582865efb41645aa1e414f339e46 (patch)
treecd11b3bea5e50946c015d22672f269970f664f9b /mllib/src/main/scala/org/apache
parent50ab8a6543ad5c31e89c16df374d0cb13222fd1e (diff)
downloadspark-a83571acc938582865efb41645aa1e414f339e46.tar.gz
spark-a83571acc938582865efb41645aa1e414f339e46.tar.bz2
spark-a83571acc938582865efb41645aa1e414f339e46.zip
[SPARK-6113] [ml] Stabilize DecisionTree API
This is a PR for cleaning up and finalizing the DecisionTree API. PRs for ensembles will follow once this is merged. ### Goal Here is the description copied from the JIRA (for both trees and ensembles): > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design. > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details. > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API. Overall code layout: * The old API in mllib.tree.* will remain the same. * The new API will reside in ml.classification.* and ml.regression.* ### Summary of changes Old API * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release). New APIs * Under Pipeline API * The new API preserves functionality, except: * New API does NOT store prob (probability of label in classification). I want to have it store the full vector of probabilities but feel that should be in a later PR. * Use abstractions for parameters, estimators, and models to avoid code duplication * Limit parameters to relevant algorithms * For enum-like types, only expose Strings * We can make these pluggable later on by adding new parameters. That is a far-future item. Test suites * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves. * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API. * After code is moved to this new API, we should move the tests from the old suites which test the internals. ### Details #### Changed names Parameters * useNodeIdCache -> cacheNodeIds #### Other changes * Split: Changed categories to set instead of list #### Non-decision tree changes * AttributeGroup * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.) * Attributes * Renamed: toMetadata -> toMetadataImpl * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”) * NominalAttribute: Added getNumValues method which examines both numValues and values. * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters) ### Questions for reviewers * Is "DecisionTreeClassificationModel" too long a name? * Is this OK in the docs? ``` class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor] ``` ### Future We should open up the abstractions at some point. E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms. Follow-up JIRAs will be (in this order): * Tree ensembles * Deprecate old tree code * Move DecisionTree implementation code to new API. * Move tests from the old suites which test the internals. * Update programming guide * Python API * Change RandomForest* to always use bootstrapping, even when numTrees = 1 * Provide the probability of the predicted label for classification. After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API. CC: mengxr manishamde codedeft chouqin MechCoder Author: Joseph K. Bradley <joseph@databricks.com> Closes #5530 from jkbradley/dt-api-dt and squashes the following commits: 6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead ec17947 [Joseph K. Bradley] Updates based on code review. Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final 5626c81 [Joseph K. Bradley] style fixes f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR. small cleanups 7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time) e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example 119f407 [Joseph K. Bradley] added DecisionTreeClassifier example 0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet. Need to add example as well 2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala43
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala300
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala145
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala205
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala151
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala60
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala32
23 files changed, 1196 insertions, 58 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index aa27a668f1..d7dee8fed2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -117,12 +117,12 @@ class AttributeGroup private (
case numeric: NumericAttribute =>
// Skip default numeric attributes.
if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
- numericMetadata += numeric.toMetadata(withType = false)
+ numericMetadata += numeric.toMetadataImpl(withType = false)
}
case nominal: NominalAttribute =>
- nominalMetadata += nominal.toMetadata(withType = false)
+ nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute =>
- binaryMetadata += binary.toMetadata(withType = false)
+ binaryMetadata += binary.toMetadataImpl(withType = false)
}
val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) {
@@ -151,7 +151,7 @@ class AttributeGroup private (
}
/** Converts to ML metadata */
- def toMetadata: Metadata = toMetadata(Metadata.empty)
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
/** Converts to a StructField with some existing metadata. */
def toStructField(existingMetadata: Metadata): StructField = {
@@ -159,7 +159,7 @@ class AttributeGroup private (
}
/** Converts to a StructField. */
- def toStructField: StructField = toStructField(Metadata.empty)
+ def toStructField(): StructField = toStructField(Metadata.empty)
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 00b7566aab..5717d6ec2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable {
* Converts this attribute to [[Metadata]].
* @param withType whether to include the type info
*/
- private[attribute] def toMetadata(withType: Boolean): Metadata
+ private[attribute] def toMetadataImpl(withType: Boolean): Metadata
/**
* Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
* save space, because numeric type is the default attribute type. For nominal and binary
* attributes, the type info is included.
*/
- private[attribute] def toMetadata(): Metadata = {
+ private[attribute] def toMetadataImpl(): Metadata = {
if (attrType == AttributeType.Numeric) {
- toMetadata(withType = false)
+ toMetadataImpl(withType = false)
} else {
- toMetadata(withType = true)
+ toMetadataImpl(withType = true)
}
}
+ /** Converts to ML metadata with some existing metadata. */
+ def toMetadata(existingMetadata: Metadata): Metadata = {
+ new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
+ .build()
+ }
+
+ /** Converts to ML metadata */
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
+
/**
* Converts to a [[StructField]] with some existing metadata.
* @param existingMetadata existing metadata to carry over
@@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable {
def toStructField(existingMetadata: Metadata): StructField = {
val newMetadata = new MetadataBuilder()
.withMetadata(existingMetadata)
- .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+ .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
.build()
StructField(name.get, DoubleType, nullable = false, newMetadata)
}
@@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable {
/** Converts to a [[StructField]]. */
def toStructField(): StructField = toStructField(Metadata.empty)
- override def toString: String = toMetadata(withType = true).toString
+ override def toString: String = toMetadataImpl(withType = true).toString
}
/** Trait for ML attribute factories. */
@@ -210,7 +221,7 @@ class NumericAttribute private[ml] (
override def isNominal: Boolean = false
/** Convert this attribute to metadata. */
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -353,6 +364,20 @@ class NominalAttribute private[ml] (
/** Copy without the `numValues`. */
def withoutNumValues: NominalAttribute = copy(numValues = None)
+ /**
+ * Get the number of values, either from `numValues` or from `values`.
+ * Return None if unknown.
+ */
+ def getNumValues: Option[Int] = {
+ if (numValues.nonEmpty) {
+ numValues
+ } else if (values.nonEmpty) {
+ Some(values.get.length)
+ } else {
+ None
+ }
+ }
+
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
@@ -363,7 +388,7 @@ class NominalAttribute private[ml] (
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -465,7 +490,7 @@ class BinaryAttribute private[ml] (
new BinaryAttribute(name, index, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder
if (withType) bldr.putString(TYPE, attrType.name)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000000..3855e396b5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassifier
+ extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ with DecisionTreeParams
+ with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type =
+ super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
+ s" with invalid label column, without the number of classes specified.")
+ // TODO: Automatically index labels.
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ override private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
+ strategy.algo = OldAlgo.Classification
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeClassifier {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassificationModel private[ml] (
+ override val parent: DecisionTreeClassifier,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeClassificationModel = {
+ val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+ }
+}
+
+private[ml] object DecisionTreeClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification,
+ s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+ s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 4d960df357..23956c512c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -118,7 +118,7 @@ class StringIndexerModel private[ml] (
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
- .withName(outputColName).withValues(labels).toStructField().metadata
+ .withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
new file mode 100644
index 0000000000..6f4509f03d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.impl.estimator.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
+ Impurity => OldImpurity, Variance => OldVariance}
+
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+ /**
+ * Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree." +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be >= 2 and >= number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature.")
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Should be >= 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Should be >= 1.")
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.")
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.")
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertParam
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+ " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
+ " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+ " checkpoint directory is set in the SparkContext. Must be >= 1.")
+
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+ maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ def setMaxDepth(value: Int): this.type = {
+ require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
+ set(maxDepth, value)
+ this.asInstanceOf[this.type]
+ }
+
+ /** @group getParam */
+ def getMaxDepth: Int = getOrDefault(maxDepth)
+
+ /** @group setParam */
+ def setMaxBins(value: Int): this.type = {
+ require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
+ set(maxBins, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMaxBins: Int = getOrDefault(maxBins)
+
+ /** @group setParam */
+ def setMinInstancesPerNode(value: Int): this.type = {
+ require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
+ set(minInstancesPerNode, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+
+ /** @group setParam */
+ def setMinInfoGain(value: Double): this.type = {
+ set(minInfoGain, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInfoGain: Double = getOrDefault(minInfoGain)
+
+ /** @group expertSetParam */
+ def setMaxMemoryInMB(value: Int): this.type = {
+ require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
+ set(maxMemoryInMB, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+
+ /** @group expertSetParam */
+ def setCacheNodeIds(value: Boolean): this.type = {
+ set(cacheNodeIds, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+
+ /** @group expertSetParam */
+ def setCheckpointInterval(value: Int): this.type = {
+ require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
+ set(checkpointInterval, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
+ * the default for single trees).
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy
+ }
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "gini")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based classifier was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "variance")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based regressor was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ protected def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object TreeRegressorParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index b45bd1499b..ac75e9de1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -32,6 +32,18 @@ package org.apache.spark
* @groupname getParam Parameter getters
* @groupprio getParam 6
*
+ * @groupname expertParam (expert-only) Parameters
+ * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can
+ * take. Users can set and get the parameter values through setters and getters,
+ * respectively.
+ * @groupprio expertParam 7
+ *
+ * @groupname expertSetParam (expert-only) Parameter setters
+ * @groupprio expertSetParam 8
+ *
+ * @groupname expertGetParam (expert-only) Parameter getters
+ * @groupprio expertGetParam 9
+ *
* @groupname Ungrouped Members
* @groupprio Ungrouped 0
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 849c60433c..ddc5907e7f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -296,8 +296,9 @@ private[spark] object Params {
paramMap: ParamMap,
parent: E,
child: M): Unit = {
+ val childParams = child.params.map(_.name).toSet
parent.params.foreach { param =>
- if (paramMap.contains(param)) {
+ if (paramMap.contains(param) && childParams.contains(param.name)) {
child.set(child.getParam(param.name), paramMap(param))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000000..49a8b77acf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class DecisionTreeRegressor
+ extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+ with DecisionTreeParams
+ with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
+ strategy.algo = OldAlgo.Regression
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeRegressor {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param rootNode Root of the decision tree
+ */
+@AlphaComponent
+final class DecisionTreeRegressionModel private[ml] (
+ override val parent: DecisionTreeRegressor,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeRegressionModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeRegressionModel = {
+ val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+ }
+}
+
+private[ml] object DecisionTreeRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression,
+ s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+ s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000000..d6e2203d9f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
+ Node => OldNode, Predict => OldPredict}
+
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+ // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
+ // code into the new API and deprecate the old API.
+
+ /** Prediction this node makes (or would make, if it is an internal node) */
+ def prediction: Double
+
+ /** Impurity measure at this node (for training data) */
+ def impurity: Double
+
+ /** Recursive prediction helper method */
+ private[ml] def predict(features: Vector): Double = prediction
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
+ */
+ private[tree] def subtreeDepth: Int
+
+ /**
+ * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+ * @param id Node ID using old format IDs
+ */
+ private[ml] def toOld(id: Int): OldNode
+}
+
+private[ml] object Node {
+
+ /**
+ * Create a new Node from the old Node format, recursively creating child nodes as needed.
+ */
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ if (oldNode.isLeaf) {
+ // TODO: Once the implementation has been moved to this API, then include sufficient
+ // statistics here.
+ new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ } else {
+ val gain = if (oldNode.stats.nonEmpty) {
+ oldNode.stats.get.gain
+ } else {
+ 0.0
+ }
+ new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+ gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ }
+ }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction Prediction this node makes
+ * @param impurity Impurity measure at this node (for training data)
+ */
+final class LeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double) extends Node {
+
+ override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+ override private[ml] def predict(features: Vector): Double = prediction
+
+ override private[tree] def numDescendants: Int = 0
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"Predict: $prediction\n"
+ }
+
+ override private[tree] def subtreeDepth: Int = 0
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
+ None, None, None, None)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction Prediction this node would make if it were a leaf node
+ * @param impurity Impurity measure at this node (for training data)
+ * @param gain Information gain value.
+ * Values < 0 indicate missing values; this quirk will be removed with future updates.
+ * @param leftChild Left-hand child node
+ * @param rightChild Right-hand child node
+ * @param split Information about the test used to split to the left or right child.
+ */
+final class InternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ val gain: Double,
+ val leftChild: Node,
+ val rightChild: Node,
+ val split: Split) extends Node {
+
+ override def toString: String = {
+ s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+ }
+
+ override private[ml] def predict(features: Vector): Double = {
+ if (split.shouldGoLeft(features)) {
+ leftChild.predict(features)
+ } else {
+ rightChild.predict(features)
+ }
+ }
+
+ override private[tree] def numDescendants: Int = {
+ 2 + leftChild.numDescendants + rightChild.numDescendants
+ }
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+ leftChild.subtreeToString(indentFactor + 1) +
+ prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+ rightChild.subtreeToString(indentFactor + 1)
+ }
+
+ override private[tree] def subtreeDepth: Int = {
+ 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+ }
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ + " since the old API does not support deep trees.")
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
+ Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+ Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+ new OldPredict(leftChild.prediction, prob = 0.0),
+ new OldPredict(rightChild.prediction, prob = 0.0))))
+ }
+}
+
+private object InternalNode {
+
+ /**
+ * Helper method for [[Node.subtreeToString()]].
+ * @param split Split to print
+ * @param left Indicates whether this is the part of the split going to the left,
+ * or that going to the right.
+ */
+ private def splitToString(split: Split, left: Boolean): String = {
+ val featureStr = s"feature ${split.featureIndex}"
+ split match {
+ case contSplit: ContinuousSplit =>
+ if (left) {
+ s"$featureStr <= ${contSplit.threshold}"
+ } else {
+ s"$featureStr > ${contSplit.threshold}"
+ }
+ case catSplit: CategoricalSplit =>
+ val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ if (left) {
+ s"$featureStr in $categoriesStr"
+ } else {
+ s"$featureStr not in $categoriesStr"
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
new file mode 100644
index 0000000000..cb940f6299
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
+import org.apache.spark.mllib.tree.model.{Split => OldSplit}
+
+
+/**
+ * Interface for a "Split," which specifies a test made at a decision tree node
+ * to choose the left or right path.
+ */
+sealed trait Split extends Serializable {
+
+ /** Index of feature which this split tests */
+ def featureIndex: Int
+
+ /** Return true (split to left) or false (split to right) */
+ private[ml] def shouldGoLeft(features: Vector): Boolean
+
+ /** Convert to old Split format */
+ private[tree] def toOld: OldSplit
+}
+
+private[ml] object Split {
+
+ def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
+ oldSplit.featureType match {
+ case OldFeatureType.Categorical =>
+ new CategoricalSplit(featureIndex = oldSplit.feature,
+ leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ case OldFeatureType.Continuous =>
+ new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
+ }
+ }
+}
+
+/**
+ * Split which tests a categorical feature.
+ * @param featureIndex Index of the feature to test
+ * @param leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
+ * @param numCategories Number of categories for this feature.
+ */
+final class CategoricalSplit(
+ override val featureIndex: Int,
+ leftCategories: Array[Double],
+ private val numCategories: Int)
+ extends Split {
+
+ require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+
+ /**
+ * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
+ */
+ private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+
+ /** Set of categories determining the splitting rule, along with [[isLeft]]. */
+ private val categories: Set[Double] = {
+ if (isLeft) {
+ leftCategories.toSet
+ } else {
+ setComplement(leftCategories.toSet)
+ }
+ }
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ if (isLeft) {
+ categories.contains(features(featureIndex))
+ } else {
+ !categories.contains(features(featureIndex))
+ }
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ val oldCats = if (isLeft) {
+ categories
+ } else {
+ setComplement(categories)
+ }
+ OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
+ }
+
+ /** Get sorted categories which split to the left */
+ def getLeftCategories: Array[Double] = {
+ val cats = if (isLeft) categories else setComplement(categories)
+ cats.toArray.sorted
+ }
+
+ /** Get sorted categories which split to the right */
+ def getRightCategories: Array[Double] = {
+ val cats = if (isLeft) setComplement(categories) else categories
+ cats.toArray.sorted
+ }
+
+ /** [0, numCategories) \ cats */
+ private def setComplement(cats: Set[Double]): Set[Double] = {
+ Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
+ }
+}
+
+/**
+ * Split which tests a continuous feature.
+ * @param featureIndex Index of the feature to test
+ * @param threshold If the feature value is <= this threshold, then the split goes left.
+ * Otherwise, it goes right.
+ */
+final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ features(featureIndex) <= threshold
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: ContinuousSplit =>
+ featureIndex == other.featureIndex && threshold == other.threshold
+ case _ =>
+ false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000000..8e3bc3849d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
new file mode 100644
index 0000000000..c84c8b4eb7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
+ NumericAttribute}
+import org.apache.spark.sql.types.StructField
+
+
+/**
+ * :: Experimental ::
+ *
+ * Helper utilities for tree-based algorithms
+ */
+@Experimental
+object MetadataUtils {
+
+ /**
+ * Examine a schema to identify the number of classes in a label column.
+ * Returns None if the number of labels is not specified, or if the label column is continuous.
+ */
+ def getNumClasses(labelSchema: StructField): Option[Int] = {
+ Attribute.fromStructField(labelSchema) match {
+ case numAttr: NumericAttribute => None
+ case binAttr: BinaryAttribute => Some(2)
+ case nomAttr: NominalAttribute => nomAttr.getNumValues
+ }
+ }
+
+ /**
+ * Examine a schema to identify categorical (Binary and Nominal) features.
+ *
+ * @param featuresSchema Schema of the features column.
+ * If a feature does not have metadata, it is assumed to be continuous.
+ * If a feature is Nominal, then it must have the number of values
+ * specified.
+ * @return Map: feature index --> number of categories.
+ * The map's set of keys will be the set of categorical feature indices.
+ */
+ def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
+ val metadata = AttributeGroup.fromStructField(featuresSchema)
+ if (metadata.attributes.isEmpty) {
+ HashMap.empty[Int, Int]
+ } else {
+ metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
+ if (attr == null) {
+ Iterator()
+ } else {
+ attr match {
+ case numAttr: NumericAttribute => Iterator()
+ case binAttr: BinaryAttribute => Iterator(idx -> 2)
+ case nomAttr: NominalAttribute =>
+ nomAttr.getNumValues match {
+ case Some(numValues: Int) => Iterator(idx -> numValues)
+ case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
+ " Nominal (categorical), but it does not have the number of values specified.")
+ }
+ }
+ }
+ }.toMap
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b9d0c56dd1..dfe3a0b691 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging {
}
}
- assert(splits.length > 0)
+ // TODO: Do not fail; just ignore the useless feature.
+ assert(splits.length > 0,
+ s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+ " Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index c02c79f094..0e31c7ed58 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @param validationInput Validation dataset:
- RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- Should be different from and follow the same distribution as input.
- e.g., these two datasets could be created from an original dataset
- by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return a gradient boosted trees model that can be used for prediction
*/
def runWithValidation(
@@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging {
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
- val startingModel = new GradientBoostedTreesModel(
- Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index db01f2e229..055e60c7d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
nodeIdCache.get.deleteAllCheckpoints()
} catch {
case e:IOException =>
- logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
+ logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 664c8df019..2d6b01524f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -89,14 +89,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStragtegy = Strategy.defaultStategy(algo)
- treeStragtegy.maxDepth = 3
+ val treeStrategy = Strategy.defaultStategy(algo)
+ treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
- treeStragtegy.numClasses = 2
- new BoostingStrategy(treeStragtegy, LogLoss)
+ treeStrategy.numClasses = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case Algo.Regression =>
- new BoostingStrategy(treeStragtegy, SquaredError)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 6f570b4e09..2bdef73c4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 24ee9f3d51..778c24526d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -47,10 +47,9 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index d3b82b752f..64ffccbce0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -57,6 +58,5 @@ trait Loss extends Serializable {
* @param label True label.
* @return Measure of model error on datapoint.
*/
- def computeError(prediction: Double, label: Double): Double
-
+ private[mllib] def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 58857ae15e..a5582d3ef3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object SquaredError extends Loss {
2.0 * (prediction - label)
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = prediction - label
err * err
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c9bafd60fb..331af42853 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = DecisionTreeModel.formatVersion
}
object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
+ private[spark] def formatVersion: String = "1.0"
+
private[tree] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 4f72bb8014..708ba04b56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -175,7 +175,7 @@ class Node (
}
}
-private[tree] object Node {
+private[spark] object Node {
/**
* Return a node with the given node id (but nothing else set).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index fef3d2acb2..8341219bfa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
+
/**
* :: Experimental ::
* Represents a random forest model.
@@ -47,7 +48,7 @@ import org.apache.spark.util.Utils
*/
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
- extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy = if (algo == Classification) Vote else Average)
with Saveable {
@@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
RandomForestModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+ override protected def formatVersion: String = RandomForestModel.formatVersion
}
object RandomForestModel extends Loader[RandomForestModel] {
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): RandomForestModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -102,15 +105,13 @@ class GradientBoostedTreesModel(
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
with Saveable {
- require(trees.size == treeWeights.size)
+ require(trees.length == treeWeights.length)
override def save(sc: SparkContext, path: String): Unit = {
TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
-
/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -138,7 +139,7 @@ class GradientBoostedTreesModel(
evaluationArray(0) = predictionAndError.values.mean()
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).map { nTree =>
+ (1 until numIterations).foreach { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = localTreeWeights(nTree)
@@ -155,6 +156,7 @@ class GradientBoostedTreesModel(
evaluationArray
}
+ override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
- iter.map {
- case (lp, (pred, error)) => {
- val newPred = pred + tree.predict(lp.features) * treeWeight
- val newError = loss.computeError(newPred, lp.label)
- (newPred, newError)
- }
+ iter.map { case (lp, (pred, error)) =>
+ val newPred = pred + tree.predict(lp.features) * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
}
}
newPredError
}
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel(
}
/**
- * Get number of trees in forest.
+ * Get number of trees in ensemble.
*/
- def numTrees: Int = trees.size
+ def numTrees: Int = trees.length
/**
- * Get total number of nodes, summed over all trees in the forest.
+ * Get total number of nodes, summed over all trees in the ensemble.
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}