aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-17 13:15:36 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-17 13:15:36 -0700
commita83571acc938582865efb41645aa1e414f339e46 (patch)
treecd11b3bea5e50946c015d22672f269970f664f9b /mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
parent50ab8a6543ad5c31e89c16df374d0cb13222fd1e (diff)
downloadspark-a83571acc938582865efb41645aa1e414f339e46.tar.gz
spark-a83571acc938582865efb41645aa1e414f339e46.tar.bz2
spark-a83571acc938582865efb41645aa1e414f339e46.zip
[SPARK-6113] [ml] Stabilize DecisionTree API
This is a PR for cleaning up and finalizing the DecisionTree API. PRs for ensembles will follow once this is merged. ### Goal Here is the description copied from the JIRA (for both trees and ensembles): > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design. > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details. > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API. Overall code layout: * The old API in mllib.tree.* will remain the same. * The new API will reside in ml.classification.* and ml.regression.* ### Summary of changes Old API * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release). New APIs * Under Pipeline API * The new API preserves functionality, except: * New API does NOT store prob (probability of label in classification). I want to have it store the full vector of probabilities but feel that should be in a later PR. * Use abstractions for parameters, estimators, and models to avoid code duplication * Limit parameters to relevant algorithms * For enum-like types, only expose Strings * We can make these pluggable later on by adding new parameters. That is a far-future item. Test suites * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves. * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API. * After code is moved to this new API, we should move the tests from the old suites which test the internals. ### Details #### Changed names Parameters * useNodeIdCache -> cacheNodeIds #### Other changes * Split: Changed categories to set instead of list #### Non-decision tree changes * AttributeGroup * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.) * Attributes * Renamed: toMetadata -> toMetadataImpl * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”) * NominalAttribute: Added getNumValues method which examines both numValues and values. * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters) ### Questions for reviewers * Is "DecisionTreeClassificationModel" too long a name? * Is this OK in the docs? ``` class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor] ``` ### Future We should open up the abstractions at some point. E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms. Follow-up JIRAs will be (in this order): * Tree ensembles * Deprecate old tree code * Move DecisionTree implementation code to new API. * Move tests from the old suites which test the internals. * Update programming guide * Python API * Change RandomForest* to always use bootstrapping, even when numTrees = 1 * Provide the probability of the predicted label for classification. After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API. CC: mengxr manishamde codedeft chouqin MechCoder Author: Joseph K. Bradley <joseph@databricks.com> Closes #5530 from jkbradley/dt-api-dt and squashes the following commits: 6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead ec17947 [Joseph K. Bradley] Updates based on code review. Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final 5626c81 [Joseph K. Bradley] style fixes f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR. small cleanups 7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time) e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example 119f407 [Joseph K. Bradley] added DecisionTreeClassifier example 0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet. Need to add example as well 2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala60
1 files changed, 60 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000000..8e3bc3849d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+}