diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-02-04 22:46:48 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-04 22:47:46 -0800 |
commit | 885bcbb079419b731d5a6a53e6e213be5a31c403 (patch) | |
tree | 47f11210ae0b4c3f920752dcec61ddf4683c3bca /mllib/src/main | |
parent | 59fb5c7462bea7862506cc81436737a7f810e77d (diff) | |
download | spark-885bcbb079419b731d5a6a53e6e213be5a31c403.tar.gz spark-885bcbb079419b731d5a6a53e6e213be5a31c403.tar.bz2 spark-885bcbb079419b731d5a6a53e6e213be5a31c403.zip |
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes
This is a PR for Parquet-based model import/export. Please see the design doc on [the JIRA](https://issues.apache.org/jira/browse/SPARK-4587).
Note: This includes only a subset of regression and classification models:
* NaiveBayes, SVM, LogisticRegression
* LinearRegression, RidgeRegression, Lasso
Follow-up PRs will cover other models.
Sketch of current contents:
* New traits: Saveable, Loader
* Implementations for some algorithms
* Also: Added LogisticRegressionModel.getThreshold method (so that unit test could check the threshold)
CC: mengxr selvinsource
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #4233 from jkbradley/ml-import-export and squashes the following commits:
87c4eb8 [Joseph K. Bradley] small cleanups
12d9059 [Joseph K. Bradley] Many cleanups after code review. Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests
b4ee064 [Joseph K. Bradley] Reorganized save/load for regression and classification. Renamed concepts to Saveable, Loader
a34aef5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into ml-import-export
ee99228 [Joseph K. Bradley] scala style fix
79675d5 [Joseph K. Bradley] cleanups in LogisticRegression after rebasing after multinomial PR
d1e5882 [Joseph K. Bradley] organized imports
2935963 [Joseph K. Bradley] Added save/load and tests for most classification and regression models
c495dba [Joseph K. Bradley] made version for model import/export local to each model
1496852 [Joseph K. Bradley] Added save/load for NaiveBayes
8d46386 [Joseph K. Bradley] Added save/load to NaiveBayes
1577d70 [Joseph K. Bradley] fixed issues after rebasing on master (DataFrame patch)
64914a3 [Joseph K. Bradley] added getThreshold to SVMModel
b1fc5ec [Joseph K. Bradley] small cleanups
418ba1b [Joseph K. Bradley] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite
(cherry picked from commit 975bcef467b35586e5224171071355409f451d2d)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
12 files changed, 653 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24..348c1e8760 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} /** * :: Experimental :: @@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * + * @param modelClass String name for model class (used for error messages) + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = { + metadata.select("numFeatures", "numClasses").take(1)(0) match { + case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses) + case _ => throw new Exception(s"$modelClass unable to load" + + s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a469315a1b..5c9feb6fb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,14 +17,17 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** * Classification model trained using Multinomial/Binary Logistic Regression. * @@ -42,7 +45,22 @@ class LogisticRegressionModel ( override val intercept: Double, val numFeatures: Int, val numClasses: Int) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { + + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) @@ -62,6 +80,13 @@ class LogisticRegressionModel ( /** * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + + /** + * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. */ @Experimental @@ -70,7 +95,9 @@ class LogisticRegressionModel ( this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { require(dataMatrix.size == numFeatures) @@ -126,6 +153,40 @@ class LogisticRegressionModel ( bestClass.toDouble } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures, numClasses, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a967df857b..4bafd495f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + /** * Model for Naive Bayes Classifiers. @@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) @@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } + + override def save(sc: SparkContext, path: String): Unit = { + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + } + + override protected def formatVersion: String = "1.0" +} + +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + import Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1) + .toDataFrame("class", "version", "numFeatures", "numClasses") + metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) + } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val model = SaveLoadV1_0.load(sc, path) + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a3..24d31e62ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,13 +17,16 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** * Model for Support Vector Machines (SVMs). * @@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD class SVMModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { private var threshold: Option[Double] = Some(0.0) @@ -51,6 +55,13 @@ class SVMModel ( /** * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + + /** + * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. */ @Experimental @@ -69,6 +80,42 @@ class SVMModel ( case None => margin } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object SVMModel extends Loader[SVMModel] { + + override def load(sc: SparkContext, path: String): SVMModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 0000000000..b60c0cdd0a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper class for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + numFeatures: Int, + numClasses: Int, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1) + .toDataFrame("class", "version", "numFeatures", "numClasses") + metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad9..1159e59fff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -32,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +42,37 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LassoModel extends Loader[LassoModel] { + + override def load(sc: SparkContext, path: String): LassoModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377..0136dcfdce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -38,12 +42,37 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LinearRegressionModel extends Loader[LinearRegressionModel] { + + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e..843e59bdfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} @Experimental trait RegressionModel extends Serializable { @@ -48,3 +50,21 @@ trait RegressionModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * + * @param modelClass String name for model class (used for error messages) + * @return numFeatures + */ + def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = { + metadata.select("numFeatures").take(1)(0) match { + case Row(nFeatures: Int) => nFeatures + case _ => throw new Exception(s"$modelClass unable to load" + + s" numFeatures from metadata: ${Loader.metadataPath(path)}") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051..f2a5f1db1e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. @@ -32,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +43,37 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 0000000000..00f25a8be9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1) + .toDataFrame("class", "version", "numFeatures") + metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + data match { + case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") + Data(weights, intercept) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a576096306..a25e625a40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -53,7 +53,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala new file mode 100644 index 0000000000..56b77a7d12 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructType, StructField} + + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +trait Saveable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * This directory and any intermediate directory will be created if needed. + */ + def save(sc: SparkContext, path: String): Unit + + /** Current version of model save/load format. */ + protected def formatVersion: String + +} + +/** + * :: DeveloperApi :: + * + * Trait for classes which can load models and transformers from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +trait Loader[M <: Saveable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + def load(sc: SparkContext, path: String): M + +} + +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { + + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + + /** + * Check the schema of loaded model data. + * + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. + * + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. + */ + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name) == field.dataType, + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } + + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = { + val sqlContext = new SQLContext(sc) + val metadata = sqlContext.jsonFile(metadataPath(path)) + val (clazz, version) = try { + val metadataArray = metadata.select("class", "version").take(1) + assert(metadataArray.size == 1) + metadataArray(0) match { + case Row(clazz: String, version: String) => (clazz, version) + } + } catch { + case e: Exception => + throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}") + } + (clazz, version, metadata) + } + +} |