aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-04 22:46:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-04 22:46:48 -0800
commit975bcef467b35586e5224171071355409f451d2d (patch)
tree47f11210ae0b4c3f920752dcec61ddf4683c3bca /mllib/src/test
parentc23ac03c8c27e840498a192b088e00b27076765f (diff)
downloadspark-975bcef467b35586e5224171071355409f451d2d.tar.gz
spark-975bcef467b35586e5224171071355409f451d2d.tar.bz2
spark-975bcef467b35586e5224171071355409f451d2d.zip
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes
This is a PR for Parquet-based model import/export. Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587). Note: This includes only a subset of regression and classification models: * NaiveBayes, SVM, LogisticRegression * LinearRegression, RidgeRegression, Lasso Follow-up PRs will cover other models. Sketch of current contents: * New traits: Saveable, Loader * Implementations for some algorithms * Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold) CC: mengxr selvinsource Author: Joseph K. Bradley <joseph@databricks.com> Closes #4233 from jkbradley/ml-import-export and squashes the following commits: 87c4eb8 [Joseph K. Bradley] small cleanups 12d9059 [Joseph K. Bradley] Many cleanups after code review. Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification. Renamed concepts to Saveable, Loader a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export ee99228 [Joseph K. Bradley] scala style fix 79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR d1e5882 [Joseph K. Bradley] organized imports 2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models c495dba [Joseph K. Bradley] made version for model import/export local to each model 1496852 [Joseph K. Bradley] Added save/load for NaiveBayes 8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes 1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch) 64914a3 [Joseph K. Bradley] added getThreshold to SVMModel b1fc5ec [Joseph K. Bradley] small cleanups 418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala70
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala24
6 files changed, 210 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 3fb45938f7..d2b40f2cae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.mllib.classification
-import scala.util.control.Breaks._
-import scala.util.Random
import scala.collection.JavaConversions._
+import scala.util.Random
+import scala.util.control.Breaks._
import org.scalatest.FunSuite
import org.scalatest.Matchers
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
object LogisticRegressionSuite {
@@ -147,8 +149,25 @@ object LogisticRegressionSuite {
val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
testData
}
+
+ /** Binary labels, 3 features */
+ private val binaryModel = new LogisticRegressionModel(
+ weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2)
+
+ /** 3 classes, 2 features */
+ private val multiclassModel = new LogisticRegressionModel(
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3)
+
+ private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = {
+ assert(a.weights == b.weights)
+ assert(a.intercept == b.intercept)
+ assert(a.numClasses == b.numClasses)
+ assert(a.numFeatures == b.numFeatures)
+ assert(a.getThreshold == b.getThreshold)
+ }
}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
@@ -462,6 +481,53 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}
+ test("model save/load: binary classification") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = LogisticRegressionSuite.binaryModel
+
+ model.clearThreshold()
+ assert(model.getThreshold.isEmpty)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ // Save model with threshold.
+ try {
+ model.setThreshold(0.7)
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load: multiclass classification") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = LogisticRegressionSuite.multiclassModel
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
}
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index e68fe89d6c..64dcc0fb9f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -25,6 +25,8 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
object NaiveBayesSuite {
@@ -58,6 +60,18 @@ object NaiveBayesSuite {
LabeledPoint(y, Vectors.dense(xi))
}
}
+
+ private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)
+
+ private val smallTheta = Array(
+ Array(0.91, 0.03, 0.03, 0.03), // label 0
+ Array(0.03, 0.91, 0.03, 0.03), // label 1
+ Array(0.03, 0.03, 0.91, 0.03) // label 2
+ ).map(_.map(math.log))
+
+ /** Binary labels, 3 features */
+ private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
+ theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)))
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -74,12 +88,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
test("Naive Bayes") {
val nPoints = 10000
- val pi = Array(0.5, 0.3, 0.2).map(math.log)
- val theta = Array(
- Array(0.91, 0.03, 0.03, 0.03), // label 0
- Array(0.03, 0.91, 0.03, 0.03), // label 1
- Array(0.03, 0.03, 0.91, 0.03) // label 2
- ).map(_.map(math.log))
+ val pi = NaiveBayesSuite.smallPi
+ val theta = NaiveBayesSuite.smallTheta
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
@@ -123,6 +133,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
NaiveBayes.train(sc.makeRDD(nan, 2))
}
}
+
+ test("model save/load") {
+ val model = NaiveBayesSuite.binaryModel
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = NaiveBayesModel.load(sc, path)
+ assert(model.labels === sameModel.labels)
+ assert(model.pi === sameModel.pi)
+ assert(model.theta === sameModel.theta)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index a2de7fbd41..6de098b383 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
object SVMSuite {
@@ -56,6 +57,9 @@ object SVMSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
+ /** Binary labels, 3 features */
+ private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+
}
class SVMSuite extends FunSuite with MLlibTestSparkContext {
@@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
// Turning off data validation should not throw an exception
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
+
+ test("model save/load") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = SVMSuite.binaryModel
+
+ model.clearThreshold()
+ assert(model.getThreshold.isEmpty)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = SVMModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ assert(sameModel.getThreshold.isEmpty)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ // Save model with threshold.
+ try {
+ model.setThreshold(0.7)
+ model.save(sc, path)
+ val sameModel2 = SVMModel.load(sc, path)
+ assert(model.getThreshold.get == sameModel2.getThreshold.get)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 2668dcc14a..c9f5dc069e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LassoSuite {
+
+ /** 3 features */
+ val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class LassoSuite extends FunSuite with MLlibTestSparkContext {
@@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("model save/load") {
+ val model = LassoSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LassoModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 864622a929..3781931c2f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LinearRegressionSuite {
+
+ /** 3 features */
+ val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
@@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
validatePrediction(
sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
+
+ test("model save/load") {
+ val model = LinearRegressionSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LinearRegressionModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 18d3bf5ea4..43d61151e2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -25,6 +25,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object RidgeRegressionSuite {
+
+ /** 3 features */
+ val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
@@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
+
+ test("model save/load") {
+ val model = RidgeRegressionSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RidgeRegressionModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {