aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2017-04-06 09:41:32 +0200
committerNick Pentreath <nickp@za.ibm.com>2017-04-06 09:41:32 +0200
commite156b5dd39dc1992077fe06e0f8be810c49c8255 (patch)
treebb4e220194e41342b2dfbb6502804080d850b972 /mllib/src
parent5142e5d4e09c7cb36cf1d792934a21c5305c6d42 (diff)
downloadspark-e156b5dd39dc1992077fe06e0f8be810c49c8255.tar.gz
spark-e156b5dd39dc1992077fe06e0f8be810c49c8255.tar.bz2
spark-e156b5dd39dc1992077fe06e0f8be810c49c8255.zip
[SPARK-19953][ML] Random Forest Models use parent UID when being fit
## What changes were proposed in this pull request? The ML `RandomForestClassificationModel` and `RandomForestRegressionModel` were not using the estimator parent UID when being fit. This change fixes that so the models can be properly be identified with their parents. ## How was this patch tested?Existing tests. Added check to verify that model uid matches that of the parent, then renamed `checkCopy` to `checkCopyAndUids` and verified that it was called by one test for each ML algorithm. Author: Bryan Cutler <cutlerb@gmail.com> Closes #17296 from BryanCutler/rfmodels-use-parent-uid-SPARK-19953.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala35
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala8
41 files changed, 98 insertions, 100 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ce834f1d17..ab4c235209 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -140,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = oldDataset.first().features.size
- val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
+ val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
instr.logSuccess(m)
m
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2f524a8c57..a58da50fad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -131,7 +131,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = oldDataset.first().features.size
- val m = new RandomForestRegressionModel(trees, numFeatures)
+ val m = new RandomForestRegressionModel(uid, trees, numFeatures)
instr.logSuccess(m)
m
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index dafc6c200f..4cdbf845ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -79,7 +79,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
val pipelineModel = pipeline.fit(dataset0)
- MLTestingUtils.checkCopy(pipelineModel)
+ MLTestingUtils.checkCopyAndUids(pipeline, pipelineModel)
assert(pipelineModel.stages.length === 4)
assert(pipelineModel.stages(0).eq(model0))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 964fcfbdd8..918ab27e27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -249,8 +249,7 @@ class DecisionTreeClassifierSuite
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(newTree)
+ MLTestingUtils.checkCopyAndUids(dt, newTree)
val predictions = newTree.transform(newData)
.select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0cddb37281..1f79e0d4e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -97,8 +97,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.getProbabilityCol === "probability")
assert(model.hasParent)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(gbt, model)
}
test("setThreshold, getThreshold") {
@@ -261,8 +260,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
.setSeed(123)
val model = gbt.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(gbt, model)
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index c763a4cef1..2f87afc23f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -124,8 +124,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model.hasParent)
assert(model.numFeatures === 2)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(lsvc, model)
}
test("linear svc doesn't fit intercept when fitIntercept is off") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index f0648d0936..c858b9bbfc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -142,8 +142,7 @@ class LogisticRegressionSuite
assert(model.intercept !== 0.0)
assert(model.hasParent)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(lr, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 7700099caa..ce54c3df4f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -74,8 +74,8 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
- MLTestingUtils.checkCopy(model)
val result = model.transform(dataset)
+ MLTestingUtils.checkCopyAndUids(trainer, model)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index d41c5b533d..b56f8e19ca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -149,6 +149,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateModelFit(pi, theta, model)
assert(model.hasParent)
+ MLTestingUtils.checkCopyAndUids(nb, model)
val validationDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index aacb7921b8..c02e38ad64 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -76,8 +76,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(ovaModel)
+ MLTestingUtils.checkCopyAndUids(ova, ovaModel)
assert(ovaModel.models.length === numClasses)
val transformedDataset = ovaModel.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index c3003cec73..ca2954d2f3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -141,8 +141,7 @@ class RandomForestClassifierSuite
val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(rf, model)
val predictions = model.transform(df)
.select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 200a892f6c..fa7471fa2d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -47,8 +47,7 @@ class BisectingKMeansSuite
assert(bkm.getMinDivisibleClusterSize === 1.0)
val model = bkm.setMaxIter(1).fit(dataset)
- // copied model must have the same parent
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(bkm, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 61da897b66..08b800b7e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -77,8 +77,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gm.getTol === 0.01)
val model = gm.setMaxIter(1).fit(dataset)
- // copied model must have the same parent
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(gm, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index ca05b9c389..119fe1dead 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -52,8 +52,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(kmeans.getTol === 1e-4)
val model = kmeans.setMaxIter(1).fit(dataset)
- // copied model must have the same parent
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(kmeans, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 75aa0be61a..b4fe63a89f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -176,7 +176,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2)
val model = lda.fit(dataset)
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(lda, model)
assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize)
@@ -221,7 +221,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2)
val model_ = lda.fit(dataset)
- MLTestingUtils.checkCopy(model_)
+ MLTestingUtils.checkCopyAndUids(lda, model_)
assert(model_.isInstanceOf[DistributedLDAModel])
val model = model_.asInstanceOf[DistributedLDAModel]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index cc81da5c66..7175c721bf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -94,7 +94,8 @@ class BucketedRandomProjectionLSHSuite
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
- MLTestingUtils.checkCopy(brpModel)
+
+ MLTestingUtils.checkCopyAndUids(brp, brpModel)
}
test("BucketedRandomProjectionLSH: test of LSH property") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index d6925da97d..c83909c449 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -119,7 +119,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
test("Test Chi-Square selector: numTopFeatures") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
- ChiSqSelectorSuite.testSelector(selector, dataset)
+ val model = ChiSqSelectorSuite.testSelector(selector, dataset)
+ MLTestingUtils.checkCopyAndUids(selector, model)
}
test("Test Chi-Square selector: percentile") {
@@ -166,11 +167,13 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
object ChiSqSelectorSuite {
- private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = {
- selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect()
+ private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = {
+ val selectorModel = selector.fit(dataset)
+ selectorModel.transform(dataset).select("filtered", "topFeature").collect()
.foreach { case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}
+ selectorModel
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 69d3033bb2..f213145f1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@@ -68,10 +68,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
- .fit(df)
- assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
+ val cvm = cv.fit(df)
+ MLTestingUtils.checkCopyAndUids(cv, cvm)
+ assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
- cv.transform(df).select("features", "expected").collect().foreach {
+ cvm.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 5325d95526..005edf73d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -65,10 +65,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val df = data.zip(expected).toSeq.toDF("features", "expected")
- val idfModel = new IDF()
+ val idfEst = new IDF()
.setInputCol("features")
.setOutputCol("idfValue")
- .fit(df)
+ val idfModel = idfEst.fit(df)
+
+ MLTestingUtils.checkCopyAndUids(idfEst, idfModel)
idfModel.transform(df).select("idfValue", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index a9b559f7ba..dd4dd62b8c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
@@ -58,6 +58,8 @@ private[ml] object LSHTest {
val outputCol = model.getOutputCol
val transformedData = model.transform(dataset)
+ MLTestingUtils.checkCopyAndUids(lsh, model)
+
// Check output column type
SchemaUtils.checkColumnType(
transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
index a12174493b..918da4f938 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
@@ -50,8 +50,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1")
}
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(scaler, model)
}
test("MaxAbsScaler read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 0ddf097a6e..96df68dbdf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -63,7 +63,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCol("values")
val model = mh.fit(dataset)
assert(mh.uid === model.uid)
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(mh, model)
}
test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index b79eeb2d75..51db74eb73 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -53,8 +53,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(vector1.equals(vector2), "Transformed vector is different with expected.")
}
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(scaler, model)
}
test("MinMaxScaler arguments max must be larger than min") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index a60e87590f..3067a52a4d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -58,12 +58,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
.setInputCol("features")
.setOutputCol("pca_features")
.setK(3)
- .fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(pca)
+ val pcaModel = pca.fit(df)
- pca.transform(df).select("pca_features", "expected").collect().foreach {
+ MLTestingUtils.checkCopyAndUids(pca, pcaModel)
+
+ pcaModel.transform(df).select("pca_features", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 5cfd59e6b8..fbebd75d70 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -37,7 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(formula, model)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index a928f93633..350ba44baa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -77,10 +77,11 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
test("Standardization with default parameter") {
val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected")
- val standardScaler0 = new StandardScaler()
+ val standardScalerEst0 = new StandardScaler()
.setInputCol("features")
.setOutputCol("standardized_features")
- .fit(df0)
+ val standardScaler0 = standardScalerEst0.fit(df0)
+ MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0)
assertResult(standardScaler0.transform(df0))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 8d9042b31e..5634d4210f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -45,12 +45,11 @@ class StringIndexerSuite
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
- .fit(df)
+ val indexerModel = indexer.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(indexer)
+ MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
- val transformed = indexer.transform(df)
+ val transformed = indexerModel.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index b28ce2ab45..f2cca8aa82 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -114,8 +114,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
val vectorIndexer = getIndexer
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(vectorIndexer, model)
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 2043a16c15..a6a1c2b4f3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -57,15 +57,14 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val docDF = doc.zip(expected).toDF("text", "expected")
- val model = new Word2Vec()
+ val w2v = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
- .fit(docDF)
+ val model = w2v.fit(docDF)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(w2v, model)
// These expectations are just magic values, characterizing the current
// behavior. The test needs to be updated to be more general, see SPARK-11502
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 6bec057511..6806cb03bc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.fpm
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -121,7 +122,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(fpGrowth, model)
+ ParamsSuite.checkParams(fpGrowth)
+ ParamsSuite.checkParams(model)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index a177ed13bf..7574af3d77 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -409,8 +409,7 @@ class ALSSuite
logInfo(s"Test RMSE is $rmse.")
assert(rmse < targetRMSE)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(als, model)
}
test("exact rank-1 matrix") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 708185a094..fb39e50a83 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -83,8 +83,7 @@ class AFTSurvivalRegressionSuite
.setQuantilesCol("quantiles")
.fit(datasetUnivariate)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(aftr, model)
model.transform(datasetUnivariate)
.select("label", "prediction", "quantiles")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 0e91284d03..642f266891 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -69,11 +69,12 @@ class DecisionTreeRegressorSuite
test("copied model must have the same parent") {
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
- val model = new DecisionTreeRegressor()
+ val dtr = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
- .setMaxBins(8).fit(df)
- MLTestingUtils.checkCopy(model)
+ .setMaxBins(8)
+ val model = dtr.fit(df)
+ MLTestingUtils.checkCopyAndUids(dtr, model)
}
test("predictVariance") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 03c2f97797..2da25f7e01 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -90,8 +90,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
.setMaxIter(2)
val model = gbt.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(gbt, model)
val preds = model.transform(df)
val predictions = preds.select("prediction").rdd.map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 401911763f..f7c7c001a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -197,8 +197,7 @@ class GeneralizedLinearRegressionSuite
val model = glr.setFamily("gaussian").setLink("identity")
.fit(datasetGaussianIdentity)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(glr, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index f41a3601b1..180f5f7ce5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -93,8 +93,7 @@ class IsotonicRegressionSuite
val model = ir.fit(dataset)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(ir, model)
model.transform(dataset)
.select("label", "features", "prediction", "weight")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index c6a267b728..e7bd4eb9e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -148,8 +148,7 @@ class LinearRegressionSuite
assert(lir.getSolver == "auto")
val model = lir.fit(datasetWithDenseFeature)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
+ MLTestingUtils.checkCopyAndUids(lir, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 3bf0445ebd..8b8e8a655f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -90,6 +90,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val model = rf.fit(df)
+ MLTestingUtils.checkCopyAndUids(rf, model)
+
val importances = model.featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 7116265474..2b4e6b53e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -58,8 +58,7 @@ class CrossValidatorSuite
.setNumFolds(3)
val cvModel = cv.fit(dataset)
- // copied model must have the same paren.
- MLTestingUtils.checkCopy(cvModel)
+ MLTestingUtils.checkCopyAndUids(cv, cvModel)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 4463a9b6e5..a34f930aa1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -45,18 +45,18 @@ class TrainValidationSplitSuite
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
- val cv = new TrainValidationSplit()
+ val tvs = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
- val cvModel = cv.fit(dataset)
- val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
- assert(cv.getTrainRatio === 0.5)
+ val tvsModel = tvs.fit(dataset)
+ val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ assert(tvs.getTrainRatio === 0.5)
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
- assert(cvModel.validationMetrics.length === lrParamMaps.length)
+ assert(tvsModel.validationMetrics.length === lrParamMaps.length)
}
test("train validation with linear regression") {
@@ -71,28 +71,27 @@ class TrainValidationSplitSuite
.addGrid(trainer.maxIter, Array(0, 10))
.build()
val eval = new RegressionEvaluator()
- val cv = new TrainValidationSplit()
+ val tvs = new TrainValidationSplit()
.setEstimator(trainer)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
- val cvModel = cv.fit(dataset)
+ val tvsModel = tvs.fit(dataset)
- // copied model must have the same paren.
- MLTestingUtils.checkCopy(cvModel)
+ MLTestingUtils.checkCopyAndUids(tvs, tvsModel)
- val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
+ val parent = tvsModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
- assert(cvModel.validationMetrics.length === lrParamMaps.length)
+ assert(tvsModel.validationMetrics.length === lrParamMaps.length)
eval.setMetricName("r2")
- val cvModel2 = cv.fit(dataset)
- val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
+ val tvsModel2 = tvs.fit(dataset)
+ val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent2.getRegParam === 0.001)
assert(parent2.getMaxIter === 10)
- assert(cvModel2.validationMetrics.length === lrParamMaps.length)
+ assert(tvsModel2.validationMetrics.length === lrParamMaps.length)
}
test("transformSchema should check estimatorParamMaps") {
@@ -104,17 +103,17 @@ class TrainValidationSplitSuite
.addGrid(est.inputCol, Array("input1", "input2"))
.build()
- val cv = new TrainValidationSplit()
+ val tvs = new TrainValidationSplit()
.setEstimator(est)
.setEstimatorParamMaps(paramMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
- cv.transformSchema(new StructType()) // This should pass.
+ tvs.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
- cv.setEstimatorParamMaps(invalidParamMaps)
+ tvs.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
- cv.transformSchema(new StructType())
+ tvs.transformSchema(new StructType())
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 578f31c8e7..bef79e634f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -31,11 +31,15 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
object MLTestingUtils extends SparkFunSuite {
- def checkCopy(model: Model[_]): Unit = {
+
+ def checkCopyAndUids[T <: Estimator[_]](estimator: T, model: Model[_]): Unit = {
+ assert(estimator.uid === model.uid, "Model uid does not match parent estimator")
+
+ // copied model must have the same parent
val copied = model.copy(ParamMap.empty)
.asInstanceOf[Model[_]]
- assert(copied.parent.uid == model.parent.uid)
assert(copied.parent == model.parent)
+ assert(copied.parent.uid == model.parent.uid)
}
def checkNumericTypes[M <: Model[M], T <: Estimator[M]](