From 045a4f045821dcf60442f0600c2df1b79bddb536 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../spark/ml/regression/LinearRegression.scala | 77 +++++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581..ca55d5915e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fe..2bdc0e184d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } -- cgit v1.2.3