aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala91
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala32
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala50
7 files changed, 173 insertions, 59 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a88f526741..71c2533bcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
@Experimental
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
- with LogisticRegressionParams with Logging {
+ with LogisticRegressionParams with Writable with Logging {
def this() = this(Identifiable.randomUID("logreg"))
@@ -385,6 +385,12 @@ class LogisticRegression(override val uid: String)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object LogisticRegression extends Readable[LogisticRegression] {
+ override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression]
}
/**
@@ -517,61 +523,62 @@ class LogisticRegressionModel private[ml] (
*
* For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
+ *
+ * This also does not save the [[parent]] currently.
*/
- override def write: Writer = new LogisticRegressionWriter(this)
-}
-
-
-/** [[Writer]] instance for [[LogisticRegressionModel]] */
-private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel)
- extends Writer with Logging {
-
- private case class Data(
- numClasses: Int,
- numFeatures: Int,
- intercept: Double,
- coefficients: Vector)
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: numClasses, numFeatures, intercept, coefficients
- val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
- instance.coefficients)
- val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
- }
+ override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
}
object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
- override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader
+ override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader
override def load(path: String): LogisticRegressionModel = read.load(path)
-}
+ /** [[Writer]] instance for [[LogisticRegressionModel]] */
+ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
+ extends Writer with Logging {
+
+ private case class Data(
+ numClasses: Int,
+ numFeatures: Int,
+ intercept: Double,
+ coefficients: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: numClasses, numFeatures, intercept, coefficients
+ val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
+ instance.coefficients)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+ }
+ }
-private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] {
+ private[classification] class LogisticRegressionModelReader
+ extends Reader[LogisticRegressionModel] {
- /** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
- override def load(path: String): LogisticRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ override def load(path: String): LogisticRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
- val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.format("parquet").load(dataPath)
- .select("numClasses", "numFeatures", "intercept", "coefficients").head()
- // We will need numClasses, numFeatures in the future for multinomial logreg support.
- // val numClasses = data.getInt(0)
- // val numFeatures = data.getInt(1)
- val intercept = data.getDouble(2)
- val coefficients = data.getAs[Vector](3)
- val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.format("parquet").load(dataPath)
+ .select("numClasses", "numFeatures", "intercept", "coefficients").head()
+ // We will need numClasses, numFeatures in the future for multinomial logreg support.
+ // val numClasses = data.getInt(0)
+ // val numFeatures = data.getInt(1)
+ val intercept = data.getDouble(2)
+ val coefficients = data.getAs[Vector](3)
+ val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
- DefaultParamsReader.getAndSetParams(model, metadata)
- model
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 3169c9e9af..dddb72af5b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -217,6 +217,7 @@ private[ml] object DefaultParamsWriter {
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @tparam T ML instance type
+ * TODO: Consider adding check for correct class name.
*/
private[ml] class DefaultParamsReader[T] extends Reader[T] {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 484026b1ba..7f5c3895ac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(pipeline2.stages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)
-
- val path = new File(tempDir, pipeline.uid).getPath
- val stagesDir = new Path(path, "stages").toString
- val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
- assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
- s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
- s" to be saved to path: $expectedStagePath")
}
test("PipelineModel read/write: getStagePath") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
new file mode 100644
index 0000000000..d0e3fe7ad1
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+object ClassifierSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "rawPredictionCol" -> "myRawPrediction"
+ )
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 51b06b7eb6..48ce1bb630 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -873,15 +873,34 @@ class LogisticRegressionSuite
}
test("read/write") {
- // Set some Params to make sure set Params are serialized.
+ def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients.toArray === model2.coefficients.toArray)
+ assert(model.numClasses === model2.numClasses)
+ assert(model.numFeatures === model2.numFeatures)
+ }
val lr = new LogisticRegression()
- .setElasticNetParam(0.1)
- .setMaxIter(2)
- .fit(dataset)
- val lr2 = testDefaultReadWrite(lr)
- assert(lr.intercept === lr2.intercept)
- assert(lr.coefficients.toArray === lr2.coefficients.toArray)
- assert(lr.numClasses === lr2.numClasses)
- assert(lr.numFeatures === lr2.numFeatures)
+ testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
+ checkModelData)
}
}
+
+object LogisticRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map(
+ "probabilityCol" -> "myProbability",
+ "thresholds" -> Array(0.4, 0.6),
+ "regParam" -> 0.01,
+ "elasticNetParam" -> 0.1,
+ "maxIter" -> 2, // intentionally small
+ "fitIntercept" -> false,
+ "tol" -> 0.8,
+ "standardization" -> false,
+ "threshold" -> 0.6
+ )
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index fb5f00e064..cfa75ecf38 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
}
}
+
+object ProbabilisticClassifierSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map(
+ "probabilityCol" -> "myProbability",
+ "thresholds" -> Array(0.4, 0.6)
+ )
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index c37f0503f1..dd1e8acce9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -22,13 +22,17 @@ import java.io.{File, IOException}
import org.scalatest.Suite
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Model, Estimator}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/**
* Checks "overwrite" option and params.
+ * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
+ * in order to avoid conflicts from multiple calls to this method.
* @param instance ML instance to test saving/loading
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
@@ -38,7 +42,10 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
instance: T,
testParams: Boolean = true): T = {
val uid = instance.uid
- val path = new File(tempDir, uid).getPath
+ val subdirName = Identifiable.randomUID("test")
+
+ val subdir = new File(tempDir, subdirName)
+ val path = new File(subdir, uid).getPath
instance.save(path)
intercept[IOException] {
@@ -69,6 +76,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
assert(another.uid === instance.uid)
another
}
+
+ /**
+ * Default test for Estimator, Model pairs:
+ * - Explicitly set Params, and train model
+ * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model
+ * - Check Params on Estimator and Model
+ *
+ * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ * @param estimator Estimator to test
+ * @param dataset Dataset to pass to [[Estimator.fit()]]
+ * @param testParams Set of [[Param]] values to set in estimator
+ * @param checkModelData Method which takes the original and loaded [[Model]] and compares their
+ * data. This method does not need to check [[Param]] values.
+ * @tparam E Type of [[Estimator]]
+ * @tparam M Type of [[Model]] produced by estimator
+ */
+ def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable](
+ estimator: E,
+ dataset: DataFrame,
+ testParams: Map[String, Any],
+ checkModelData: (M, M) => Unit): Unit = {
+ // Set some Params to make sure set Params are serialized.
+ testParams.foreach { case (p, v) =>
+ estimator.set(estimator.getParam(p), v)
+ }
+ val model = estimator.fit(dataset)
+
+ // Test Estimator save/load
+ val estimator2 = testDefaultReadWrite(estimator)
+ testParams.foreach { case (p, v) =>
+ val param = estimator.getParam(p)
+ assert(estimator.get(param).get === estimator2.get(param).get)
+ }
+
+ // Test Model save/load
+ val model2 = testDefaultReadWrite(model)
+ testParams.foreach { case (p, v) =>
+ val param = model.getParam(p)
+ assert(model.get(param).get === model2.get(param).get)
+ }
+ }
}
class MyParams(override val uid: String) extends Params with Writable {