aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-04 22:46:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-04 22:46:48 -0800
commit975bcef467b35586e5224171071355409f451d2d (patch)
tree47f11210ae0b4c3f920752dcec61ddf4683c3bca /mllib/src
parentc23ac03c8c27e840498a192b088e00b27076765f (diff)
downloadspark-975bcef467b35586e5224171071355409f451d2d.tar.gz
spark-975bcef467b35586e5224171071355409f451d2d.tar.bz2
spark-975bcef467b35586e5224171071355409f451d2d.zip
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes
This is a PR for Parquet-based model import/export. Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587). Note: This includes only a subset of regression and classification models: * NaiveBayes, SVM, LogisticRegression * LinearRegression, RidgeRegression, Lasso Follow-up PRs will cover other models. Sketch of current contents: * New traits: Saveable, Loader * Implementations for some algorithms * Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold) CC: mengxr selvinsource Author: Joseph K. Bradley <joseph@databricks.com> Closes #4233 from jkbradley/ml-import-export and squashes the following commits: 87c4eb8 [Joseph K. Bradley] small cleanups 12d9059 [Joseph K. Bradley] Many cleanups after code review. Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification. Renamed concepts to Saveable, Loader a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export ee99228 [Joseph K. Bradley] scala style fix 79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR d1e5882 [Joseph K. Bradley] organized imports 2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models c495dba [Joseph K. Bradley] made version for model import/export local to each model 1496852 [Joseph K. Bradley] Added save/load for NaiveBayes 8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes 1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch) 64914a3 [Joseph K. Bradley] added getThreshold to SVMModel b1fc5ec [Joseph K. Bradley] small cleanups 418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala87
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala51
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala95
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala86
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala139
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala70
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala24
18 files changed, 863 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index b7a1d90d24..348c1e8760 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
/**
* :: Experimental ::
@@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
+
+private[mllib] object ClassificationModel {
+
+ /**
+ * Helper method for loading GLM classification model metadata.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ * @return (numFeatures, numClasses)
+ */
+ def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
+ metadata.select("numFeatures", "numClasses").take(1)(0) match {
+ case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
+ case _ => throw new Exception(s"$modelClass unable to load" +
+ s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a469315a1b..5c9feb6fb2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -17,14 +17,17 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
+
/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
@@ -42,7 +45,22 @@ class LogisticRegressionModel (
override val intercept: Double,
val numFeatures: Int,
val numClasses: Int)
- extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+ with Saveable {
+
+ if (numClasses == 2) {
+ require(weights.size == numFeatures,
+ s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
+ s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
+ } else {
+ val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
+ val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
+ require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
+ s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
+ s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
+ s" or $weightsSizeWithIntercept (with intercept)," +
+ s" but was given weights of length ${weights.size}")
+ }
def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
@@ -62,6 +80,13 @@ class LogisticRegressionModel (
/**
* :: Experimental ::
+ * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+ */
+ @Experimental
+ def getThreshold: Option[Double] = threshold
+
+ /**
+ * :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
*/
@Experimental
@@ -70,7 +95,9 @@ class LogisticRegressionModel (
this
}
- override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)
@@ -126,6 +153,40 @@ class LogisticRegressionModel (
bestClass.toDouble
}
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+ numFeatures, numClasses, weights, intercept, threshold)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+ // numFeatures, numClasses, weights are checked in model initialization
+ val model =
+ new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
+ data.threshold match {
+ case Some(t) => model.setThreshold(t)
+ case None => model.clearThreshold()
+ }
+ model
+ case _ => throw new Exception(
+ s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index a967df857b..4bafd495f9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import org.apache.spark.{SparkException, Logging}
-import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* Model for Naive Bayes Classifiers.
@@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
- val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
+ val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
+ NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object NaiveBayesModel extends Loader[NaiveBayesModel] {
+
+ import Loader._
+
+ private object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Hard-code class name string in case it changes in the future */
+ def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+
+ /** Model data for model import/export */
+ case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
+
+ def save(sc: SparkContext, path: String, data: Data): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
+ .toDataFrame("class", "version", "numFeatures", "numClasses")
+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ dataRDD.saveAsParquetFile(dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ checkSchema[Data](dataRDD.schema)
+ val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
+ assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ val data = dataArray(0)
+ val labels = data.getAs[Seq[Double]](0).toArray
+ val pi = data.getAs[Seq[Double]](1).toArray
+ val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
+ new NaiveBayesModel(labels, pi, theta)
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val (loadedClassName, version, metadata) = loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val model = SaveLoadV1_0.load(sc, path)
+ assert(model.pi.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class priors vector pi had ${model.pi.size} elements")
+ assert(model.theta.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class conditionals array theta had ${model.theta.size} elements")
+ assert(model.theta.forall(_.size == numFeatures),
+ s"NaiveBayesModel.load expected $numFeatures features," +
+ s" but class conditionals array theta had elements of size:" +
+ s" ${model.theta.map(_.size).mkString(",")}")
+ model
+ case _ => throw new Exception(
+ s"NaiveBayesModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index dd514ff8a3..24d31e62ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -17,13 +17,16 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.DataValidators
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
+
/**
* Model for Support Vector Machines (SVMs).
*
@@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
class SVMModel (
override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+ with Saveable {
private var threshold: Option[Double] = Some(0.0)
@@ -51,6 +55,13 @@ class SVMModel (
/**
* :: Experimental ::
+ * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+ */
+ @Experimental
+ def getThreshold: Option[Double] = threshold
+
+ /**
+ * :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
*/
@Experimental
@@ -69,6 +80,42 @@ class SVMModel (
case None => margin
}
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+ numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object SVMModel extends Loader[SVMModel] {
+
+ override def load(sc: SparkContext, path: String): SVMModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+ val model = new SVMModel(data.weights, data.intercept)
+ assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
+ s" was given non-matching weights vector of size ${model.weights.size}")
+ assert(numClasses == 2,
+ s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
+ data.threshold match {
+ case Some(t) => model.setThreshold(t)
+ case None => model.clearThreshold()
+ }
+ model
+ case _ => throw new Exception(
+ s"SVMModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
new file mode 100644
index 0000000000..b60c0cdd0a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper class for import/export of GLM classification models.
+ */
+private[classification] object GLMClassificationModel {
+
+ object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Model data for import/export */
+ case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
+
+ /**
+ * Helper method for saving GLM classification model metadata and data.
+ * @param modelClass String name for model class, to be saved with metadata
+ * @param numClasses Number of classes label can take, to be saved with metadata
+ */
+ def save(
+ sc: SparkContext,
+ path: String,
+ modelClass: String,
+ numFeatures: Int,
+ numClasses: Int,
+ weights: Vector,
+ intercept: Double,
+ threshold: Option[Double]): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
+ .toDataFrame("class", "version", "numFeatures", "numClasses")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val data = Data(weights, intercept, threshold)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Helper method for loading GLM classification model data.
+ *
+ * NOTE: Callers of this method should check numClasses, numFeatures on their own.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ */
+ def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataRDD = sqlContext.parquetFile(datapath)
+ val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
+ assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+ val data = dataArray(0)
+ assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
+ val (weights, intercept) = data match {
+ case Row(weights: Vector, intercept: Double, _) =>
+ (weights, intercept)
+ }
+ val threshold = if (data.isNullAt(2)) {
+ None
+ } else {
+ Some(data.getDouble(2))
+ }
+ Data(weights, intercept, threshold)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 8ecd5c6ad9..1159e59fff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -17,9 +17,11 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
/**
@@ -32,7 +34,7 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
+ with RegressionModel with Serializable with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -40,12 +42,37 @@ class LassoModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LassoModel extends Loader[LassoModel] {
+
+ override def load(sc: SparkContext, path: String): LassoModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new LassoModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"LassoModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a regression model with L1-regularization using Stochastic Gradient Descent.
* This solves the l1-regularized least squares regression formulation
- * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1
+ * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 81b6598377..0136dcfdce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -17,9 +17,12 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.rdd.RDD
/**
* Regression model trained using LinearRegression.
@@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._
class LinearRegressionModel (
override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
+ with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -38,12 +42,37 @@ class LinearRegressionModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LinearRegressionModel extends Loader[LinearRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): LinearRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new LinearRegressionModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"LinearRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a linear regression model with no regularization using Stochastic Gradient Descent.
* This solves the least squares regression formulation
- * f(weights) = 1/n ||A weights-y||^2
+ * f(weights) = 1/n ||A weights-y||^2^
* (which is the mean squared error).
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 64b02f7a6e..843e59bdfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
@Experimental
trait RegressionModel extends Serializable {
@@ -48,3 +50,21 @@ trait RegressionModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
+
+private[mllib] object RegressionModel {
+
+ /**
+ * Helper method for loading GLM regression model metadata.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ * @return numFeatures
+ */
+ def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
+ metadata.select("numFeatures").take(1)(0) match {
+ case Row(nFeatures: Int) => nFeatures
+ case _ => throw new Exception(s"$modelClass unable to load" +
+ s" numFeatures from metadata: ${Loader.metadataPath(path)}")
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 076ba35051..f2a5f1db1e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.optimization._
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+
/**
* Regression model trained using RidgeRegression.
@@ -32,7 +35,7 @@ class RidgeRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
+ with RegressionModel with Serializable with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -40,12 +43,37 @@ class RidgeRegressionModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): RidgeRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new RidgeRegressionModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"RidgeRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
* This solves the l1-regularized least squares regression formulation
- * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2
+ * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
new file mode 100644
index 0000000000..00f25a8be9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper methods for import/export of GLM regression models.
+ */
+private[regression] object GLMRegressionModel {
+
+ object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Model data for model import/export */
+ case class Data(weights: Vector, intercept: Double)
+
+ /**
+ * Helper method for saving GLM regression model metadata and data.
+ * @param modelClass String name for model class, to be saved with metadata
+ */
+ def save(
+ sc: SparkContext,
+ path: String,
+ modelClass: String,
+ weights: Vector,
+ intercept: Double): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
+ .toDataFrame("class", "version", "numFeatures")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val data = Data(weights, intercept)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Helper method for loading GLM regression model data.
+ * @param modelClass String name for model class (used for error messages)
+ * @param numFeatures Number of features, to be checked against loaded data.
+ * The length of the weights vector should equal numFeatures.
+ */
+ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataRDD = sqlContext.parquetFile(datapath)
+ val dataArray = dataRDD.select("weights", "intercept").take(1)
+ assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+ val data = dataArray(0)
+ assert(data.size == 2, s"Unable to load $modelClass data from: $datapath")
+ data match {
+ case Row(weights: Vector, intercept: Double) =>
+ assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
+ s" found ${weights.size} features when loading $modelClass weights from $datapath")
+ Data(weights, intercept)
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a576096306..a25e625a40 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -53,7 +53,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}
-
/**
* Predict values for the given data set using the model trained.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
new file mode 100644
index 0000000000..56b77a7d12
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{DataType, StructType, StructField}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for models and transformers which may be saved as files.
+ * This should be inherited by the class which implements model instances.
+ */
+@DeveloperApi
+trait Saveable {
+
+ /**
+ * Save this model to the given path.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[Loader.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * This directory and any intermediate directory will be created if needed.
+ */
+ def save(sc: SparkContext, path: String): Unit
+
+ /** Current version of model save/load format. */
+ protected def formatVersion: String
+
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for classes which can load models and transformers from files.
+ * This should be inherited by an object paired with the model class.
+ */
+@DeveloperApi
+trait Loader[M <: Saveable] {
+
+ /**
+ * Load a model from the given path.
+ *
+ * The model should have been saved by [[Saveable.save]].
+ *
+ * @param sc Spark context used for loading model files.
+ * @param path Path specifying the directory to which the model was saved.
+ * @return Model instance
+ */
+ def load(sc: SparkContext, path: String): M
+
+}
+
+/**
+ * Helper methods for loading models from files.
+ */
+private[mllib] object Loader {
+
+ /** Returns URI for path/data using the Hadoop filesystem */
+ def dataPath(path: String): String = new Path(path, "data").toUri.toString
+
+ /** Returns URI for path/metadata using the Hadoop filesystem */
+ def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString
+
+ /**
+ * Check the schema of loaded model data.
+ *
+ * This checks every field in the expected schema to make sure that a field with the same
+ * name and DataType appears in the loaded schema. Note that this does NOT check metadata
+ * or containsNull.
+ *
+ * @param loadedSchema Schema for model data loaded from file.
+ * @tparam Data Expected data type from which an expected schema can be derived.
+ */
+ def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ val expectedFields: Array[StructField] =
+ ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields
+ val loadedFields: Map[String, DataType] =
+ loadedSchema.map(field => field.name -> field.dataType).toMap
+ expectedFields.foreach { field =>
+ assert(loadedFields.contains(field.name), s"Unable to parse model data." +
+ s" Expected field with name ${field.name} was missing in loaded schema:" +
+ s" ${loadedFields.mkString(", ")}")
+ assert(loadedFields(field.name) == field.dataType,
+ s"Unable to parse model data. Expected field $field but found field" +
+ s" with different type: ${loadedFields(field.name)}")
+ }
+ }
+
+ /**
+ * Load metadata from the given path.
+ * @return (class name, version, metadata)
+ */
+ def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = {
+ val sqlContext = new SQLContext(sc)
+ val metadata = sqlContext.jsonFile(metadataPath(path))
+ val (clazz, version) = try {
+ val metadataArray = metadata.select("class", "version").take(1)
+ assert(metadataArray.size == 1)
+ metadataArray(0) match {
+ case Row(clazz: String, version: String) => (clazz, version)
+ }
+ } catch {
+ case e: Exception =>
+ throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}")
+ }
+ (clazz, version, metadata)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 3fb45938f7..d2b40f2cae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.mllib.classification
-import scala.util.control.Breaks._
-import scala.util.Random
import scala.collection.JavaConversions._
+import scala.util.Random
+import scala.util.control.Breaks._
import org.scalatest.FunSuite
import org.scalatest.Matchers
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
object LogisticRegressionSuite {
@@ -147,8 +149,25 @@ object LogisticRegressionSuite {
val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
testData
}
+
+ /** Binary labels, 3 features */
+ private val binaryModel = new LogisticRegressionModel(
+ weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2)
+
+ /** 3 classes, 2 features */
+ private val multiclassModel = new LogisticRegressionModel(
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3)
+
+ private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = {
+ assert(a.weights == b.weights)
+ assert(a.intercept == b.intercept)
+ assert(a.numClasses == b.numClasses)
+ assert(a.numFeatures == b.numFeatures)
+ assert(a.getThreshold == b.getThreshold)
+ }
}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
@@ -462,6 +481,53 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}
+ test("model save/load: binary classification") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = LogisticRegressionSuite.binaryModel
+
+ model.clearThreshold()
+ assert(model.getThreshold.isEmpty)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ // Save model with threshold.
+ try {
+ model.setThreshold(0.7)
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load: multiclass classification") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = LogisticRegressionSuite.multiclassModel
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LogisticRegressionModel.load(sc, path)
+ LogisticRegressionSuite.checkModelsEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
}
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index e68fe89d6c..64dcc0fb9f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -25,6 +25,8 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
object NaiveBayesSuite {
@@ -58,6 +60,18 @@ object NaiveBayesSuite {
LabeledPoint(y, Vectors.dense(xi))
}
}
+
+ private val smallPi = Array(0.5, 0.3, 0.2).map(math.log)
+
+ private val smallTheta = Array(
+ Array(0.91, 0.03, 0.03, 0.03), // label 0
+ Array(0.03, 0.91, 0.03, 0.03), // label 1
+ Array(0.03, 0.03, 0.91, 0.03) // label 2
+ ).map(_.map(math.log))
+
+ /** Binary labels, 3 features */
+ private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
+ theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)))
}
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -74,12 +88,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
test("Naive Bayes") {
val nPoints = 10000
- val pi = Array(0.5, 0.3, 0.2).map(math.log)
- val theta = Array(
- Array(0.91, 0.03, 0.03, 0.03), // label 0
- Array(0.03, 0.91, 0.03, 0.03), // label 1
- Array(0.03, 0.03, 0.91, 0.03) // label 2
- ).map(_.map(math.log))
+ val pi = NaiveBayesSuite.smallPi
+ val theta = NaiveBayesSuite.smallTheta
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
@@ -123,6 +133,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
NaiveBayes.train(sc.makeRDD(nan, 2))
}
}
+
+ test("model save/load") {
+ val model = NaiveBayesSuite.binaryModel
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = NaiveBayesModel.load(sc, path)
+ assert(model.labels === sameModel.labels)
+ assert(model.pi === sameModel.pi)
+ assert(model.theta === sameModel.theta)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index a2de7fbd41..6de098b383 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.util.Utils
object SVMSuite {
@@ -56,6 +57,9 @@ object SVMSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
+ /** Binary labels, 3 features */
+ private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+
}
class SVMSuite extends FunSuite with MLlibTestSparkContext {
@@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
// Turning off data validation should not throw an exception
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
+
+ test("model save/load") {
+ // NOTE: This will need to be generalized once there are multiple model format versions.
+ val model = SVMSuite.binaryModel
+
+ model.clearThreshold()
+ assert(model.getThreshold.isEmpty)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = SVMModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ assert(sameModel.getThreshold.isEmpty)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ // Save model with threshold.
+ try {
+ model.setThreshold(0.7)
+ model.save(sc, path)
+ val sameModel2 = SVMModel.load(sc, path)
+ assert(model.getThreshold.get == sameModel2.getThreshold.get)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 2668dcc14a..c9f5dc069e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LassoSuite {
+
+ /** 3 features */
+ val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class LassoSuite extends FunSuite with MLlibTestSparkContext {
@@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("model save/load") {
+ val model = LassoSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LassoModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 864622a929..3781931c2f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -24,6 +24,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object LinearRegressionSuite {
+
+ /** 3 features */
+ val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
@@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
validatePrediction(
sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
+
+ test("model save/load") {
+ val model = LinearRegressionSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = LinearRegressionModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 18d3bf5ea4..43d61151e2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -25,6 +25,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
+import org.apache.spark.util.Utils
+
+private object RidgeRegressionSuite {
+
+ /** 3 features */
+ val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
+}
class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
@@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
+
+ test("model save/load") {
+ val model = RidgeRegressionSuite.model
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RidgeRegressionModel.load(sc, path)
+ assert(model.weights == sameModel.weights)
+ assert(model.intercept == sameModel.intercept)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {