aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-11-19 22:01:02 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-19 22:01:02 -0800
commit4114ce20fbe820f111e55e891ae3889b0e6e0006 (patch)
tree424724003d17d9eb2d09327a5d0f1e79613a9ed1 /mllib
parent7ee7d5a3c4ff77d2cee2afce36ff41f6302e6315 (diff)
downloadspark-4114ce20fbe820f111e55e891ae3889b0e6e0006.tar.gz
spark-4114ce20fbe820f111e55e891ae3889b0e6e0006.tar.bz2
spark-4114ce20fbe820f111e55e891ae3889b0e6e0006.zip
[SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression
https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin <yinxusen@gmail.com> Closes #9836 from yinxusen/SPARK-11846.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala78
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala83
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala34
4 files changed, 210 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index b7d095872f..aedfb48058 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -21,20 +21,20 @@ import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
+import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkException, Logging}
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.{Model, Estimator}
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
-import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.{Logging, SparkException}
/**
* Params for accelerated failure time (AFT) regression.
@@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
@Experimental
@Since("1.6.0")
class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
- extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging {
+ extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
+ with DefaultParamsWritable with Logging {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("aftSurvReg"))
@@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
}
+@Since("1.6.0")
+object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] {
+
+ @Since("1.6.0")
+ override def load(path: String): AFTSurvivalRegression = super.load(path)
+}
+
/**
* :: Experimental ::
* Model produced by [[AFTSurvivalRegression]].
@@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0") val coefficients: Vector,
@Since("1.6.0") val intercept: Double,
@Since("1.6.0") val scale: Double)
- extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams {
+ extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
/** @group setParam */
@Since("1.6.0")
@@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] (
copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra)
.setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter =
+ new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
+}
+
+@Since("1.6.0")
+object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] {
+
+ @Since("1.6.0")
+ override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): AFTSurvivalRegressionModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */
+ private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter (
+ instance: AFTSurvivalRegressionModel
+ ) extends MLWriter with Logging {
+
+ private case class Data(coefficients: Vector, intercept: Double, scale: Double)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: coefficients, intercept, scale
+ val data = Data(instance.coefficients, instance.intercept, instance.scale)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[AFTSurvivalRegressionModel].getName
+
+ override def load(path: String): AFTSurvivalRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("coefficients", "intercept", "scale").head()
+ val coefficients = data.getAs[Vector](0)
+ val intercept = data.getDouble(1)
+ val scale = data.getDouble(2)
+ val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index a1fe01b047..bbb1c7ac0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -17,18 +17,22 @@
package org.apache.spark.ml.regression
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel}
+import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
+import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
/**
@@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
@Since("1.5.0")
@Experimental
class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String)
- extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase {
+ extends Estimator[IsotonicRegressionModel]
+ with IsotonicRegressionBase with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("isoReg"))
@@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
}
}
+@Since("1.6.0")
+object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] {
+
+ @Since("1.6.0")
+ override def load(path: String): IsotonicRegression = super.load(path)
+}
+
/**
* :: Experimental ::
* Model fitted by IsotonicRegression.
@@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
class IsotonicRegressionModel private[ml] (
override val uid: String,
private val oldModel: MLlibIsotonicRegressionModel)
- extends Model[IsotonicRegressionModel] with IsotonicRegressionBase {
+ extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {
/** @group setParam */
@Since("1.5.0")
@@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter =
+ new IsotonicRegressionModelWriter(this)
+}
+
+@Since("1.6.0")
+object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
+
+ @Since("1.6.0")
+ override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): IsotonicRegressionModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */
+ private[IsotonicRegressionModel] class IsotonicRegressionModelWriter (
+ instance: IsotonicRegressionModel
+ ) extends MLWriter with Logging {
+
+ private case class Data(
+ boundaries: Array[Double],
+ predictions: Array[Double],
+ isotonic: Boolean)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: boundaries, predictions, isotonic
+ val data = Data(
+ instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[IsotonicRegressionModel].getName
+
+ override def load(path: String): IsotonicRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("boundaries", "predictions", "isotonic").head()
+ val boundaries = data.getAs[Seq[Double]](0).toArray
+ val predictions = data.getAs[Seq[Double]](1).toArray
+ val isotonic = data.getBoolean(2)
+ val model = new IsotonicRegressionModel(
+ metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic))
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 359f310271..d718ef63b5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -21,14 +21,15 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row}
-class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class AFTSurvivalRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
@@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
assert(prediction ~== model.predict(features) relTol 1E-5)
}
}
+
+ test("read/write") {
+ def checkModelData(
+ model: AFTSurvivalRegressionModel,
+ model2: AFTSurvivalRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients === model2.coefficients)
+ assert(model.scale === model2.scale)
+ }
+ val aft = new AFTSurvivalRegression()
+ testEstimatorAndModelReadWrite(aft, datasetMultivariate,
+ AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
+ }
+}
+
+object AFTSurvivalRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "fitIntercept" -> true,
+ "maxIter" -> 2,
+ "tol" -> 0.01
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 59f4193abc..f067c29d27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -19,12 +19,14 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class IsotonicRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
sqlContext.createDataFrame(
labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
@@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
}
+
+ test("read/write") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
+
+ def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = {
+ assert(model.boundaries === model2.boundaries)
+ assert(model.predictions === model2.predictions)
+ assert(model.isotonic === model2.isotonic)
+ }
+
+ val ir = new IsotonicRegression()
+ testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
+ checkModelData)
+ }
+}
+
+object IsotonicRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "isotonic" -> true,
+ "featureIndex" -> 0
+ )
}