aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
commit6fc2b6541fd5ab73b289af5f7296fc602b5b4dce (patch)
treeec8da69765b849a72e0faf5914f25a6dbd4d21f6
parent3f06eb72ca0c3e5779a702c7c677229e0c480751 (diff)
downloadspark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.gz
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.bz2
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.zip
[SPARK-11888][ML] Decision tree persistence in spark.ml
### What changes were proposed in this pull request? Made these MLReadable and MLWritable: DecisionTreeClassifier, DecisionTreeClassificationModel, DecisionTreeRegressor, DecisionTreeRegressionModel * The shared implementation is in treeModels.scala * I use case classes to create a DataFrame to save, and I use the Dataset API to parse loaded files. Other changes: * Made CategoricalSplit.numCategories public (to use in persistence) * Fixed a bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite, where it did not call the checkModelData function passed as an argument. This caused an error in LDASuite, which I fixed. ### How was this patch tested? Persistence is tested via unit tests. For each algorithm, there are 2 non-trivial trees (depth 2). One is built with continuous features, and one with categorical; this ensures that both types of splits are tested. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11581 from jkbradley/dt-io.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala70
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala68
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala132
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala18
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala50
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala35
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala)37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala2
23 files changed, 428 insertions, 71 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index bcbedc8bc1..6ea1abb49b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -17,11 +17,16 @@
package org.apache.spark.ml.classification
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
+import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
@@ -29,6 +34,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
@@ -41,7 +47,7 @@ import org.apache.spark.sql.DataFrame
final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
- with DecisionTreeParams with TreeClassifierParams {
+ with DecisionTreeClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc"))
@@ -115,10 +121,13 @@ final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object DecisionTreeClassifier {
+object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
/** Accessor for supported impurities: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeClassifier = super.load(path)
}
/**
@@ -135,7 +144,7 @@ final class DecisionTreeClassificationModel private[ml] (
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
- with DecisionTreeModel with Serializable {
+ with DecisionTreeModel with DecisionTreeClassifierParams with MLWritable with Serializable {
require(rootNode != null,
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
@@ -200,12 +209,57 @@ final class DecisionTreeClassificationModel private[ml] (
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this)
}
-private[ml] object DecisionTreeClassificationModel {
+@Since("2.0.0")
+object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeClassificationModel] =
+ new DecisionTreeClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeClassificationModel = super.load(path)
+
+ private[DecisionTreeClassificationModel]
+ class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val (nodeData, _) = NodeData.build(instance.rootNode, 0)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
+ }
+ }
+
+ private class DecisionTreeClassificationModelReader
+ extends MLReader[DecisionTreeClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): DecisionTreeClassificationModel = {
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val root = loadTreeNodes(path, metadata, sqlContext)
+ val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d7d6c0f5fa..42411d2d8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -101,7 +101,26 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
/** Decodes a param value from JSON. */
- def jsonDecode(json: String): T = {
+ def jsonDecode(json: String): T = Param.jsonDecode[T](json)
+
+ private[this] val stringRepresentation = s"${parent}__$name"
+
+ override final def toString: String = stringRepresentation
+
+ override final def hashCode: Int = toString.##
+
+ override final def equals(obj: Any): Boolean = {
+ obj match {
+ case p: Param[_] => (p.parent == parent) && (p.name == name)
+ case _ => false
+ }
+ }
+}
+
+private[ml] object Param {
+
+ /** Decodes a param value from JSON. */
+ def jsonDecode[T](json: String): T = {
parse(json) match {
case JString(x) =>
x.asInstanceOf[T]
@@ -116,19 +135,6 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
}
-
- private[this] val stringRepresentation = s"${parent}__$name"
-
- override final def toString: String = stringRepresentation
-
- override final def hashCode: Int = toString.##
-
- override final def equals(obj: Any): Boolean = {
- obj match {
- case p: Param[_] => (p.parent == parent) && (p.name == name)
- case _ => false
- }
- }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6e46292451..428bc7a6d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -17,12 +17,17 @@
package org.apache.spark.ml.regression
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
@@ -31,6 +36,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
+
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
@@ -41,7 +47,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
- with DecisionTreeRegressorParams {
+ with DecisionTreeRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
@@ -107,9 +113,12 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object DecisionTreeRegressor {
+object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressor = super.load(path)
}
/**
@@ -125,13 +134,13 @@ final class DecisionTreeRegressionModel private[ml] (
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
- with DecisionTreeModel with DecisionTreeRegressorParams with Serializable {
+ with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
require(rootNode != null,
- "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+ "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.")
/**
* Construct a decision tree regression model.
@@ -200,12 +209,55 @@ final class DecisionTreeRegressionModel private[ml] (
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this)
}
-private[ml] object DecisionTreeRegressionModel {
+@Since("2.0.0")
+object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[DecisionTreeRegressionModel] =
+ new DecisionTreeRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): DecisionTreeRegressionModel = super.load(path)
+
+ private[DecisionTreeRegressionModel]
+ class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val (nodeData, _) = NodeData.build(instance.rootNode, 0)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
+ }
+ }
+
+ private class DecisionTreeRegressionModelReader
+ extends MLReader[DecisionTreeRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): DecisionTreeRegressionModel = {
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val root = loadTreeNodes(path, metadata, sqlContext)
+ val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeRegressor,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 78199cc2df..9d895b8fac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tree
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
import org.apache.spark.mllib.tree.model.{Split => OldSplit}
@@ -76,7 +76,7 @@ private[tree] object Split {
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
_leftCategories: Array[Double],
- private val numCategories: Int)
+ @Since("2.0.0") val numCategories: Int)
extends Split {
require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 40ed95773e..3e72e85d10 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,7 +17,15 @@
package org.apache.spark.ml.tree
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.util.DefaultParamsReader
import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.sql.SQLContext
/**
* Abstraction for Decision Tree models.
@@ -101,3 +109,127 @@ private[ml] trait TreeEnsembleModel {
/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
+
+/** Helper classes for tree model persistence */
+private[ml] object DecisionTreeModelReadWrite {
+
+ /**
+ * Info for a [[org.apache.spark.ml.tree.Split]]
+ *
+ * @param featureIndex Index of feature split on
+ * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories.
+ * For continuous feature, threshold.
+ * @param numCategories For categorical feature, number of categories.
+ * For continuous feature, -1.
+ */
+ case class SplitData(
+ featureIndex: Int,
+ leftCategoriesOrThreshold: Array[Double],
+ numCategories: Int) {
+
+ def getSplit: Split = {
+ if (numCategories != -1) {
+ new CategoricalSplit(featureIndex, leftCategoriesOrThreshold, numCategories)
+ } else {
+ assert(leftCategoriesOrThreshold.length == 1, s"DecisionTree split data expected" +
+ s" 1 threshold for ContinuousSplit, but found thresholds: " +
+ leftCategoriesOrThreshold.mkString(", "))
+ new ContinuousSplit(featureIndex, leftCategoriesOrThreshold(0))
+ }
+ }
+ }
+
+ object SplitData {
+ def apply(split: Split): SplitData = split match {
+ case s: CategoricalSplit =>
+ SplitData(s.featureIndex, s.leftCategories, s.numCategories)
+ case s: ContinuousSplit =>
+ SplitData(s.featureIndex, Array(s.threshold), -1)
+ }
+ }
+
+ /**
+ * Info for a [[Node]]
+ *
+ * @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
+ * @param impurityStats Stats array. Impurity type is stored in metadata.
+ * @param gain Gain, or arbitrary value if leaf node.
+ * @param leftChild Left child index, or arbitrary value if leaf node.
+ * @param rightChild Right child index, or arbitrary value if leaf node.
+ * @param split Split info, or arbitrary value if leaf node.
+ */
+ case class NodeData(
+ id: Int,
+ prediction: Double,
+ impurity: Double,
+ impurityStats: Array[Double],
+ gain: Double,
+ leftChild: Int,
+ rightChild: Int,
+ split: SplitData)
+
+ object NodeData {
+ /**
+ * Create [[NodeData]] instances for this node and all children.
+ *
+ * @param id Current ID. IDs are assigned via a pre-order traversal.
+ * @return (sequence of nodes in pre-order traversal order, largest ID in subtree)
+ * The nodes are returned in pre-order traversal (root first) so that it is easy to
+ * get the ID of the subtree's root node.
+ */
+ def build(node: Node, id: Int): (Seq[NodeData], Int) = node match {
+ case n: InternalNode =>
+ val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
+ val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
+ val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
+ n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split))
+ (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
+ case _: LeafNode =>
+ (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
+ -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
+ id)
+ }
+ }
+
+ def loadTreeNodes(
+ path: String,
+ metadata: DefaultParamsReader.Metadata,
+ sqlContext: SQLContext): Node = {
+ import sqlContext.implicits._
+ implicit val format = DefaultFormats
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).as[NodeData]
+
+ // Load all nodes, sorted by ID.
+ val nodes: Array[NodeData] = data.collect().sortBy(_.id)
+ // Sanity checks; could remove
+ assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
+ s" but found ${nodes.head.id}")
+ assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" +
+ s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}")
+ // We fill `finalNodes` in reverse order. Since node IDs are assigned via a pre-order
+ // traversal, this guarantees that child nodes will be built before parent nodes.
+ val finalNodes = new Array[Node](nodes.length)
+ nodes.reverseIterator.foreach { case n: NodeData =>
+ val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
+ val node = if (n.leftChild != -1) {
+ val leftChild = finalNodes(n.leftChild)
+ val rightChild = finalNodes(n.rightChild)
+ new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild,
+ n.split.getSplit, impurityStats)
+ } else {
+ new LeafNode(n.prediction, n.impurity, impurityStats)
+ }
+ finalNodes(n.id) = node
+ }
+ // Return the root node
+ finalNodes.head
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 7a651a37ac..3f2d0c7198 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -217,6 +217,9 @@ private[ml] object TreeClassifierParams {
final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
}
+private[ml] trait DecisionTreeClassifierParams
+ extends DecisionTreeParams with TreeClassifierParams
+
/**
* Parameters for Decision Tree-based regression algorithms.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 7b2504361a..329548f95a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -273,7 +273,29 @@ private[ml] object DefaultParamsReader {
sparkVersion: String,
params: JValue,
metadata: JValue,
- metadataJson: String)
+ metadataJson: String) {
+
+ /**
+ * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
+ * This can be useful for getting a Param value before an instance of [[Params]]
+ * is available.
+ */
+ def getParamValue(paramName: String): JValue = {
+ implicit val format = DefaultFormats
+ params match {
+ case JObject(pairs) =>
+ val values = pairs.filter { case (pName, jsonValue) =>
+ pName == paramName
+ }.map(_._2)
+ assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
+ s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
+ values.head
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: $metadataJson.")
+ }
+ }
+ }
/**
* Load metadata from file.
@@ -302,6 +324,7 @@ private[ml] object DefaultParamsReader {
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
+ * TODO: Move to [[Metadata]] method
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 4637dcceea..b2c6e2bba4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -179,3 +179,21 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
}
}
+
+private[spark] object ImpurityCalculator {
+
+ /**
+ * Create an [[ImpurityCalculator]] instance of the given impurity type and with
+ * the given stats.
+ */
+ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
+ impurity match {
+ case "gini" => new GiniCalculator(stats)
+ case "entropy" => new EntropyCalculator(stats)
+ case "variance" => new VarianceCalculator(stats)
+ case _ =>
+ throw new IllegalArgumentException(
+ s"ImpurityCalculator builder did not recognize impurity type: $impurity")
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 0d923dfeff..1f23682621 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -29,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index f470f4ada6..74841058a2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 9a63cef2a8..75061464e5 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a1575300a8..fa3b28ed4f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 9477e8d2bf..8413ea0e0a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index a90535d11a..b6f793f6de 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -28,7 +28,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
-import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 6d68364499..2b07524815 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
@@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
-class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DecisionTreeClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeClassifierSuite.compareAPIs
@@ -338,25 +339,34 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
- val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: DecisionTreeClassificationModel,
+ model2: DecisionTreeClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val dt = new DecisionTreeClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ // Categorical splits with tree depth 2
+ val categoricalData: DataFrame =
+ TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
+ testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 2
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 0
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
+ checkModelData)
}
- */
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 039141aeb6..29efd675ab 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 4c7c56782c..b896099e31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 56b335a33a..662e3fc679 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,8 +18,8 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
@@ -28,7 +28,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DecisionTreeRegressorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeRegressorSuite.compareAPIs
@@ -120,7 +121,33 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: test("model save/load") SPARK-6725
+ test("read/write") {
+ def checkModelData(
+ model: DecisionTreeRegressionModel,
+ model2: DecisionTreeRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ }
+
+ val dt = new DecisionTreeRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ // Categorical splits with tree depth 2
+ val categoricalData: DataFrame =
+ TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
+ testEstimatorAndModelReadWrite(dt, categoricalData,
+ TreeTests.allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 2
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 0
+ testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
+ }
}
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 244db8637b..db68606397 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index efb117f8f9..6be0c8bca0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index d5c238e9ae..9d922291a6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.GiniCalculator
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 5561f6f0ef..12808b0305 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.ml.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._
-import org.apache.spark.SparkContext
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
@@ -154,4 +153,36 @@ private[ml] object TreeTests extends SparkFunSuite {
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
))
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ *
+ * This set of Params is for all Decision Tree-based models.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "checkpointInterval" -> 7,
+ "seed" -> 543L,
+ "maxDepth" -> 2,
+ "maxBins" -> 20,
+ "minInstancesPerNode" -> 2,
+ "minInfoGain" -> 1e-14,
+ "maxMemoryInMB" -> 257,
+ "cacheNodeIds" -> true
+ )
+
+ /** Data for tree read/write tests which produces a non-trivial tree. */
+ def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 2.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 2.0)))
+ sc.parallelize(arr)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 8e5365af84..16280473c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -33,6 +33,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* Checks "overwrite" option and params.
* This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
* in order to avoid conflicts from multiple calls to this method.
+ *
* @param instance ML instance to test saving/loading
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
@@ -85,6 +86,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Compare model data
*
* This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ *
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
* @param testParams Set of [[Param]] values to set in estimator