aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWenjian Huang <nextrush@163.com>2015-11-18 13:06:25 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:06:25 -0800
commit045a4f045821dcf60442f0600c2df1b79bddb536 (patch)
treea2ba86ade1009cef997bf0bfed4e9d4a34152384 /mllib
parent09ad9533d5760652de59fa4830c24cb8667958ac (diff)
downloadspark-045a4f045821dcf60442f0600c2df1b79bddb536.tar.gz
spark-045a4f045821dcf60442f0600c2df1b79bddb536.tar.bz2
spark-045a4f045821dcf60442f0600c2df1b79bddb536.zip
[SPARK-6790][ML] Add spark.ml LinearRegression import/export
This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang <nextrush@163.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #9814 from jkbradley/fayeshine-patch-6790.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala77
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala34
2 files changed, 106 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 913140e581..ca55d5915e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import breeze.stats.distributions.StudentsT
+import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.ml.feature.Instance
@@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
@@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
@Experimental
class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String)
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
- with LinearRegressionParams with Logging {
+ with LinearRegressionParams with Writable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("linReg"))
@@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegression extends Readable[LinearRegression] {
+
+ @Since("1.6.0")
+ override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression]
+
+ @Since("1.6.0")
+ override def load(path: String): LinearRegression = read.load(path)
}
/**
@@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
- with LinearRegressionParams {
+ with LinearRegressionParams with Writable {
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
@@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] (
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
}
+
+ /**
+ * Returns a [[Writer]] instance for this ML instance.
+ *
+ * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
+ * An option to save [[summary]] may be added in the future.
+ *
+ * This also does not save the [[parent]] currently.
+ */
+ @Since("1.6.0")
+ override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegressionModel extends Readable[LinearRegressionModel] {
+
+ @Since("1.6.0")
+ override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): LinearRegressionModel = read.load(path)
+
+ /** [[Writer]] instance for [[LinearRegressionModel]] */
+ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
+ extends Writer with Logging {
+
+ private case class Data(intercept: Double, coefficients: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: intercept, coefficients
+ val data = Data(instance.intercept, instance.coefficients)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+ }
+ }
+
+ private class LinearRegressionModelReader extends Reader[LinearRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+ override def load(path: String): LinearRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.format("parquet").load(dataPath)
+ .select("intercept", "coefficients").head()
+ val intercept = data.getDouble(0)
+ val coefficients = data.getAs[Vector](1)
+ val model = new LinearRegressionModel(metadata.uid, coefficients, intercept)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index a1d86fe8fe..2bdc0e184d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,14 +22,15 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
private val seed: Int = 42
@transient var datasetWithDenseFeature: DataFrame = _
@@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
}
+
+ test("read/write") {
+ def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = {
+ assert(model.intercept === model2.intercept)
+ assert(model.coefficients === model2.coefficients)
+ }
+ val lr = new LinearRegression()
+ testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
+ checkModelData)
+ }
+}
+
+object LinearRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "regParam" -> 0.01,
+ "elasticNetParam" -> 0.1,
+ "maxIter" -> 2, // intentionally small
+ "fitIntercept" -> true,
+ "tol" -> 0.8,
+ "standardization" -> false,
+ "solver" -> "l-bfgs"
+ )
}