aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala68
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala67
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala47
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala41
4 files changed, 195 insertions, 28 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index a14dcecbaf..c512a2cb8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -17,12 +17,15 @@
package org.apache.spark.ml.classification
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
+import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
@Experimental
class NaiveBayes(override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
- with NaiveBayesParams {
+ with NaiveBayesParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("nb"))
@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}
+@Since("1.6.0")
+object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
+
+ @Since("1.6.0")
+ override def load(path: String): NaiveBayes = super.load(path)
+}
+
/**
* :: Experimental ::
* Model produced by [[NaiveBayes]]
@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
- extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+ extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
+ with NaiveBayesParams with MLWritable {
import OldNaiveBayes.{Bernoulli, Multinomial}
@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
}
+ @Since("1.6.0")
+ override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
}
-private[ml] object NaiveBayesModel {
+@Since("1.6.0")
+object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
/** Convert a model from the old API */
- def fromOld(
+ private[ml] def fromOld(
oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta)
}
+
+ @Since("1.6.0")
+ override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): NaiveBayesModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[NaiveBayesModel]] */
+ private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
+
+ private case class Data(pi: Vector, theta: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: pi, theta
+ val data = Data(instance.pi, instance.theta)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[NaiveBayesModel].getName
+
+ override def load(path: String): NaiveBayesModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
+ val pi = data.getAs[Vector](0)
+ val theta = data.getAs[Matrix](1)
+ val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 509be63002..71e9684975 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,10 +17,12 @@
package org.apache.spark.ml.clustering
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
-
/**
* Common params for KMeans and KMeansModel
*/
@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
- private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
+ private val parentModel: MLlibKMeansModel)
+ extends Model[KMeansModel] with KMeansParams with MLWritable {
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
@@ -129,6 +131,52 @@ class KMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+}
+
+@Since("1.6.0")
+object KMeansModel extends MLReadable[KMeansModel] {
+
+ @Since("1.6.0")
+ override def read: MLReader[KMeansModel] = new KMeansModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): KMeansModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[KMeansModel]] */
+ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
+
+ private case class Data(clusterCenters: Array[Vector])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: cluster centers
+ val data = Data(instance.clusterCenters)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class KMeansModelReader extends MLReader[KMeansModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[KMeansModel].getName
+
+ override def load(path: String): KMeansModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
+ val clusterCenters = data.getAs[Seq[Vector]](0).toArray
+ val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
@@ -141,7 +189,7 @@ class KMeansModel private[ml] (
@Experimental
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
- extends Estimator[KMeansModel] with KMeansParams {
+ extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
setDefault(
k -> 2,
@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
}
}
+@Since("1.6.0")
+object KMeans extends DefaultParamsReadable[KMeans] {
+
+ @Since("1.6.0")
+ override def load(path: String): KMeans = super.load(path)
+}
+
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 98bc951116..082a6bcd21 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
+import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.classification.NaiveBayesSuite._
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
+
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val pi = Array(0.5, 0.1, 0.4).map(math.log)
+ val theta = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
-class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+ dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
+ }
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
+
+ test("read/write") {
+ def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
+ assert(model.pi === model2.pi)
+ assert(model.theta === model2.theta)
+ }
+ val nb = new NaiveBayes()
+ testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+ }
+}
+
+object NaiveBayesSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "smoothing" -> 0.1
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c05f90550d..2724e51f31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
private[clustering] case class TestRow(features: Vector)
-object KMeansSuite {
- def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
- val sc = sql.sparkContext
- val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
- .map(v => new TestRow(v))
- sql.createDataFrame(rdd)
- }
-}
-
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
@transient var dataset: DataFrame = _
@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
}
+
+ test("read/write") {
+ def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val kmeans = new KMeans()
+ testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object KMeansSuite {
+ def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
+ val sc = sql.sparkContext
+ val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
+ .map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "k" -> 3,
+ "maxIter" -> 2,
+ "tol" -> 0.01
+ )
}