aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-11-19 23:43:18 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-19 23:43:18 -0800
commit3e1d120cedb4bd9e1595e95d4d531cf61da6684d (patch)
treee9a26f005cf0df7162a58063208f6aa2ec15a7e2 /mllib/src/test/scala/org/apache
parent0fff8eb3e476165461658d4e16682ec64269fdfe (diff)
downloadspark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.tar.gz
spark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.tar.bz2
spark-3e1d120cedb4bd9e1595e95d4d531cf61da6684d.zip
[SPARK-11867] Add save/load for kmeans and naive bayes
https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin <yinxusen@gmail.com> Closes #9849 from yinxusen/SPARK-11867.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala47
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala41
2 files changed, 73 insertions, 15 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 98bc951116..082a6bcd21 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
+import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.classification.NaiveBayesSuite._
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
+
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val pi = Array(0.5, 0.1, 0.4).map(math.log)
+ val theta = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
-class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+ dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
+ }
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
+
+ test("read/write") {
+ def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
+ assert(model.pi === model2.pi)
+ assert(model.theta === model2.theta)
+ }
+ val nb = new NaiveBayes()
+ testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+ }
+}
+
+object NaiveBayesSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "smoothing" -> 0.1
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c05f90550d..2724e51f31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
private[clustering] case class TestRow(features: Vector)
-object KMeansSuite {
- def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
- val sc = sql.sparkContext
- val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
- .map(v => new TestRow(v))
- sql.createDataFrame(rdd)
- }
-}
-
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
@transient var dataset: DataFrame = _
@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
}
+
+ test("read/write") {
+ def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val kmeans = new KMeans()
+ testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object KMeansSuite {
+ def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
+ val sc = sql.sparkContext
+ val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
+ .map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "k" -> 3,
+ "maxIter" -> 2,
+ "tol" -> 0.01
+ )
}