From 3e1d120cedb4bd9e1595e95d4d531cf61da6684d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 23:43:18 -0800 Subject: [SPARK-11867] Add save/load for kmeans and naive bayes https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin Closes #9849 from yinxusen/SPARK-11867. --- .../spark/ml/classification/NaiveBayes.scala | 68 +++++++++++++++++++--- .../org/apache/spark/ml/clustering/KMeans.scala | 67 +++++++++++++++++++-- 2 files changed, 122 insertions(+), 13 deletions(-) (limited to 'mllib/src/main/scala/org') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf..c512a2cb8b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { @Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("nb")) @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be63002..71e9684975 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel */ @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +131,52 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +189,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + -- cgit v1.2.3