From 6eb7008b7f33a36b06d0615b68cc21ed90ad1d8a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 14:03:49 -0800 Subject: [SPARK-11763][ML] Add save,load to LogisticRegression Estimator Add save/load to LogisticRegression Estimator, and refactor tests a little to make it easier to add similar support to other Estimator, Model pairs. Moved LogisticRegressionReader/Writer to within LogisticRegressionModel CC: mengxr Author: Joseph K. Bradley Closes #9749 from jkbradley/lr-io-2. --- .../scala/org/apache/spark/ml/PipelineSuite.scala | 7 --- .../spark/ml/classification/ClassifierSuite.scala | 32 ++++++++++++++ .../classification/LogisticRegressionSuite.scala | 37 ++++++++++++---- .../ProbabilisticClassifierSuite.scala | 14 ++++++ .../spark/ml/util/DefaultReadWriteTest.scala | 50 +++++++++++++++++++++- 5 files changed, 123 insertions(+), 17 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala (limited to 'mllib/src/test/scala') diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 484026b1ba..7f5c3895ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(pipeline2.stages(0).isInstanceOf[WritableStage]) val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] assert(writableStage.getIntParam === writableStage2.getIntParam) - - val path = new File(tempDir, pipeline.uid).getPath - val stagesDir = new Path(path, "stages").toString - val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) - assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), - s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + - s" to be saved to path: $expectedStagePath") } test("PipelineModel read/write: getStagePath") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 0000000000..d0e3fe7ad1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +object ClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "rawPredictionCol" -> "myRawPrediction" + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 51b06b7eb6..48ce1bb630 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -873,15 +873,34 @@ class LogisticRegressionSuite } test("read/write") { - // Set some Params to make sure set Params are serialized. + def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + assert(model.numClasses === model2.numClasses) + assert(model.numFeatures === model2.numFeatures) + } val lr = new LogisticRegression() - .setElasticNetParam(0.1) - .setMaxIter(2) - .fit(dataset) - val lr2 = testDefaultReadWrite(lr) - assert(lr.intercept === lr2.intercept) - assert(lr.coefficients.toArray === lr2.coefficients.toArray) - assert(lr.numClasses === lr2.numClasses) - assert(lr.numFeatures === lr2.numFeatures) + testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, + checkModelData) } } + +object LogisticRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6), + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> false, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index fb5f00e064..cfa75ecf38 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } } + +object ProbabilisticClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6) + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index c37f0503f1..dd1e8acce9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,13 +22,17 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Model, Estimator} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. + * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name + * in order to avoid conflicts from multiple calls to this method. * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type @@ -38,7 +42,10 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance: T, testParams: Boolean = true): T = { val uid = instance.uid - val path = new File(tempDir, uid).getPath + val subdirName = Identifiable.randomUID("test") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath instance.save(path) intercept[IOException] { @@ -69,6 +76,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(another.uid === instance.uid) another } + + /** + * Default test for Estimator, Model pairs: + * - Explicitly set Params, and train model + * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Check Params on Estimator and Model + * + * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * @param estimator Estimator to test + * @param dataset Dataset to pass to [[Estimator.fit()]] + * @param testParams Set of [[Param]] values to set in estimator + * @param checkModelData Method which takes the original and loaded [[Model]] and compares their + * data. This method does not need to check [[Param]] values. + * @tparam E Type of [[Estimator]] + * @tparam M Type of [[Model]] produced by estimator + */ + def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + estimator: E, + dataset: DataFrame, + testParams: Map[String, Any], + checkModelData: (M, M) => Unit): Unit = { + // Set some Params to make sure set Params are serialized. + testParams.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + val model = estimator.fit(dataset) + + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator) + testParams.foreach { case (p, v) => + val param = estimator.getParam(p) + assert(estimator.get(param).get === estimator2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + testParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + } } class MyParams(override val uid: String) extends Params with Writable { -- cgit v1.2.3