From f4f6bd8c9884e3919509907307fda774f56b5ecc Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Thu, 22 Sep 2016 16:34:42 -0700 Subject: [SPARK-16240][ML] ML persistence backward compatibility for LDA ## What changes were proposed in this pull request? Allow Spark 2.x to load instances of LDA, LocalLDAModel, and DistributedLDAModel saved from Spark 1.6. ## How was this patch tested? I tested this manually, saving the 3 types from 1.6 and loading them into master (2.x). In the future, we can add generic tests for testing backwards compatibility across all ML models in SPARK-15573. Author: Joseph K. Bradley Closes #15034 from jkbradley/lda-backwards. --- .../scala/org/apache/spark/ml/clustering/LDA.scala | 86 +++++++++++++++++----- project/MimaExcludes.scala | 4 +- 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index b5a764b586..7773802854 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging @@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} import org.apache.spark.mllib.impl.PeriodicCheckpointer -import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector, - Vectors => OldVectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.VersionUtils private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter @@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } +private object LDAParams { + + /** + * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. + * + * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with + * [[Param]] values extracted from metadata. + * @param metadata Loaded model metadata + */ + def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = { + VersionUtils.majorMinorVersion(metadata.sparkVersion) match { + case (1, 6) => + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val origParam = + if (paramName == "topicDistribution") "topicDistributionCol" else paramName + val param = model.getParam(origParam) + val value = param.jsonDecode(compact(render(jsonValue))) + model.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + case _ => // 2.0+ + DefaultParamsReader.getAndSetParams(model, metadata) + } + } +} + /** * :: Experimental :: @@ -418,11 +458,11 @@ sealed abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset.toDF + dataset.toDF() } } @@ -578,18 +618,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", - "gammaShape") - .head() - val vocabSize = data.getAs[Int](0) - val topicsMatrix = data.getAs[Matrix](1) - val docConcentration = data.getAs[Vector](2) - val topicConcentration = data.getAs[Double](3) - val gammaShape = data.getAs[Double](4) + val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") + val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix") + val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector, + topicConcentration: Double, gammaShape: Double) = + matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration", + "topicConcentration", "gammaShape").head() val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) - DefaultParamsReader.getAndSetParams(model, metadata) + LDAParams.getAndSetParams(model, metadata) model } } @@ -735,9 +773,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) - val model = new DistributedLDAModel( - metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None) - DefaultParamsReader.getAndSetParams(model, metadata) + val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize, + oldModel, sparkSession, None) + LDAParams.getAndSetParams(model, metadata) model } } @@ -885,7 +923,7 @@ class LDA @Since("1.6.0") ( } @Since("2.0.0") -object LDA extends DefaultParamsReadable[LDA] { +object LDA extends MLReadable[LDA] { /** Get dataset for spark.mllib LDA */ private[clustering] def getOldDataset( @@ -900,6 +938,20 @@ object LDA extends DefaultParamsReadable[LDA] { } } + private class LDAReader extends MLReader[LDA] { + + private val className = classOf[LDA].getName + + override def load(path: String): LDA = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val model = new LDA(metadata.uid) + LDAParams.getAndSetParams(model, metadata) + model + } + } + + override def read: MLReader[LDA] = new LDAReader + @Since("2.0.0") override def load(path: String): LDA = super.load(path) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0a56a6b19e..b6f64e5a70 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,9 @@ object MimaExcludes { // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"), // [SPARK-16967] Move Mesos to Module - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), + // [SPARK-16240] ML persistence backward compatibility for LDA + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") ) } -- cgit v1.2.3