aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala30
19 files changed, 113 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 63d2fa31c7..1f2c9b75b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.sql.DataFrame
class PipelineSuite extends SparkFunSuite {
@@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite {
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
val pipelineModel = pipeline.fit(dataset0)
+ MLTestingUtils.checkCopy(pipelineModel)
+
assert(pipelineModel.stages.length === 4)
assert(pipelineModel.stages(0).eq(model0))
assert(pipelineModel.stages(1).eq(transformer1))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c7bbf1ce07..4b7c5d3f23 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
@@ -244,6 +245,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(newTree)
+
val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index d4b5896c12..e3909bccaa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setCheckpointInterval(2)
val model = gbt.fit(df)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e354e161c6..cce39f382f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -135,6 +136,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
lr.setFitIntercept(false)
val model = lr.fit(dataset)
assert(model.intercept === 0.0)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("logistic regression with setters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index bd8e819f69..977f0e0b70 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(ovaModel)
+
assert(ovaModel.models.size === numClasses)
val transformedDataset = ovaModel.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 6ca4b5aa5f..b4403ec300 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
val predictions = model.transform(df)
.select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index ec85e0d151..0eba34fda6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index c452054bec..c04dda41ee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
@@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
.foreach { case Row(vector1: Vector, vector2: Vector) =>
assert(vector1.equals(vector2), "Transformed vector is different with expected.")
}
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("MinMaxScaler arguments max must be larger than min") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index d0ae36b28c..30c500f87a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
.setK(3)
.fit(df)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(pca)
+
pca.transform(df).select("pca_features", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index b111036087..2d24914cb9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -38,6 +39,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(indexer)
+
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 03120c828c..8cb0a2cf14 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L
test("Throws error when given RDDs with different size vectors") {
val vectorIndexer = getIndexer
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[SparkException] {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index adcda0e623..a2e46f2029 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -62,6 +63,9 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
.setSeed(42L)
.fit(docDF)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(docDF).select("result", "expected").collect().foreach {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 2e5cfe7027..eadc80e0e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
}
logInfo(s"Test RMSE is $rmse.")
assert(rmse < targetRMSE)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("exact rank-1 matrix") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 33aa9d0d62..b092bcd6a7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
@@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}
+ test("copied model must have the same parent") {
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val model = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(8).fit(df)
+ MLTestingUtils.checkCopy(model)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dbdce0c9de..a68197b591 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxDepth(2)
.setMaxIter(2)
val model = gbt.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
val preds = model.transform(df)
val predictions = preds.select("prediction").map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)
@@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
+
}
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 21ad8225bd..2aaee71ecc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -72,6 +73,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lir.getFitIntercept)
assert(lir.getStandardization)
val model = lir.fit(dataset)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(dataset)
.select("label", "prediction")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 992ce95624..7b1b3f1148 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -91,7 +92,11 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
- val importances = rf.fit(df).featureImportances
+ val model = rf.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+ val importances = model.featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index db64511a76..aaca08bb61 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
@@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setEvaluator(eval)
.setNumFolds(3)
val cvModel = cv.fit(dataset)
+
+ // copied model must have the same paren.
+ MLTestingUtils.checkCopy(cvModel)
+
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
new file mode 100644
index 0000000000..d290cc9b06
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+
+object MLTestingUtils {
+ def checkCopy(model: Model[_]): Unit = {
+ val copied = model.copy(ParamMap.empty)
+ .asInstanceOf[Model[_]]
+ assert(copied.parent.uid == model.parent.uid)
+ assert(copied.parent == model.parent)
+ }
+}