aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala25
1 files changed, 15 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index c4ab673d9a..f38e1ec7c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite {
sql: SQLContext,
extraMetadata: JObject): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
- val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map {
+ val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
case (tree, treeID) =>
- treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext)
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
+ instance.treeWeights(treeID))
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata")
+ sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
.write.parquet(treesMetadataPath)
val dataPath = new Path(path, "data").toString
val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
@@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite {
path: String,
sql: SQLContext,
className: String,
- treeClassName: String): (Metadata, Array[(Metadata, Node)]) = {
+ treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
import sql.implicits._
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
@@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite {
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath)
- .select("treeID", "metadata").as[(Int, String)].rdd.map {
- case (treeID: Int, json: String) =>
- treeID -> DefaultParamsReader.parseMetadata(json, treeClassName)
+ val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
+ case (treeID: Int, json: String, weights: Double) =>
+ treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights)
}
- val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect()
+
+ val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+ val treesMetadata = treesMetadataWeights.map(_._1)
+ val treesWeights = treesMetadataWeights.map(_._2)
val dataPath = new Path(path, "data").toString
val nodeData: Dataset[EnsembleNodeData] =
@@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite {
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
- (metadata, treesMetadata.zip(rootNodes))
+ (metadata, treesMetadata.zip(rootNodes), treesWeights)
}
/**