aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-11-19 23:43:18 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-19 23:43:18 -0800
commit3e1d120cedb4bd9e1595e95d4d531cf61da6684d (patch)
treee9a26f005cf0df7162a58063208f6aa2ec15a7e2 /mllib/src/main/scala/org
parent0fff8eb3e476165461658d4e16682ec64269fdfe (diff)
downloadspark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.tar.gz
spark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.tar.bz2
spark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.zip
[SPARK-11867] Add save/load for kmeans and naive bayes
https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin <yinxusen@gmail.com> Closes #9849 from yinxusen/SPARK-11867.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala68
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala67
2 files changed, 122 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index a14dcecbaf..c512a2cb8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -17,12 +17,15 @@
package org.apache.spark.ml.classification
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
+import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
@Experimental
class NaiveBayes(override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
- with NaiveBayesParams {
+ with NaiveBayesParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("nb"))
@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}
+@Since("1.6.0")
+object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
+
+ @Since("1.6.0")
+ override def load(path: String): NaiveBayes = super.load(path)
+}
+
/**
* :: Experimental ::
* Model produced by [[NaiveBayes]]
@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
- extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+ extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
+ with NaiveBayesParams with MLWritable {
import OldNaiveBayes.{Bernoulli, Multinomial}
@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
}
+ @Since("1.6.0")
+ override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
}
-private[ml] object NaiveBayesModel {
+@Since("1.6.0")
+object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
/** Convert a model from the old API */
- def fromOld(
+ private[ml] def fromOld(
oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta)
}
+
+ @Since("1.6.0")
+ override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): NaiveBayesModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[NaiveBayesModel]] */
+ private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
+
+ private case class Data(pi: Vector, theta: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: pi, theta
+ val data = Data(instance.pi, instance.theta)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[NaiveBayesModel].getName
+
+ override def load(path: String): NaiveBayesModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
+ val pi = data.getAs[Vector](0)
+ val theta = data.getAs[Matrix](1)
+ val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 509be63002..71e9684975 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,10 +17,12 @@
package org.apache.spark.ml.clustering
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
-
/**
* Common params for KMeans and KMeansModel
*/
@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
- private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
+ private val parentModel: MLlibKMeansModel)
+ extends Model[KMeansModel] with KMeansParams with MLWritable {
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
@@ -129,6 +131,52 @@ class KMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+}
+
+@Since("1.6.0")
+object KMeansModel extends MLReadable[KMeansModel] {
+
+ @Since("1.6.0")
+ override def read: MLReader[KMeansModel] = new KMeansModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): KMeansModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[KMeansModel]] */
+ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
+
+ private case class Data(clusterCenters: Array[Vector])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: cluster centers
+ val data = Data(instance.clusterCenters)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class KMeansModelReader extends MLReader[KMeansModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[KMeansModel].getName
+
+ override def load(path: String): KMeansModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
+ val clusterCenters = data.getAs[Seq[Vector]](0).toArray
+ val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
@@ -141,7 +189,7 @@ class KMeansModel private[ml] (
@Experimental
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
- extends Estimator[KMeansModel] with KMeansParams {
+ extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
setDefault(
k -> 2,
@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
}
}
+@Since("1.6.0")
+object KMeans extends DefaultParamsReadable[KMeans] {
+
+ @Since("1.6.0")
+ override def load(path: String): KMeans = super.load(path)
+}
+