aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-07 08:58:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-07 08:58:08 -0700
commit1dbc4a155f3697a3973909806be42a1be6017d12 (patch)
treeaef8a830fbeca56fcfa2ca5b050959f62b2c5c25 /mllib
parent0a63d7ab8a58d3e48d01740729a7832f1834efe8 (diff)
downloadspark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.gz
spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.bz2
spark-1dbc4a155f3697a3973909806be42a1be6017d12.zip
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Add numNodes and depth to treeModels, add treeWeights to ensemble Models. Add __repr__ to all models. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits: 23b08be [MechCoder] private [spark] 38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights 6d16ad8 [MechCoder] Fix Python 3 Error 47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods 819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala5
1 files changed, 5 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 1929f9d021..22873909c3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.tree
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* Abstraction for Decision Tree models.
@@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
+ /** Weights used by the python wrappers. */
+ // Note: An array cannot be returned directly due to serialization problems.
+ private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
+
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.