aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala32
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala50
5 files changed, 123 insertions, 17 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 484026b1ba..7f5c3895ac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(pipeline2.stages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)
-
- val path = new File(tempDir, pipeline.uid).getPath
- val stagesDir = new Path(path, "stages").toString
- val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
- assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
- s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
- s" to be saved to path: $expectedStagePath")
}
test("PipelineModel read/write: getStagePath") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
new file mode 100644
index 0000000000..d0e3fe7ad1
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+object ClassifierSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "rawPredictionCol" -> "myRawPrediction"
+ )
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 51b06b7eb6..48ce1bb630 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -873,15 +873,34 @@ class LogisticRegressionSuite
}
test("read/write") {
- // Set some Params to make sure set Params are serialized.
+ def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients.toArray === model2.coefficients.toArray)
+ assert(model.numClasses === model2.numClasses)
+ assert(model.numFeatures === model2.numFeatures)
+ }
val lr = new LogisticRegression()
- .setElasticNetParam(0.1)
- .setMaxIter(2)
- .fit(dataset)
- val lr2 = testDefaultReadWrite(lr)
- assert(lr.intercept === lr2.intercept)
- assert(lr.coefficients.toArray === lr2.coefficients.toArray)
- assert(lr.numClasses === lr2.numClasses)
- assert(lr.numFeatures === lr2.numFeatures)
+ testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
+ checkModelData)
}
}
+
+object LogisticRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map(
+ "probabilityCol" -> "myProbability",
+ "thresholds" -> Array(0.4, 0.6),
+ "regParam" -> 0.01,
+ "elasticNetParam" -> 0.1,
+ "maxIter" -> 2, // intentionally small
+ "fitIntercept" -> false,
+ "tol" -> 0.8,
+ "standardization" -> false,
+ "threshold" -> 0.6
+ )
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index fb5f00e064..cfa75ecf38 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
}
}
+
+object ProbabilisticClassifierSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map(
+ "probabilityCol" -> "myProbability",
+ "thresholds" -> Array(0.4, 0.6)
+ )
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index c37f0503f1..dd1e8acce9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -22,13 +22,17 @@ import java.io.{File, IOException}
import org.scalatest.Suite
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Model, Estimator}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/**
* Checks "overwrite" option and params.
+ * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
+ * in order to avoid conflicts from multiple calls to this method.
* @param instance ML instance to test saving/loading
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
@@ -38,7 +42,10 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
instance: T,
testParams: Boolean = true): T = {
val uid = instance.uid
- val path = new File(tempDir, uid).getPath
+ val subdirName = Identifiable.randomUID("test")
+
+ val subdir = new File(tempDir, subdirName)
+ val path = new File(subdir, uid).getPath
instance.save(path)
intercept[IOException] {
@@ -69,6 +76,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
assert(another.uid === instance.uid)
another
}
+
+ /**
+ * Default test for Estimator, Model pairs:
+ * - Explicitly set Params, and train model
+ * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model
+ * - Check Params on Estimator and Model
+ *
+ * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ * @param estimator Estimator to test
+ * @param dataset Dataset to pass to [[Estimator.fit()]]
+ * @param testParams Set of [[Param]] values to set in estimator
+ * @param checkModelData Method which takes the original and loaded [[Model]] and compares their
+ * data. This method does not need to check [[Param]] values.
+ * @tparam E Type of [[Estimator]]
+ * @tparam M Type of [[Model]] produced by estimator
+ */
+ def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable](
+ estimator: E,
+ dataset: DataFrame,
+ testParams: Map[String, Any],
+ checkModelData: (M, M) => Unit): Unit = {
+ // Set some Params to make sure set Params are serialized.
+ testParams.foreach { case (p, v) =>
+ estimator.set(estimator.getParam(p), v)
+ }
+ val model = estimator.fit(dataset)
+
+ // Test Estimator save/load
+ val estimator2 = testDefaultReadWrite(estimator)
+ testParams.foreach { case (p, v) =>
+ val param = estimator.getParam(p)
+ assert(estimator.get(param).get === estimator2.get(param).get)
+ }
+
+ // Test Model save/load
+ val model2 = testDefaultReadWrite(model)
+ testParams.foreach { case (p, v) =>
+ val param = model.getParam(p)
+ assert(model.get(param).get === model2.get(param).get)
+ }
+ }
}
class MyParams(override val uid: String) extends Params with Writable {