aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-04 22:46:48 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-04 22:47:46 -0800
commit885bcbb079419b731d5a6a53e6e213be5a31c403 (patch)
tree47f11210ae0b4c3f920752dcec61ddf4683c3bca /mllib/src/main
parent59fb5c7462bea7862506cc81436737a7f810e77d (diff)
downloadspark-885bcbb079419b731d5a6a53e6e213be5a31c403.tar.gz
spark-885bcbb079419b731d5a6a53e6e213be5a31c403.tar.bz2
spark-885bcbb079419b731d5a6a53e6e213be5a31c403.zip
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes
This is a PR for Parquet-based model import/export. Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587). Note: This includes only a subset of regression and classification models: * NaiveBayes, SVM, LogisticRegression * LinearRegression, RidgeRegression, Lasso Follow-up PRs will cover other models. Sketch of current contents: * New traits: Saveable, Loader * Implementations for some algorithms * Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold) CC: mengxr selvinsource Author: Joseph K. Bradley <joseph@databricks.com> Closes #4233 from jkbradley/ml-import-export and squashes the following commits: 87c4eb8 [Joseph K. Bradley] small cleanups 12d9059 [Joseph K. Bradley] Many cleanups after code review. Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification. Renamed concepts to Saveable, Loader a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export ee99228 [Joseph K. Bradley] scala style fix 79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR d1e5882 [Joseph K. Bradley] organized imports 2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models c495dba [Joseph K. Bradley] made version for model import/export local to each model 1496852 [Joseph K. Bradley] Added save/load for NaiveBayes 8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes 1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch) 64914a3 [Joseph K. Bradley] added getThreshold to SVMModel b1fc5ec [Joseph K. Bradley] small cleanups 418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite (cherry picked from commit 975bcef467b35586e5224171071355409f451d2d) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala87
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala51
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala95
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala86
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala139
12 files changed, 653 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index b7a1d90d24..348c1e8760 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
/**
* :: Experimental ::
@@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
+
+private[mllib] object ClassificationModel {
+
+ /**
+ * Helper method for loading GLM classification model metadata.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ * @return (numFeatures, numClasses)
+ */
+ def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
+ metadata.select("numFeatures", "numClasses").take(1)(0) match {
+ case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
+ case _ => throw new Exception(s"$modelClass unable to load" +
+ s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a469315a1b..5c9feb6fb2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -17,14 +17,17 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
+
/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
@@ -42,7 +45,22 @@ class LogisticRegressionModel (
override val intercept: Double,
val numFeatures: Int,
val numClasses: Int)
- extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+ with Saveable {
+
+ if (numClasses == 2) {
+ require(weights.size == numFeatures,
+ s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
+ s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
+ } else {
+ val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
+ val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
+ require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
+ s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
+ s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
+ s" or $weightsSizeWithIntercept (with intercept)," +
+ s" but was given weights of length ${weights.size}")
+ }
def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
@@ -62,6 +80,13 @@ class LogisticRegressionModel (
/**
* :: Experimental ::
+ * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+ */
+ @Experimental
+ def getThreshold: Option[Double] = threshold
+
+ /**
+ * :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
*/
@Experimental
@@ -70,7 +95,9 @@ class LogisticRegressionModel (
this
}
- override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)
@@ -126,6 +153,40 @@ class LogisticRegressionModel (
bestClass.toDouble
}
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+ numFeatures, numClasses, weights, intercept, threshold)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+ // numFeatures, numClasses, weights are checked in model initialization
+ val model =
+ new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
+ data.threshold match {
+ case Some(t) => model.setThreshold(t)
+ case None => model.clearThreshold()
+ }
+ model
+ case _ => throw new Exception(
+ s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index a967df857b..4bafd495f9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import org.apache.spark.{SparkException, Logging}
-import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* Model for Naive Bayes Classifiers.
@@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
- val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
+ val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
+ NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object NaiveBayesModel extends Loader[NaiveBayesModel] {
+
+ import Loader._
+
+ private object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Hard-code class name string in case it changes in the future */
+ def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+
+ /** Model data for model import/export */
+ case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
+
+ def save(sc: SparkContext, path: String, data: Data): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
+ .toDataFrame("class", "version", "numFeatures", "numClasses")
+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ dataRDD.saveAsParquetFile(dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ checkSchema[Data](dataRDD.schema)
+ val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
+ assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+ val data = dataArray(0)
+ val labels = data.getAs[Seq[Double]](0).toArray
+ val pi = data.getAs[Seq[Double]](1).toArray
+ val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
+ new NaiveBayesModel(labels, pi, theta)
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): NaiveBayesModel = {
+ val (loadedClassName, version, metadata) = loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val model = SaveLoadV1_0.load(sc, path)
+ assert(model.pi.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class priors vector pi had ${model.pi.size} elements")
+ assert(model.theta.size == numClasses,
+ s"NaiveBayesModel.load expected $numClasses classes," +
+ s" but class conditionals array theta had ${model.theta.size} elements")
+ assert(model.theta.forall(_.size == numFeatures),
+ s"NaiveBayesModel.load expected $numFeatures features," +
+ s" but class conditionals array theta had elements of size:" +
+ s" ${model.theta.map(_.size).mkString(",")}")
+ model
+ case _ => throw new Exception(
+ s"NaiveBayesModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index dd514ff8a3..24d31e62ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -17,13 +17,16 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.DataValidators
+import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
+
/**
* Model for Support Vector Machines (SVMs).
*
@@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
class SVMModel (
override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
+ with Saveable {
private var threshold: Option[Double] = Some(0.0)
@@ -51,6 +55,13 @@ class SVMModel (
/**
* :: Experimental ::
+ * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+ */
+ @Experimental
+ def getThreshold: Option[Double] = threshold
+
+ /**
+ * :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
*/
@Experimental
@@ -69,6 +80,42 @@ class SVMModel (
case None => margin
}
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
+ numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object SVMModel extends Loader[SVMModel] {
+
+ override def load(sc: SparkContext, path: String): SVMModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val (numFeatures, numClasses) =
+ ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
+ val model = new SVMModel(data.weights, data.intercept)
+ assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
+ s" was given non-matching weights vector of size ${model.weights.size}")
+ assert(numClasses == 2,
+ s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
+ data.threshold match {
+ case Some(t) => model.setThreshold(t)
+ case None => model.clearThreshold()
+ }
+ model
+ case _ => throw new Exception(
+ s"SVMModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
new file mode 100644
index 0000000000..b60c0cdd0a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper class for import/export of GLM classification models.
+ */
+private[classification] object GLMClassificationModel {
+
+ object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Model data for import/export */
+ case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
+
+ /**
+ * Helper method for saving GLM classification model metadata and data.
+ * @param modelClass String name for model class, to be saved with metadata
+ * @param numClasses Number of classes label can take, to be saved with metadata
+ */
+ def save(
+ sc: SparkContext,
+ path: String,
+ modelClass: String,
+ numFeatures: Int,
+ numClasses: Int,
+ weights: Vector,
+ intercept: Double,
+ threshold: Option[Double]): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
+ .toDataFrame("class", "version", "numFeatures", "numClasses")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val data = Data(weights, intercept, threshold)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Helper method for loading GLM classification model data.
+ *
+ * NOTE: Callers of this method should check numClasses, numFeatures on their own.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ */
+ def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataRDD = sqlContext.parquetFile(datapath)
+ val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
+ assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+ val data = dataArray(0)
+ assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
+ val (weights, intercept) = data match {
+ case Row(weights: Vector, intercept: Double, _) =>
+ (weights, intercept)
+ }
+ val threshold = if (data.isNullAt(2)) {
+ None
+ } else {
+ Some(data.getDouble(2))
+ }
+ Data(weights, intercept, threshold)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 8ecd5c6ad9..1159e59fff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -17,9 +17,11 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
/**
@@ -32,7 +34,7 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
+ with RegressionModel with Serializable with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -40,12 +42,37 @@ class LassoModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LassoModel extends Loader[LassoModel] {
+
+ override def load(sc: SparkContext, path: String): LassoModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new LassoModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"LassoModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a regression model with L1-regularization using Stochastic Gradient Descent.
* This solves the l1-regularized least squares regression formulation
- * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1
+ * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 81b6598377..0136dcfdce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -17,9 +17,12 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.rdd.RDD
/**
* Regression model trained using LinearRegression.
@@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._
class LinearRegressionModel (
override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
+ with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -38,12 +42,37 @@ class LinearRegressionModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object LinearRegressionModel extends Loader[LinearRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): LinearRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new LinearRegressionModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"LinearRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a linear regression model with no regularization using Stochastic Gradient Descent.
* This solves the least squares regression formulation
- * f(weights) = 1/n ||A weights-y||^2
+ * f(weights) = 1/n ||A weights-y||^2^
* (which is the mean squared error).
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 64b02f7a6e..843e59bdfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
@Experimental
trait RegressionModel extends Serializable {
@@ -48,3 +50,21 @@ trait RegressionModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
+
+private[mllib] object RegressionModel {
+
+ /**
+ * Helper method for loading GLM regression model metadata.
+ *
+ * @param modelClass String name for model class (used for error messages)
+ * @return numFeatures
+ */
+ def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
+ metadata.select("numFeatures").take(1)(0) match {
+ case Row(nFeatures: Int) => nFeatures
+ case _ => throw new Exception(s"$modelClass unable to load" +
+ s" numFeatures from metadata: ${Loader.metadataPath(path)}")
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 076ba35051..f2a5f1db1e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.optimization._
+import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.regression.impl.GLMRegressionModel
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+
/**
* Regression model trained using RidgeRegression.
@@ -32,7 +35,7 @@ class RidgeRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
+ with RegressionModel with Serializable with Saveable {
override protected def predictPoint(
dataMatrix: Vector,
@@ -40,12 +43,37 @@ class RidgeRegressionModel (
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
+
+ override def load(sc: SparkContext, path: String): RidgeRegressionModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ // Hard-code class name string in case it changes in the future
+ val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
+ new RidgeRegressionModel(data.weights, data.intercept)
+ case _ => throw new Exception(
+ s"RidgeRegressionModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
* This solves the l1-regularized least squares regression formulation
- * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2
+ * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
new file mode 100644
index 0000000000..00f25a8be9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.Loader
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+/**
+ * Helper methods for import/export of GLM regression models.
+ */
+private[regression] object GLMRegressionModel {
+
+ object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ /** Model data for model import/export */
+ case class Data(weights: Vector, intercept: Double)
+
+ /**
+ * Helper method for saving GLM regression model metadata and data.
+ * @param modelClass String name for model class, to be saved with metadata
+ */
+ def save(
+ sc: SparkContext,
+ path: String,
+ modelClass: String,
+ weights: Vector,
+ intercept: Double): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Create JSON metadata.
+ val metadataRDD =
+ sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
+ .toDataFrame("class", "version", "numFeatures")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val data = Data(weights, intercept)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ // TODO: repartition with 1 partition after SPARK-5532 gets fixed
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Helper method for loading GLM regression model data.
+ * @param modelClass String name for model class (used for error messages)
+ * @param numFeatures Number of features, to be checked against loaded data.
+ * The length of the weights vector should equal numFeatures.
+ */
+ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataRDD = sqlContext.parquetFile(datapath)
+ val dataArray = dataRDD.select("weights", "intercept").take(1)
+ assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
+ val data = dataArray(0)
+ assert(data.size == 2, s"Unable to load $modelClass data from: $datapath")
+ data match {
+ case Row(weights: Vector, intercept: Double) =>
+ assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
+ s" found ${weights.size} features when loading $modelClass weights from $datapath")
+ Data(weights, intercept)
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a576096306..a25e625a40 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -53,7 +53,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}
-
/**
* Predict values for the given data set using the model trained.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
new file mode 100644
index 0000000000..56b77a7d12
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{DataType, StructType, StructField}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for models and transformers which may be saved as files.
+ * This should be inherited by the class which implements model instances.
+ */
+@DeveloperApi
+trait Saveable {
+
+ /**
+ * Save this model to the given path.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[Loader.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * This directory and any intermediate directory will be created if needed.
+ */
+ def save(sc: SparkContext, path: String): Unit
+
+ /** Current version of model save/load format. */
+ protected def formatVersion: String
+
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for classes which can load models and transformers from files.
+ * This should be inherited by an object paired with the model class.
+ */
+@DeveloperApi
+trait Loader[M <: Saveable] {
+
+ /**
+ * Load a model from the given path.
+ *
+ * The model should have been saved by [[Saveable.save]].
+ *
+ * @param sc Spark context used for loading model files.
+ * @param path Path specifying the directory to which the model was saved.
+ * @return Model instance
+ */
+ def load(sc: SparkContext, path: String): M
+
+}
+
+/**
+ * Helper methods for loading models from files.
+ */
+private[mllib] object Loader {
+
+ /** Returns URI for path/data using the Hadoop filesystem */
+ def dataPath(path: String): String = new Path(path, "data").toUri.toString
+
+ /** Returns URI for path/metadata using the Hadoop filesystem */
+ def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString
+
+ /**
+ * Check the schema of loaded model data.
+ *
+ * This checks every field in the expected schema to make sure that a field with the same
+ * name and DataType appears in the loaded schema. Note that this does NOT check metadata
+ * or containsNull.
+ *
+ * @param loadedSchema Schema for model data loaded from file.
+ * @tparam Data Expected data type from which an expected schema can be derived.
+ */
+ def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ val expectedFields: Array[StructField] =
+ ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields
+ val loadedFields: Map[String, DataType] =
+ loadedSchema.map(field => field.name -> field.dataType).toMap
+ expectedFields.foreach { field =>
+ assert(loadedFields.contains(field.name), s"Unable to parse model data." +
+ s" Expected field with name ${field.name} was missing in loaded schema:" +
+ s" ${loadedFields.mkString(", ")}")
+ assert(loadedFields(field.name) == field.dataType,
+ s"Unable to parse model data. Expected field $field but found field" +
+ s" with different type: ${loadedFields(field.name)}")
+ }
+ }
+
+ /**
+ * Load metadata from the given path.
+ * @return (class name, version, metadata)
+ */
+ def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = {
+ val sqlContext = new SQLContext(sc)
+ val metadata = sqlContext.jsonFile(metadataPath(path))
+ val (clazz, version) = try {
+ val metadataArray = metadata.select("class", "version").take(1)
+ assert(metadataArray.size == 1)
+ metadataArray(0) match {
+ case Row(clazz: String, version: String) => (clazz, version)
+ }
+ } catch {
+ case e: Exception =>
+ throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}")
+ }
+ (clazz, version, metadata)
+ }
+
+}