diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-11-10 18:45:48 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-10 18:45:48 -0800 |
commit | 6e101d2e9d6e08a6a63f7065c1e87a5338f763ea (patch) | |
tree | f93c013e57ee3644af985e1c5aae11659269e22e /mllib/src/main/scala/org | |
parent | 745e45d5ff7fe251c0d5197b7e08b1f80807b005 (diff) | |
download | spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.gz spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.tar.bz2 spark-6e101d2e9d6e08a6a63f7065c1e87a5338f763ea.zip |
[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
This PR adds model save/load for spark.ml's LogisticRegressionModel. It also does minor refactoring of the default save/load classes to reuse code.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #9606 from jkbradley/logreg-io2.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 68 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 74 |
2 files changed, 134 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f5fca686df..a88f526741 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,13 +21,14 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams { + with LogisticRegressionParams with Writable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] ( // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + */ + override def write: Writer = new LogisticRegressionWriter(this) +} + + +/** [[Writer]] instance for [[LogisticRegressionModel]] */ +private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel) + extends Writer with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } +} + + +object LogisticRegressionModel extends Readable[LogisticRegressionModel] { + + override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader + + override def load(path: String): LogisticRegressionModel = read.load(path) } + +private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } +} + + /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index cbdf913ba8..85f888c9f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -175,6 +175,21 @@ trait Readable[T] { private[ml] class DefaultParamsWriter(instance: Params) extends Writer { override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + } +} + +private[ml] object DefaultParamsWriter { + + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - paramMap + */ + def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -201,14 +216,61 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer { private[ml] class DefaultParamsReader[T] extends Reader[T] { override def load(path: String): T = { - implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[T] + } +} + +private[ml] object DefaultParamsReader { + + /** + * All info from metadata file. + * @param params paramMap, as a [[JValue]] + * @param metadataStr Full metadata file String (for debugging) + */ + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadataStr: String) + + /** + * Load metadata from file. + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) - val cls = Utils.classForName((metadata \ "class").extract[String]) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] val uid = (metadata \ "uid").extract[String] - val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] - (metadata \ "paramMap") match { + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => val param = instance.getParam(paramName) @@ -216,8 +278,8 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] { instance.set(param, value) } case _ => - throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") } - instance.asInstanceOf[T] } } |