aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGayathri Murali <gayathri.m.softie@gmail.com>2016-09-22 16:34:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-09-22 16:34:42 -0700
commitf4f6bd8c9884e3919509907307fda774f56b5ecc (patch)
treeabb388081543709adff02e6d5c3a9004d18beaf2
parent0d634875026ccf1eaf984996e9460d7673561f80 (diff)
downloadspark-f4f6bd8c9884e3919509907307fda774f56b5ecc.tar.gz
spark-f4f6bd8c9884e3919509907307fda774f56b5ecc.tar.bz2
spark-f4f6bd8c9884e3919509907307fda774f56b5ecc.zip
[SPARK-16240][ML] ML persistence backward compatibility for LDA
## What changes were proposed in this pull request? Allow Spark 2.x to load instances of LDA, LocalLDAModel, and DistributedLDAModel saved from Spark 1.6. ## How was this patch tested? I tested this manually, saving the 3 types from 1.6 and loading them into master (2.x). In the future, we can add generic tests for testing backwards compatibility across all ML models in SPARK-15573. Author: Joseph K. Bradley <joseph@databricks.com> Closes #15034 from jkbradley/lda-backwards.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala86
-rw-r--r--project/MimaExcludes.scala4
2 files changed, 72 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index b5a764b586..7773802854 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -18,6 +18,9 @@
package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.JsonAST.JObject
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
@@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.impl.PeriodicCheckpointer
-import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector,
- Vectors => OldVectors}
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.VersionUtils
private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
@@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Values should be >= 0
* - default = uniformly (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
+ *
* @group param
*/
@Since("1.6.0")
@@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Value should be >= 0
* - default = (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
+ *
* @group param
*/
@Since("1.6.0")
@@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}
+private object LDAParams {
+
+ /**
+ * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
+ * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
+ *
+ * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
+ * [[Param]] values extracted from metadata.
+ * @param metadata Loaded model metadata
+ */
+ def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = {
+ VersionUtils.majorMinorVersion(metadata.sparkVersion) match {
+ case (1, 6) =>
+ implicit val format = DefaultFormats
+ metadata.params match {
+ case JObject(pairs) =>
+ pairs.foreach { case (paramName, jsonValue) =>
+ val origParam =
+ if (paramName == "topicDistribution") "topicDistributionCol" else paramName
+ val param = model.getParam(origParam)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ model.set(param, value)
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
+ }
+ case _ => // 2.0+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ }
+ }
+}
+
/**
* :: Experimental ::
@@ -418,11 +458,11 @@ sealed abstract class LDAModel private[ml] (
val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)
val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
- dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
+ dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
- dataset.toDF
+ dataset.toDF()
}
}
@@ -578,18 +618,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
- .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
- "gammaShape")
- .head()
- val vocabSize = data.getAs[Int](0)
- val topicsMatrix = data.getAs[Matrix](1)
- val docConcentration = data.getAs[Vector](2)
- val topicConcentration = data.getAs[Double](3)
- val gammaShape = data.getAs[Double](4)
+ val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
+ val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix")
+ val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector,
+ topicConcentration: Double, gammaShape: Double) =
+ matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration",
+ "topicConcentration", "gammaShape").head()
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ LDAParams.getAndSetParams(model, metadata)
model
}
}
@@ -735,9 +773,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
- val model = new DistributedLDAModel(
- metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
+ oldModel, sparkSession, None)
+ LDAParams.getAndSetParams(model, metadata)
model
}
}
@@ -885,7 +923,7 @@ class LDA @Since("1.6.0") (
}
@Since("2.0.0")
-object LDA extends DefaultParamsReadable[LDA] {
+object LDA extends MLReadable[LDA] {
/** Get dataset for spark.mllib LDA */
private[clustering] def getOldDataset(
@@ -900,6 +938,20 @@ object LDA extends DefaultParamsReadable[LDA] {
}
}
+ private class LDAReader extends MLReader[LDA] {
+
+ private val className = classOf[LDA].getName
+
+ override def load(path: String): LDA = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val model = new LDA(metadata.uid)
+ LDAParams.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ override def read: MLReader[LDA] = new LDAReader
+
@Since("2.0.0")
override def load(path: String): LDA = super.load(path)
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0a56a6b19e..b6f64e5a70 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -44,7 +44,9 @@ object MimaExcludes {
// [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"),
// [SPARK-16967] Move Mesos to Module
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"),
+ // [SPARK-16240] ML persistence backward compatibility for LDA
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$")
)
}