aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-31 11:17:32 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-31 11:17:32 -0700
commit8b207f3b6a0eb617d38091f3b9001830ac3651fe (patch)
treeb3bca571692fd67c3ae40d4e3af29a4ddecd056d /mllib/src
parenta0a1991580ed24230f88cae9f5a4dfbe58f03b28 (diff)
downloadspark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.tar.gz
spark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.tar.bz2
spark-8b207f3b6a0eb617d38091f3b9001830ac3651fe.zip
[SPARK-11892][ML] Model export/import for spark.ml: OneVsRest
# What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11892 Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite. # How was this patch tested? Test with Scala unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #9934 from yinxusen/SPARK-11892.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala165
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala68
3 files changed, 223 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c41a611f1c..98b99a3485 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,22 +21,24 @@ import java.util.UUID
import scala.language.existentials
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject, _}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
+private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
@@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams {
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
+}
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}
+private[ml] object OneVsRestParams extends ClassifierTypeTrait {
+
+ def validateParams(instance: OneVsRestParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("OneVsRest write will fail " +
+ s" because it contains $name which does not implement MLWritable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+
+ instance match {
+ case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
+ case _ => // no need to check OneVsRest here
+ }
+
+ checkElement(instance.getClassifier, "classifier")
+ }
+
+ def saveImpl(
+ path: String,
+ instance: OneVsRestParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => p.name != "classifier" }
+ .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
+ .toList)
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val classifierPath = new Path(path, "classifier").toString
+ instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
+ }
+
+ def loadImpl(
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+ val classifierPath = new Path(path, "classifier").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
+ (metadata, estimator)
+ }
+}
+
/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
@@ -73,10 +130,10 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
- @Since("1.4.0") override val uid: String,
- @Since("1.4.0") labelMetadata: Metadata,
+ @Since("1.4.0") override val uid: String,
+ private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
- extends Model[OneVsRestModel] with OneVsRestParams {
+ extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
@@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRestModel extends MLReadable[OneVsRestModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRestModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRestModel]] */
+ private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
+ ("numClasses" -> instance.models.length)
+ OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
+ instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ model.save(modelPath)
+ }
+ }
+ }
+
+ private class OneVsRestModelReader extends MLReader[OneVsRestModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRestModel].getName
+
+ override def load(path: String): OneVsRestModel = {
+ implicit val format = DefaultFormats
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val models = Range(0, numClasses).toArray.map { idx =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
+ }
+ val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
+ DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+ ovrModel.set("classifier", classifier)
+ ovrModel
+ }
+ }
}
/**
@@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Estimator[OneVsRestModel] with OneVsRestParams {
+ extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -243,4 +350,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRest extends MLReadable[OneVsRest] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRest] = new OneVsRestReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRest = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRest]] */
+ private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ OneVsRestParams.saveImpl(path, instance, sc)
+ }
+ }
+
+ private class OneVsRestReader extends MLReader[OneVsRest] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRest].getName
+
+ override def load(path: String): OneVsRest = {
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val ovr = new OneVsRest(metadata.uid)
+ DefaultParamsReader.getAndSetParams(ovr, metadata)
+ ovr.setClassifier(classifier)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 5a596cad06..39999ede30 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
@@ -381,10 +381,8 @@ private[ml] object MetaAlgorithmReadWrite {
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
- case ovr: OneVsRestParams =>
- // TODO: SPARK-11892: This case may require special handling.
- throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" +
- s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.")
+ case ovr: OneVsRest => Array(ovr.getClassifier)
+ case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
case _: Params => Array()
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2090..51c1baf682 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
+
+ test("read/write: OneVsRest") {
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+ val ova = new OneVsRest()
+ .setClassifier(lr)
+ .setLabelCol("myLabel")
+ .setFeaturesCol("myFeature")
+ .setPredictionCol("myPrediction")
+
+ val ova2 = testDefaultReadWrite(ova, testParams = false)
+ assert(ova.uid === ova2.uid)
+ assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+ assert(ova.getLabelCol === ova2.getLabelCol)
+ assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+ ova2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ assert(lr.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: OneVsRestModel") {
+ def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+ assert(model.uid === model2.uid)
+ assert(model.getFeaturesCol === model2.getFeaturesCol)
+ assert(model.getLabelCol === model2.getLabelCol)
+ assert(model.getPredictionCol === model2.getPredictionCol)
+
+ val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+ model2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(classifier.uid === lr2.uid)
+ assert(classifier.getMaxIter === lr2.getMaxIter)
+ assert(classifier.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ assert(model.labelMetadata === model2.labelMetadata)
+ model.models.zip(model2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.uid === lrModel2.uid)
+ assert(lrModel1.coefficients === lrModel2.coefficients)
+ assert(lrModel1.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ }
+
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+ val ova = new OneVsRest().setClassifier(lr)
+ val ovaModel = ova.fit(dataset)
+ val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+ checkModelData(ovaModel, newOvaModel)
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {