aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-17 10:19:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-17 10:19:10 -0700
commitedf8b8775b81f5522680094bf24f372aa0c61447 (patch)
treea3c4255165d1674a23bacccdfa66e8f903714717 /mllib
parent828213d4ca4b0e845c4d6d778455335f187158a4 (diff)
downloadspark-edf8b8775b81f5522680094bf24f372aa0c61447.tar.gz
spark-edf8b8775b81f5522680094bf24f372aa0c61447.tar.bz2
spark-edf8b8775b81f5522680094bf24f372aa0c61447.zip
[SPARK-11891] Model export/import for RFormula and RFormulaModel
https://issues.apache.org/jira/browse/SPARK-11891 Author: Xusen Yin <yinxusen@gmail.com> Closes #9884 from yinxusen/SPARK-11891.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala179
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala40
3 files changed, 207 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index ab5f4a1a9a..e7ca7ada74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types._
@@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
* will be created from the specified response variable in the formula.
*/
@Experimental
-class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
+class RFormula(override val uid: String)
+ extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("rFormula"))
@@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
}
+@Since("2.0.0")
+object RFormula extends DefaultParamsReadable[RFormula] {
+
+ @Since("2.0.0")
+ override def load(path: String): RFormula = super.load(path)
+}
+
/**
* :: Experimental ::
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
@@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
@Experimental
class RFormulaModel private[feature](
override val uid: String,
- resolvedFormula: ResolvedRFormula,
- pipelineModel: PipelineModel)
- extends Model[RFormulaModel] with RFormulaBase {
+ private[ml] val resolvedFormula: ResolvedRFormula,
+ private[ml] val pipelineModel: PipelineModel)
+ extends Model[RFormulaModel] with RFormulaBase with MLWritable {
override def transform(dataset: DataFrame): DataFrame = {
checkCanTransform(dataset.schema)
@@ -246,14 +256,71 @@ class RFormulaModel private[feature](
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
"Label column already exists and is not of type DoubleType.")
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
+}
+
+@Since("2.0.0")
+object RFormulaModel extends MLReadable[RFormulaModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RFormulaModel] = new RFormulaModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RFormulaModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[RFormulaModel]] */
+ private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: resolvedFormula
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(instance.resolvedFormula))
+ .repartition(1).write.parquet(dataPath)
+ // Save pipeline model
+ val pmPath = new Path(path, "pipelineModel").toString
+ instance.pipelineModel.save(pmPath)
+ }
+ }
+
+ private class RFormulaModelReader extends MLReader[RFormulaModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RFormulaModel].getName
+
+ override def load(path: String): RFormulaModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
+ val label = data.getString(0)
+ val terms = data.getAs[Seq[Seq[String]]](1)
+ val hasIntercept = data.getBoolean(2)
+ val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
+
+ val pmPath = new Path(path, "pipelineModel").toString
+ val pipelineModel = PipelineModel.load(pmPath)
+
+ val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
* Utility transformer for removing temporary columns from a DataFrame.
* TODO(ekl) make this a public transformer
*/
-private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
- override val uid = Identifiable.randomUID("columnPruner")
+private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
+ extends Transformer with MLWritable {
+
+ def this(columnsToPrune: Set[String]) =
+ this(Identifiable.randomUID("columnPruner"), columnsToPrune)
override def transform(dataset: DataFrame): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
@@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
+
+ override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
+}
+
+private object ColumnPruner extends MLReadable[ColumnPruner] {
+
+ override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
+
+ override def load(path: String): ColumnPruner = super.load(path)
+
+ /** [[MLWriter]] instance for [[ColumnPruner]] */
+ private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {
+
+ private case class Data(columnsToPrune: Seq[String])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: columnsToPrune
+ val data = Data(instance.columnsToPrune.toSeq)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class ColumnPrunerReader extends MLReader[ColumnPruner] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[ColumnPruner].getName
+
+ override def load(path: String): ColumnPruner = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
+ val columnsToPrune = data.getAs[Seq[String]](0).toSet
+ val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
+
+ DefaultParamsReader.getAndSetParams(pruner, metadata)
+ pruner
+ }
+ }
}
/**
@@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
* by the value in the map.
*/
private class VectorAttributeRewriter(
- vectorCol: String,
- prefixesToRewrite: Map[String, String])
- extends Transformer {
+ override val uid: String,
+ val vectorCol: String,
+ val prefixesToRewrite: Map[String, String])
+ extends Transformer with MLWritable {
- override val uid = Identifiable.randomUID("vectorAttrRewriter")
+ def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
+ this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
override def transform(dataset: DataFrame): DataFrame = {
val metadata = {
@@ -315,4 +426,48 @@ private class VectorAttributeRewriter(
}
override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
+
+ override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
+}
+
+private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {
+
+ override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader
+
+ override def load(path: String): VectorAttributeRewriter = super.load(path)
+
+ /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
+ private[VectorAttributeRewriter]
+ class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {
+
+ private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: vectorCol, prefixesToRewrite
+ val data = Data(instance.vectorCol, instance.prefixesToRewrite)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[VectorAttributeRewriter].getName
+
+ override def load(path: String): VectorAttributeRewriter = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
+ val vectorCol = data.getString(0)
+ val prefixesToRewrite = data.getAs[Map[String, String]](1)
+ val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
+
+ DefaultParamsReader.getAndSetParams(rewriter, metadata)
+ rewriter
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 010e7d2686..3d7a91dd39 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -221,10 +221,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
// TODO: SPARK-11892: This case may require special handling.
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
" cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
- case rform: RFormulaModel =>
- // TODO: SPARK-11891: This case may require special handling.
- throw new UnsupportedOperationException("CrossValidator write will fail because it" +
- " cannot yet handle an estimator containing an RFormulaModel")
+ case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
case _: Params => Array()
}
val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 16e565d8b5..e1b269b5b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RFormula())
}
@@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
assert(attrs === expectedAttrs)
}
+
+ test("read/write: RFormula") {
+ val rFormula = new RFormula()
+ .setFormula("id ~ a:b")
+ .setFeaturesCol("myFeatures")
+ .setLabelCol("myLabels")
+
+ testDefaultReadWrite(rFormula)
+ }
+
+ test("read/write: RFormulaModel") {
+ def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = {
+ assert(model.uid === model2.uid)
+
+ assert(model.resolvedFormula.label === model2.resolvedFormula.label)
+ assert(model.resolvedFormula.terms === model2.resolvedFormula.terms)
+ assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept)
+
+ assert(model.pipelineModel.uid === model2.pipelineModel.uid)
+
+ model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach {
+ case (transformer1, transformer2) =>
+ assert(transformer1.uid === transformer2.uid)
+ assert(transformer1.params === transformer2.params)
+ }
+ }
+
+ val dataset = sqlContext.createDataFrame(
+ Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+ ).toDF("id", "a", "b")
+
+ val rFormula = new RFormula().setFormula("id ~ a:b")
+
+ val model = rFormula.fit(dataset)
+ val newModel = testDefaultReadWrite(model)
+ checkModelData(model, newModel)
+ }
}