aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-18 21:44:01 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 21:44:01 -0800
commitd02d5b9295b169c3ebb0967453b2835edb8a121f (patch)
treea4449633255a2e32a30fa165801e49decc74c28a /mllib
parentfc3f77b42d62ca789d0ee07403795978961991c7 (diff)
downloadspark-d02d5b9295b169c3ebb0967453b2835edb8a121f.tar.gz
spark-d02d5b9295b169c3ebb0967453b2835edb8a121f.tar.bz2
spark-d02d5b9295b169c3ebb0967453b2835edb8a121f.zip
[SPARK-11842][ML] Small cleanups to existing Readers and Writers
Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley <joseph@databricks.com> Closes #9829 from jkbradley/ml-io-cleanups.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
10 files changed, 38 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index b0f22e042e..6f15b37abc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -27,7 +27,7 @@ import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkContext, Logging}
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.MLReader
import org.apache.spark.ml.util.MLWriter
@@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
+ @Since("1.6.0")
override def write: MLWriter = new Pipeline.PipelineWriter(this)
}
+@Since("1.6.0")
object Pipeline extends MLReadable[Pipeline] {
+ @Since("1.6.0")
override def read: MLReader[Pipeline] = new PipelineReader
+ @Since("1.6.0")
override def load(path: String): Pipeline = super.load(path)
- private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter {
+ private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {
SharedReadWrite.validateStages(instance.getStages)
@@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] {
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
}
- private[ml] class PipelineReader extends MLReader[Pipeline] {
+ private class PipelineReader extends MLReader[Pipeline] {
/** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.Pipeline"
+ private val className = classOf[Pipeline].getName
override def load(path: String): Pipeline = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
@@ -333,18 +337,22 @@ class PipelineModel private[ml] (
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
+ @Since("1.6.0")
override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
}
+@Since("1.6.0")
object PipelineModel extends MLReadable[PipelineModel] {
import Pipeline.SharedReadWrite
+ @Since("1.6.0")
override def read: MLReader[PipelineModel] = new PipelineModelReader
+ @Since("1.6.0")
override def load(path: String): PipelineModel = super.load(path)
- private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
+ private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
@@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] {
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
}
- private[ml] class PipelineModelReader extends MLReader[PipelineModel] {
+ private class PipelineModelReader extends MLReader[PipelineModel] {
/** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.PipelineModel"
+ private val className = classOf[PipelineModel].getName
override def load(path: String): PipelineModel = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a3cc49f7f0..418bbdc9a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] (
*
* This also does not save the [[parent]] currently.
*/
+ @Since("1.6.0")
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
}
+@Since("1.6.0")
object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
+ @Since("1.6.0")
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
+ @Since("1.6.0")
override def load(path: String): LogisticRegressionModel = super.load(path)
/** [[MLWriter]] instance for [[LogisticRegressionModel]] */
- private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
+ private[LogisticRegressionModel]
+ class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
extends MLWriter with Logging {
private case class Data(
@@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
instance.coefficients)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
- private[classification] class LogisticRegressionModelReader
+ private class LogisticRegressionModelReader
extends MLReader[LogisticRegressionModel] {
/** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
+ private val className = classOf[LogisticRegressionModel].getName
override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
@@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
* @return This MultilabelSummarizer
*/
def add(label: Double, weight: Double = 1.0): this.type = {
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
@@ -839,7 +844,7 @@ private class LogisticAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 4969cf4245..b9e2144c0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] {
- private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
+ private val className = classOf[CountVectorizerModel].getName
override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 0e00ef6f2e..f7b0f29a27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] {
private class IDFModelReader extends MLReader[IDFModel] {
- private val className = "org.apache.spark.ml.feature.IDFModel"
+ private val className = classOf[IDFModel].getName
override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index ed24eabb50..c2866f5ece 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] {
- private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+ private val className = classOf[MinMaxScalerModel].getName
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 1f689c1da1..6d545219eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
private class StandardScalerModelReader extends MLReader[StandardScalerModel] {
- private val className = "org.apache.spark.ml.feature.StandardScalerModel"
+ private val className = classOf[StandardScalerModel].getName
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 97a2e4f6d6..5c40c35eea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
private class StringIndexerModelReader extends MLReader[StringIndexerModel] {
- private val className = "org.apache.spark.ml.feature.StringIndexerModel"
+ private val className = classOf[StringIndexerModel].getName
override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 795b73c4c2..4d35177ad9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] {
@Since("1.6.0")
override def load(path: String): ALSModel = super.load(path)
- private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {
+ private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
@@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] {
}
}
- private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
+ private class ALSModelReader extends MLReader[ALSModel] {
/** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.recommendation.ALSModel"
+ private val className = classOf[ALSModel].getName
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7ba1a60eda..70ccec766c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {
/** Checked against metadata when loading model */
- private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
+ private val className = classOf[LinearRegressionModel].getName
override def load(path: String): LinearRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 48ce1bb630..a9a6ff8a78 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -898,7 +898,7 @@ object LogisticRegressionSuite {
"regParam" -> 0.01,
"elasticNetParam" -> 0.1,
"maxIter" -> 2, // intentionally small
- "fitIntercept" -> false,
+ "fitIntercept" -> true,
"tol" -> 0.8,
"standardization" -> false,
"threshold" -> 0.6