aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-10 18:45:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-10 18:45:48 -0800
commit6e101d2e9d6e08a6a63f7065c1e87a5338f763ea (patch)
treef93c013e57ee3644af985e1c5aae11659269e22e /mllib/src/main/scala/org
parent745e45d5ff7fe251c0d5197b7e08b1f80807b005 (diff)
downloadspark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.gz
spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.bz2
spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.zip
[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
This PR adds model save/load for spark.ml's LogisticRegressionModel. It also does minor refactoring of the default save/load classes to reuse code. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9606 from jkbradley/logreg-io2.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala68
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala74
2 files changed, 134 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f5fca686df..a88f526741 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,13 +21,14 @@ import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
+import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
@@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
- with LogisticRegressionParams {
+ with LogisticRegressionParams with Writable {
@deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients
@@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] (
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}
+
+ /**
+ * Returns a [[Writer]] instance for this ML instance.
+ *
+ * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
+ * An option to save [[summary]] may be added in the future.
+ */
+ override def write: Writer = new LogisticRegressionWriter(this)
+}
+
+
+/** [[Writer]] instance for [[LogisticRegressionModel]] */
+private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel)
+ extends Writer with Logging {
+
+ private case class Data(
+ numClasses: Int,
+ numFeatures: Int,
+ intercept: Double,
+ coefficients: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: numClasses, numFeatures, intercept, coefficients
+ val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
+ instance.coefficients)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+ }
+}
+
+
+object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
+
+ override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader
+
+ override def load(path: String): LogisticRegressionModel = read.load(path)
}
+
+private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
+
+ override def load(path: String): LogisticRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.format("parquet").load(dataPath)
+ .select("numClasses", "numFeatures", "intercept", "coefficients").head()
+ // We will need numClasses, numFeatures in the future for multinomial logreg support.
+ // val numClasses = data.getInt(0)
+ // val numFeatures = data.getInt(1)
+ val intercept = data.getDouble(2)
+ val coefficients = data.getAs[Vector](3)
+ val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+}
+
+
/**
* MultiClassSummarizer computes the number of distinct labels and corresponding counts,
* and validates the data to see if the labels used for k class multi-label classification
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index cbdf913ba8..85f888c9f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -175,6 +175,21 @@ trait Readable[T] {
private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ }
+}
+
+private[ml] object DefaultParamsWriter {
+
+ /**
+ * Saves metadata + Params to: path + "/metadata"
+ * - class
+ * - timestamp
+ * - sparkVersion
+ * - uid
+ * - paramMap
+ */
+ def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -201,14 +216,61 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
private[ml] class DefaultParamsReader[T] extends Reader[T] {
override def load(path: String): T = {
- implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sc)
+ val cls = Utils.classForName(metadata.className)
+ val instance =
+ cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
+ DefaultParamsReader.getAndSetParams(instance, metadata)
+ instance.asInstanceOf[T]
+ }
+}
+
+private[ml] object DefaultParamsReader {
+
+ /**
+ * All info from metadata file.
+ * @param params paramMap, as a [[JValue]]
+ * @param metadataStr Full metadata file String (for debugging)
+ */
+ case class Metadata(
+ className: String,
+ uid: String,
+ timestamp: Long,
+ sparkVersion: String,
+ params: JValue,
+ metadataStr: String)
+
+ /**
+ * Load metadata from file.
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
- val cls = Utils.classForName((metadata \ "class").extract[String])
+
+ implicit val format = DefaultFormats
+ val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
- val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
- (metadata \ "paramMap") match {
+ val timestamp = (metadata \ "timestamp").extract[Long]
+ val sparkVersion = (metadata \ "sparkVersion").extract[String]
+ val params = metadata \ "paramMap"
+ if (expectedClassName.nonEmpty) {
+ require(className == expectedClassName, s"Error loading metadata: Expected class name" +
+ s" $expectedClassName but found class name $className")
+ }
+
+ Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
+ }
+
+ /**
+ * Extract Params from metadata, and set them in the instance.
+ * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
+ */
+ def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
+ implicit val format = DefaultFormats
+ metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
@@ -216,8 +278,8 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] {
instance.set(param, value)
}
case _ =>
- throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
- instance.asInstanceOf[T]
}
}