aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-09 11:59:22 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-03-09 11:59:22 -0800
commit0dd06485c4222a896c0d1ee6a04d30043de3626c (patch)
treef1ffc18c78e4dcb4caf3872c34a4a6fc2616e223 /mllib
parentcad29a40b24a8e89f2d906e263866546f8ab6071 (diff)
downloadspark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.gz
spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.tar.bz2
spark-0dd06485c4222a896c0d1ee6a04d30043de3626c.zip
[SPARK-13615][ML] GeneralizedLinearRegression supports save/load
## What changes were proposed in this pull request? ```GeneralizedLinearRegression``` supports ```save/load```. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #11465 from yanboliang/spark-13615.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala74
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala32
2 files changed, 96 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index a850dfee0a..de1dff9421 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression
import breeze.stats.distributions.{Gaussian => GD}
+import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.{Experimental, Since}
@@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
- with GeneralizedLinearRegressionBase with Logging {
+ with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
import GeneralizedLinearRegression._
@@ -236,10 +237,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
@Since("2.0.0")
-private[ml] object GeneralizedLinearRegression {
+object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] {
+
+ @Since("2.0.0")
+ override def load(path: String): GeneralizedLinearRegression = super.load(path)
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
- lazy val supportedFamilyAndLinkPairs = Set(
+ private[ml] lazy val supportedFamilyAndLinkPairs = Set(
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
@@ -247,12 +251,12 @@ private[ml] object GeneralizedLinearRegression {
)
/** Set of family names that GeneralizedLinearRegression supports. */
- lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
+ private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
/** Set of link names that GeneralizedLinearRegression supports. */
- lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
+ private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
- val epsilon: Double = 1E-16
+ private[ml] val epsilon: Double = 1E-16
/**
* Wrapper of family and link combination used in the model.
@@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("2.0.0") val intercept: Double)
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
- with GeneralizedLinearRegressionBase {
+ with GeneralizedLinearRegressionBase with MLWritable {
import GeneralizedLinearRegression._
@@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
.setParent(parent)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
+}
+
+@Since("2.0.0")
+object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GeneralizedLinearRegressionModel] =
+ new GeneralizedLinearRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GeneralizedLinearRegressionModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */
+ private[GeneralizedLinearRegressionModel]
+ class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel)
+ extends MLWriter with Logging {
+
+ private case class Data(intercept: Double, coefficients: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: intercept, coefficients
+ val data = Data(instance.intercept, instance.coefficients)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class GeneralizedLinearRegressionModelReader
+ extends MLReader[GeneralizedLinearRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GeneralizedLinearRegressionModel].getName
+
+ override def load(path: String): GeneralizedLinearRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("intercept", "coefficients").head()
+ val intercept = data.getDouble(0)
+ val coefficients = data.getAs[Vector](1)
+
+ val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 8bfa9855ce..618304ad19 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
import org.apache.spark.mllib.random._
@@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GeneralizedLinearRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private val seed: Int = 42
@transient var datasetGaussianIdentity: DataFrame = _
@@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
}
}
}
+
+ test("read/write") {
+ def checkModelData(
+ model: GeneralizedLinearRegressionModel,
+ model2: GeneralizedLinearRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients.toArray === model2.coefficients.toArray)
+ }
+
+ val glr = new GeneralizedLinearRegression()
+ testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+ GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
+ }
}
object GeneralizedLinearRegressionSuite {
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "family" -> "poisson",
+ "link" -> "log",
+ "fitIntercept" -> true,
+ "maxIter" -> 2, // intentionally small
+ "tol" -> 0.8,
+ "regParam" -> 0.01,
+ "predictionCol" -> "myPrediction")
+
def generateGeneralizedLinearRegressionInput(
intercept: Double,
coefficients: Array[Double],