aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2017-03-08 02:05:01 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-03-08 02:05:01 -0800
commit1fa58868bc6635ff2119264665bd3d00b4b1253a (patch)
tree00aef01313e44e66f08f57b57da004da4c325d20
parent314e48a3584bad4b486b046bbf0159d64ba857bc (diff)
downloadspark-1fa58868bc6635ff2119264665bd3d00b4b1253a.tar.gz
spark-1fa58868bc6635ff2119264665bd3d00b4b1253a.tar.bz2
spark-1fa58868bc6635ff2119264665bd3d00b4b1253a.zip
[ML][MINOR] Separate estimator and model params for read/write test.
## What changes were proposed in this pull request? Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams``` and ```ALSModelParams```), we should pass in test params for estimator and model separately in function ```testEstimatorAndModelReadWrite```. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #17151 from yanboliang/test-rw.
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala35
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala14
23 files changed, 59 insertions, 54 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9d..10de50306a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
- testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
+ allParamSettings, checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
- checkModelData)
+ allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d..0cddb37281 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index fe47176a4a..4c63a2a88c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
val svm = new LinearSVC()
testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
- checkModelData)
+ LinearSVCSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index d89a958eed..affaa57374 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
}
val lr = new LogisticRegression()
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
- checkModelData)
+ LogisticRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8..4d5d299d14 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
- testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
+ NaiveBayesSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee5..c3003cec73 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -218,7 +218,8 @@ class RandomForestClassifierSuite
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 30513c1e27..200a892f6c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -138,8 +138,8 @@ class BisectingKMeansSuite
assert(model.clusterCenters === model2.clusterCenters)
}
val bisectingKMeans = new BisectingKMeans()
- testEstimatorAndModelReadWrite(
- bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
+ BisectingKMeansSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index c500c5b3e3..61da897b66 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
- testEstimatorAndModelReadWrite(gm, dataset,
+ testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
GaussianMixtureSuite.allParamSettings, checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e10127f7d1..ca05b9c389 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
- testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
+ KMeansSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 9aa11fbdbe..75aa0be61a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
- testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
+ LDASuite.allParamSettings, checkModelData)
}
test("read/write DistributedLDAModel") {
@@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
+ LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ab937685a5..91eac9e733 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
}
val mh = new BucketedRandomProjectionLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
- testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 482e5d5426..d6925da97d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.selectedFeatures === model2.selectedFeatures)
}
val nb = new ChiSqSelector
- testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
+ ChiSqSelectorSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 3461cdf824..a2f009310f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
val mh = new MinHashLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
- testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 74c7461401..076d55c180 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
model2.freqItemsets.sort("items").collect())
}
val fPGrowth = new FPGrowth()
- testEstimatorAndModelReadWrite(
- fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
+ FPGrowthSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index e494ea89e6..a177ed13bf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -518,37 +518,26 @@ class ALSSuite
}
test("read/write") {
- import ALSSuite._
- val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
- val als = new ALS()
- allEstimatorParamSettings.foreach { case (p, v) =>
- als.set(als.getParam(p), v)
- }
val spark = this.spark
import spark.implicits._
- val model = als.fit(ratings.toDF())
-
- // Test Estimator save/load
- val als2 = testDefaultReadWrite(als)
- allEstimatorParamSettings.foreach { case (p, v) =>
- val param = als.getParam(p)
- assert(als.get(param).get === als2.get(param).get)
- }
+ import ALSSuite._
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
- // Test Model save/load
- val model2 = testDefaultReadWrite(model)
- allModelParamSettings.foreach { case (p, v) =>
- val param = model.getParam(p)
- assert(model.get(param).get === model2.get(param).get)
- }
- assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
- assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
- assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+
+ def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
+ assert(model.rank === model2.rank)
+ assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+ assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+ }
+
+ val als = new ALS()
+ testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
+ allModelParamSettings, checkModelData)
}
test("input type validation") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 3cd4b0ac30..708185a094 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite
}
val aft = new AFTSurvivalRegression()
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
- AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
+ AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
+ checkModelData)
}
test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b5..0e91284d03 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
testEstimatorAndModelReadWrite(dt, categoricalData,
- TreeTests.allParamSettings, checkModelData)
+ TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(dt, continuousData,
- TreeTests.allParamSettings, checkModelData)
+ TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea..03c2f97797 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
- testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index add28a72b6..401911763f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+ GeneralizedLinearRegressionSuite.allParamSettings,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 8cbb2acad2..f41a3601b1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -178,7 +178,7 @@ class IsotonicRegressionSuite
val ir = new IsotonicRegression()
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
- checkModelData)
+ IsotonicRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f..6a51e75e12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -985,7 +985,7 @@ class LinearRegressionSuite
}
val lr = new LinearRegression()
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
- checkModelData)
+ LinearRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f8..3bf0445ebd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
- testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 553b8725b3..bfe8f12258 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Check Params on Estimator and Model
* - Compare model data
*
- * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
*
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
- * @param testParams Set of [[Param]] values to set in estimator
+ * @param testEstimatorParams Set of [[Param]] values to set in estimator
+ * @param testModelParams Set of [[Param]] values to set in model
* @param checkModelData Method which takes the original and loaded [[Model]] and compares their
* data. This method does not need to check [[Param]] values.
* @tparam E Type of [[Estimator]]
@@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
dataset: Dataset[_],
- testParams: Map[String, Any],
+ testEstimatorParams: Map[String, Any],
+ testModelParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.
- testParams.foreach { case (p, v) =>
+ testEstimatorParams.foreach { case (p, v) =>
estimator.set(estimator.getParam(p), v)
}
val model = estimator.fit(dataset)
// Test Estimator save/load
val estimator2 = testDefaultReadWrite(estimator)
- testParams.foreach { case (p, v) =>
+ testEstimatorParams.foreach { case (p, v) =>
val param = estimator.getParam(p)
assert(estimator.get(param).get === estimator2.get(param).get)
}
// Test Model save/load
val model2 = testDefaultReadWrite(model)
- testParams.foreach { case (p, v) =>
+ testModelParams.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}