aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-18 18:34:01 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-18 18:34:01 -0800
commite99d3392068bc929c900a4cc7b50e9e2b437a23a (patch)
treeecc75c1b1d75173742d95c4156cce653c12418b2 /mllib/src/main
parent59a501359a267fbdb7689058693aa788703e54b1 (diff)
downloadspark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.gz
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.bz2
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.zip
[SPARK-11839][ML] refactor save/write traits
* add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9827 from mengxr/SPARK-11839.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala40
25 files changed, 174 insertions, 306 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 25f0c696f4..b0f22e042e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
-import org.apache.spark.ml.util.Reader
-import org.apache.spark.ml.util.Writer
+import org.apache.spark.ml.util.MLReader
+import org.apache.spark.ml.util.MLWriter
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@Experimental
-class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
+class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable {
def this() = this(Identifiable.randomUID("pipeline"))
@@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
- override def write: Writer = new Pipeline.PipelineWriter(this)
+ override def write: MLWriter = new Pipeline.PipelineWriter(this)
}
-object Pipeline extends Readable[Pipeline] {
+object Pipeline extends MLReadable[Pipeline] {
- override def read: Reader[Pipeline] = new PipelineReader
+ override def read: MLReader[Pipeline] = new PipelineReader
- override def load(path: String): Pipeline = read.load(path)
+ override def load(path: String): Pipeline = super.load(path)
- private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
+ private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter {
SharedReadWrite.validateStages(instance.getStages)
@@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] {
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
}
- private[ml] class PipelineReader extends Reader[Pipeline] {
+ private[ml] class PipelineReader extends MLReader[Pipeline] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.Pipeline"
@@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] {
}
}
- /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
+ /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */
private[ml] object SharedReadWrite {
import org.json4s.JsonDSL._
@@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] {
/** Check that all stages are Writable */
def validateStages(stages: Array[PipelineStage]): Unit = {
stages.foreach {
- case stage: Writable => // good
+ case stage: MLWritable => // good
case other =>
throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
@@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] {
// Save stages
val stagesDir = new Path(path, "stages").toString
- stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
+ stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) =>
stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
}
}
@@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] {
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
val cls = Utils.classForName(stageMetadata.className)
- cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
+ cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
}
(metadata.uid, stages)
}
@@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] {
class PipelineModel private[ml] (
override val uid: String,
val stages: Array[Transformer])
- extends Model[PipelineModel] with Writable with Logging {
+ extends Model[PipelineModel] with MLWritable with Logging {
/** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
@@ -333,18 +333,18 @@ class PipelineModel private[ml] (
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
- override def write: Writer = new PipelineModel.PipelineModelWriter(this)
+ override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
}
-object PipelineModel extends Readable[PipelineModel] {
+object PipelineModel extends MLReadable[PipelineModel] {
import Pipeline.SharedReadWrite
- override def read: Reader[PipelineModel] = new PipelineModelReader
+ override def read: MLReader[PipelineModel] = new PipelineModelReader
- override def load(path: String): PipelineModel = read.load(path)
+ override def load(path: String): PipelineModel = super.load(path)
- private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
+ private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
@@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] {
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
}
- private[ml] class PipelineModelReader extends Reader[PipelineModel] {
+ private[ml] class PipelineModelReader extends MLReader[PipelineModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.PipelineModel"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 71c2533bcb..a3cc49f7f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
@Experimental
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
- with LogisticRegressionParams with Writable with Logging {
+ with LogisticRegressionParams with DefaultParamsWritable with Logging {
def this() = this(Identifiable.randomUID("logreg"))
@@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
-
- override def write: Writer = new DefaultParamsWriter(this)
}
-object LogisticRegression extends Readable[LogisticRegression] {
- override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression]
+object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
+
+ override def load(path: String): LogisticRegression = super.load(path)
}
/**
@@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
- with LogisticRegressionParams with Writable {
+ with LogisticRegressionParams with MLWritable {
@deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients
@@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] (
}
/**
- * Returns a [[Writer]] instance for this ML instance.
+ * Returns a [[MLWriter]] instance for this ML instance.
*
* For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
* This also does not save the [[parent]] currently.
*/
- override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
+ override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
}
-object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
+object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
- override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader
+ override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
- override def load(path: String): LogisticRegressionModel = read.load(path)
+ override def load(path: String): LogisticRegressionModel = super.load(path)
- /** [[Writer]] instance for [[LogisticRegressionModel]] */
+ /** [[MLWriter]] instance for [[LogisticRegressionModel]] */
private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
- extends Writer with Logging {
+ extends MLWriter with Logging {
private case class Data(
numClasses: Int,
@@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
}
private[classification] class LogisticRegressionModelReader
- extends Reader[LogisticRegressionModel] {
+ extends MLReader[LogisticRegressionModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index e2be6547d8..63c0658148 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
final class Binarizer(override val uid: String)
- extends Transformer with Writable with HasInputCol with HasOutputCol {
+ extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("binarizer"))
@@ -86,17 +86,11 @@ final class Binarizer(override val uid: String)
}
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object Binarizer extends Readable[Binarizer] {
-
- @Since("1.6.0")
- override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
+object Binarizer extends DefaultParamsReadable[Binarizer] {
@Since("1.6.0")
- override def load(path: String): Binarizer = read.load(path)
+ override def load(path: String): Binarizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 7095fbd70a..324353a96a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
@Experimental
final class Bucketizer(override val uid: String)
- extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable {
+ extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("bucketizer"))
@@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String)
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
-object Bucketizer extends Readable[Bucketizer] {
+object Bucketizer extends DefaultParamsReadable[Bucketizer] {
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
@@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] {
}
@Since("1.6.0")
- override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer]
-
- @Since("1.6.0")
- override def load(path: String): Bucketizer = read.load(path)
+ override def load(path: String): Bucketizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 5ff9bfb7d1..4969cf4245 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
*/
@Experimental
class CountVectorizer(override val uid: String)
- extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {
+ extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("cntVec"))
@@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String)
}
override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object CountVectorizer extends Readable[CountVectorizer] {
-
- @Since("1.6.0")
- override def read: Reader[CountVectorizer] = new DefaultParamsReader
+object CountVectorizer extends DefaultParamsReadable[CountVectorizer] {
@Since("1.6.0")
override def load(path: String): CountVectorizer = super.load(path)
@@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] {
*/
@Experimental
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
- extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {
+ extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable {
import CountVectorizerModel._
@@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
}
@Since("1.6.0")
- override def write: Writer = new CountVectorizerModelWriter(this)
+ override def write: MLWriter = new CountVectorizerModelWriter(this)
}
@Since("1.6.0")
-object CountVectorizerModel extends Readable[CountVectorizerModel] {
+object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private[CountVectorizerModel]
- class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {
+ class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter {
private case class Data(vocabulary: Seq[String])
@@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] {
}
}
- private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {
+ private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] {
private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
@@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] {
}
@Since("1.6.0")
- override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader
+ override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader
@Since("1.6.0")
override def load(path: String): CountVectorizerModel = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index 6ea5a61617..6bed72164a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class DCT(override val uid: String)
- extends UnaryTransformer[Vector, Vector, DCT] with Writable {
+ extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("dct"))
@@ -69,17 +69,11 @@ class DCT(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object DCT extends Readable[DCT] {
-
- @Since("1.6.0")
- override def read: Reader[DCT] = new DefaultParamsReader[DCT]
+object DCT extends DefaultParamsReadable[DCT] {
@Since("1.6.0")
- override def load(path: String): DCT = read.load(path)
+ override def load(path: String): DCT = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 6d2ea675f5..9e15835429 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
*/
@Experimental
class HashingTF(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol with Writable {
+ extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("hashingTF"))
@@ -77,17 +77,11 @@ class HashingTF(override val uid: String)
}
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object HashingTF extends Readable[HashingTF] {
-
- @Since("1.6.0")
- override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF]
+object HashingTF extends DefaultParamsReadable[HashingTF] {
@Since("1.6.0")
- override def load(path: String): HashingTF = read.load(path)
+ override def load(path: String): HashingTF = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 53ad34ef12..0e00ef6f2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@Experimental
-final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable {
+final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase
+ with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("idf"))
@@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
}
override def copy(extra: ParamMap): IDF = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object IDF extends Readable[IDF] {
-
- @Since("1.6.0")
- override def read: Reader[IDF] = new DefaultParamsReader
+object IDF extends DefaultParamsReadable[IDF] {
@Since("1.6.0")
override def load(path: String): IDF = super.load(path)
@@ -110,7 +105,7 @@ object IDF extends Readable[IDF] {
class IDFModel private[ml] (
override val uid: String,
idfModel: feature.IDFModel)
- extends Model[IDFModel] with IDFBase with Writable {
+ extends Model[IDFModel] with IDFBase with MLWritable {
import IDFModel._
@@ -140,13 +135,13 @@ class IDFModel private[ml] (
def idf: Vector = idfModel.idf
@Since("1.6.0")
- override def write: Writer = new IDFModelWriter(this)
+ override def write: MLWriter = new IDFModelWriter(this)
}
@Since("1.6.0")
-object IDFModel extends Readable[IDFModel] {
+object IDFModel extends MLReadable[IDFModel] {
- private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {
+ private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter {
private case class Data(idf: Vector)
@@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] {
}
}
- private class IDFModelReader extends Reader[IDFModel] {
+ private class IDFModelReader extends MLReader[IDFModel] {
private val className = "org.apache.spark.ml.feature.IDFModel"
@@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] {
}
@Since("1.6.0")
- override def read: Reader[IDFModel] = new IDFModelReader
+ override def read: MLReader[IDFModel] = new IDFModelReader
@Since("1.6.0")
override def load(path: String): IDFModel = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 9df6b311cc..2181119f04 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.types._
@Since("1.6.0")
@Experimental
class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
- with HasInputCols with HasOutputCol with Writable {
+ with HasInputCols with HasOutputCol with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("interaction"))
@@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
}
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object Interaction extends Readable[Interaction] {
-
- @Since("1.6.0")
- override def read: Reader[Interaction] = new DefaultParamsReader[Interaction]
+object Interaction extends DefaultParamsReadable[Interaction] {
@Since("1.6.0")
- override def load(path: String): Interaction = read.load(path)
+ override def load(path: String): Interaction = super.load(path)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 24d964fae8..ed24eabb50 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
*/
@Experimental
class MinMaxScaler(override val uid: String)
- extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {
+ extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("minMaxScal"))
@@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String)
}
override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object MinMaxScaler extends Readable[MinMaxScaler] {
-
- @Since("1.6.0")
- override def read: Reader[MinMaxScaler] = new DefaultParamsReader
+object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] {
@Since("1.6.0")
override def load(path: String): MinMaxScaler = super.load(path)
@@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] (
override val uid: String,
val originalMin: Vector,
val originalMax: Vector)
- extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {
+ extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable {
import MinMaxScalerModel._
@@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] (
}
@Since("1.6.0")
- override def write: Writer = new MinMaxScalerModelWriter(this)
+ override def write: MLWriter = new MinMaxScalerModelWriter(this)
}
@Since("1.6.0")
-object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
+object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
private[MinMaxScalerModel]
- class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {
+ class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter {
private case class Data(originalMin: Vector, originalMax: Vector)
@@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
}
}
- private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {
+ private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] {
private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
@@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
}
@Since("1.6.0")
- override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader
+ override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader
@Since("1.6.0")
override def load(path: String): MinMaxScalerModel = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
index 4a17acd951..65414ecbef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
*/
@Experimental
class NGram(override val uid: String)
- extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable {
+ extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("ngram"))
@@ -66,17 +66,11 @@ class NGram(override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object NGram extends Readable[NGram] {
-
- @Since("1.6.0")
- override def read: Reader[NGram] = new DefaultParamsReader[NGram]
+object NGram extends DefaultParamsReadable[NGram] {
@Since("1.6.0")
- override def load(path: String): NGram = read.load(path)
+ override def load(path: String): NGram = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 9df6a091d5..c2d514fd96 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class Normalizer(override val uid: String)
- extends UnaryTransformer[Vector, Vector, Normalizer] with Writable {
+ extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("normalizer"))
@@ -56,17 +56,11 @@ class Normalizer(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT()
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object Normalizer extends Readable[Normalizer] {
-
- @Since("1.6.0")
- override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer]
+object Normalizer extends DefaultParamsReadable[Normalizer] {
@Since("1.6.0")
- override def load(path: String): Normalizer = read.load(path)
+ override def load(path: String): Normalizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 4e2adfaafa..d70164eaf0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
class OneHotEncoder(override val uid: String) extends Transformer
- with HasInputCol with HasOutputCol with Writable {
+ with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("oneHot"))
@@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer
}
override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object OneHotEncoder extends Readable[OneHotEncoder] {
-
- @Since("1.6.0")
- override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder]
+object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] {
@Since("1.6.0")
- override def load(path: String): OneHotEncoder = read.load(path)
+ override def load(path: String): OneHotEncoder = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 4941539832..08610593fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class PolynomialExpansion(override val uid: String)
- extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable {
+ extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("poly"))
@@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String)
override protected def outputDataType: DataType = new VectorUDT()
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
/**
@@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String)
* current index and increment it properly for sparse input.
*/
@Since("1.6.0")
-object PolynomialExpansion extends Readable[PolynomialExpansion] {
+object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
private def choose(n: Int, k: Int): Int = {
Range(n, n - k, -1).product / Range(k, 1, -1).product
@@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] {
}
@Since("1.6.0")
- override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion]
-
- @Since("1.6.0")
- override def load(path: String): PolynomialExpansion = read.load(path)
+ override def load(path: String): PolynomialExpansion = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 2da5c966d2..7bf67c6325 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
- extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable {
+ extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
@@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String)
}
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging {
+object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
/**
* Sampling from the given dataset to collect quantile statistics.
*/
@@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging {
}
@Since("1.6.0")
- override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer]
-
- @Since("1.6.0")
- override def load(path: String): QuantileDiscretizer = read.load(path)
+ override def load(path: String): QuantileDiscretizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index c115064ff3..3a735017ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType
*/
@Experimental
@Since("1.6.0")
-class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable {
+class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer
+ with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("sql"))
@@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
@Since("1.6.0")
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object SQLTransformer extends Readable[SQLTransformer] {
-
- @Since("1.6.0")
- override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer]
+object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {
@Since("1.6.0")
- override def load(path: String): SQLTransformer = read.load(path)
+ override def load(path: String): SQLTransformer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ab04e5418d..1f689c1da1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
*/
@Experimental
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
- with StandardScalerParams with Writable {
+ with StandardScalerParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("stdScal"))
@@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object StandardScaler extends Readable[StandardScaler] {
-
- @Since("1.6.0")
- override def read: Reader[StandardScaler] = new DefaultParamsReader
+object StandardScaler extends DefaultParamsReadable[StandardScaler] {
@Since("1.6.0")
override def load(path: String): StandardScaler = super.load(path)
@@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] {
class StandardScalerModel private[ml] (
override val uid: String,
scaler: feature.StandardScalerModel)
- extends Model[StandardScalerModel] with StandardScalerParams with Writable {
+ extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
import StandardScalerModel._
@@ -165,14 +159,14 @@ class StandardScalerModel private[ml] (
}
@Since("1.6.0")
- override def write: Writer = new StandardScalerModelWriter(this)
+ override def write: MLWriter = new StandardScalerModelWriter(this)
}
@Since("1.6.0")
-object StandardScalerModel extends Readable[StandardScalerModel] {
+object StandardScalerModel extends MLReadable[StandardScalerModel] {
private[StandardScalerModel]
- class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer {
+ class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
@@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] {
}
}
- private class StandardScalerModelReader extends Reader[StandardScalerModel] {
+ private class StandardScalerModelReader extends MLReader[StandardScalerModel] {
private val className = "org.apache.spark.ml.feature.StandardScalerModel"
@@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] {
}
@Since("1.6.0")
- override def read: Reader[StandardScalerModel] = new StandardScalerModelReader
+ override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader
@Since("1.6.0")
override def load(path: String): StandardScalerModel = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index f1146988dc..318808596d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -86,7 +86,7 @@ private[spark] object StopWords {
*/
@Experimental
class StopWordsRemover(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol with Writable {
+ extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("stopWords"))
@@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String)
}
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object StopWordsRemover extends Readable[StopWordsRemover] {
-
- @Since("1.6.0")
- override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover]
+object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
@Since("1.6.0")
- override def load(path: String): StopWordsRemover = read.load(path)
+ override def load(path: String): StopWordsRemover = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f16f6afc00..97a2e4f6d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
- with StringIndexerBase with Writable {
+ with StringIndexerBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("strIdx"))
@@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
}
override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object StringIndexer extends Readable[StringIndexer] {
-
- @Since("1.6.0")
- override def read: Reader[StringIndexer] = new DefaultParamsReader
+object StringIndexer extends DefaultParamsReadable[StringIndexer] {
@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
@@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] {
class StringIndexerModel (
override val uid: String,
val labels: Array[String])
- extends Model[StringIndexerModel] with StringIndexerBase with Writable {
+ extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {
import StringIndexerModel._
@@ -199,10 +193,10 @@ class StringIndexerModel (
}
@Since("1.6.0")
-object StringIndexerModel extends Readable[StringIndexerModel] {
+object StringIndexerModel extends MLReadable[StringIndexerModel] {
private[StringIndexerModel]
- class StringIndexModelWriter(instance: StringIndexerModel) extends Writer {
+ class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter {
private case class Data(labels: Array[String])
@@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] {
}
}
- private class StringIndexerModelReader extends Reader[StringIndexerModel] {
+ private class StringIndexerModelReader extends MLReader[StringIndexerModel] {
private val className = "org.apache.spark.ml.feature.StringIndexerModel"
@@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] {
}
@Since("1.6.0")
- override def read: Reader[StringIndexerModel] = new StringIndexerModelReader
+ override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader
@Since("1.6.0")
override def load(path: String): StringIndexerModel = super.load(path)
@@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] {
*/
@Experimental
class IndexToString private[ml] (override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol with Writable {
+ extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() =
this(Identifiable.randomUID("idxToStr"))
@@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String)
override def copy(extra: ParamMap): IndexToString = {
defaultCopy(extra)
}
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object IndexToString extends Readable[IndexToString] {
-
- @Since("1.6.0")
- override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString]
+object IndexToString extends DefaultParamsReadable[IndexToString] {
@Since("1.6.0")
- override def load(path: String): IndexToString = read.load(path)
+ override def load(path: String): IndexToString = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 0e4445d1e2..8ad7bbedaa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
*/
@Experimental
class Tokenizer(override val uid: String)
- extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable {
+ extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("tok"))
@@ -46,19 +46,13 @@ class Tokenizer(override val uid: String)
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object Tokenizer extends Readable[Tokenizer] {
-
- @Since("1.6.0")
- override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer]
+object Tokenizer extends DefaultParamsReadable[Tokenizer] {
@Since("1.6.0")
- override def load(path: String): Tokenizer = read.load(path)
+ override def load(path: String): Tokenizer = super.load(path)
}
/**
@@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] {
*/
@Experimental
class RegexTokenizer(override val uid: String)
- extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable {
+ extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("regexTok"))
@@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String)
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object RegexTokenizer extends Readable[RegexTokenizer] {
-
- @Since("1.6.0")
- override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer]
+object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] {
@Since("1.6.0")
- override def load(path: String): RegexTokenizer = read.load(path)
+ override def load(path: String): RegexTokenizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7e54205292..0feec05498 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
*/
@Experimental
class VectorAssembler(override val uid: String)
- extends Transformer with HasInputCols with HasOutputCol with Writable {
+ extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String)
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object VectorAssembler extends Readable[VectorAssembler] {
-
- @Since("1.6.0")
- override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler]
+object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
@Since("1.6.0")
- override def load(path: String): VectorAssembler = read.load(path)
+ override def load(path: String): VectorAssembler = super.load(path)
private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 911582b55b..5410a50bc2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType
*/
@Experimental
final class VectorSlicer(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol with Writable {
+ extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("vectorSlicer"))
@@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String)
}
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object VectorSlicer extends Readable[VectorSlicer] {
+object VectorSlicer extends DefaultParamsReadable[VectorSlicer] {
/** Return true if given feature indices are valid */
private[feature] def validIndices(indices: Array[Int]): Boolean = {
@@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] {
}
@Since("1.6.0")
- override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer]
-
- @Since("1.6.0")
- override def load(path: String): VectorSlicer = read.load(path)
+ override def load(path: String): VectorSlicer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index d92514d2e2..795b73c4c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -185,7 +185,7 @@ class ALSModel private[ml] (
val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
- extends Model[ALSModel] with ALSModelParams with Writable {
+ extends Model[ALSModel] with ALSModelParams with MLWritable {
/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
@@ -225,19 +225,19 @@ class ALSModel private[ml] (
}
@Since("1.6.0")
- override def write: Writer = new ALSModel.ALSModelWriter(this)
+ override def write: MLWriter = new ALSModel.ALSModelWriter(this)
}
@Since("1.6.0")
-object ALSModel extends Readable[ALSModel] {
+object ALSModel extends MLReadable[ALSModel] {
@Since("1.6.0")
- override def read: Reader[ALSModel] = new ALSModelReader
+ override def read: MLReader[ALSModel] = new ALSModelReader
@Since("1.6.0")
- override def load(path: String): ALSModel = read.load(path)
+ override def load(path: String): ALSModel = super.load(path)
- private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer {
+ private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
@@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] {
}
}
- private[recommendation] class ALSModelReader extends Reader[ALSModel] {
+ private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.recommendation.ALSModel"
@@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] {
* preferences rather than explicit ratings given to items.
*/
@Experimental
-class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable {
+class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams
+ with DefaultParamsWritable {
import org.apache.spark.ml.recommendation.ALS.Rating
@@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w
}
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w
* than 2 billion.
*/
@DeveloperApi
-object ALS extends Readable[ALS] with Logging {
+object ALS extends DefaultParamsReadable[ALS] with Logging {
/**
* :: DeveloperApi ::
@@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging {
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
@Since("1.6.0")
- override def read: Reader[ALS] = new DefaultParamsReader[ALS]
-
- @Since("1.6.0")
- override def load(path: String): ALS = read.load(path)
+ override def load(path: String): ALS = super.load(path)
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f7c44f0a51..7ba1a60eda 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
@Experimental
class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String)
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
- with LinearRegressionParams with Writable with Logging {
+ with LinearRegressionParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("linReg"))
@@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@Since("1.6.0")
-object LinearRegression extends Readable[LinearRegression] {
-
- @Since("1.6.0")
- override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression]
+object LinearRegression extends DefaultParamsReadable[LinearRegression] {
@Since("1.6.0")
- override def load(path: String): LinearRegression = read.load(path)
+ override def load(path: String): LinearRegression = super.load(path)
}
/**
@@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
- with LinearRegressionParams with Writable {
+ with LinearRegressionParams with MLWritable {
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
@@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] (
}
/**
- * Returns a [[Writer]] instance for this ML instance.
+ * Returns a [[MLWriter]] instance for this ML instance.
*
* For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
@@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] (
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
- override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this)
+ override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this)
}
@Since("1.6.0")
-object LinearRegressionModel extends Readable[LinearRegressionModel] {
+object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
@Since("1.6.0")
- override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader
+ override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader
@Since("1.6.0")
- override def load(path: String): LinearRegressionModel = read.load(path)
+ override def load(path: String): LinearRegressionModel = super.load(path)
- /** [[Writer]] instance for [[LinearRegressionModel]] */
+ /** [[MLWriter]] instance for [[LinearRegressionModel]] */
private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
- extends Writer with Logging {
+ extends MLWriter with Logging {
private case class Data(intercept: Double, coefficients: Vector)
@@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] {
}
}
- private class LinearRegressionModelReader extends Reader[LinearRegressionModel] {
+ private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index d8ce907af5..ff9322dba1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
/**
- * Trait for [[Writer]] and [[Reader]].
+ * Trait for [[MLWriter]] and [[MLReader]].
*/
private[util] sealed trait BaseReadWrite {
private var optionSQLContext: Option[SQLContext] = None
@@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite {
*/
@Experimental
@Since("1.6.0")
-abstract class Writer extends BaseReadWrite with Logging {
+abstract class MLWriter extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
@@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging {
}
/**
- * Trait for classes that provide [[Writer]].
+ * Trait for classes that provide [[MLWriter]].
*/
@Since("1.6.0")
-trait Writable {
+trait MLWritable {
/**
- * Returns a [[Writer]] instance for this ML instance.
+ * Returns an [[MLWriter]] instance for this ML instance.
*/
@Since("1.6.0")
- def write: Writer
+ def write: MLWriter
/**
* Saves this ML instance to the input path, a shortcut of `write.save(path)`.
@@ -130,13 +130,18 @@ trait Writable {
def save(path: String): Unit = write.save(path)
}
+private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
+
+ override def write: MLWriter = new DefaultParamsWriter(this)
+}
+
/**
* Abstract class for utility classes that can load ML instances.
* @tparam T ML instance type
*/
@Experimental
@Since("1.6.0")
-abstract class Reader[T] extends BaseReadWrite {
+abstract class MLReader[T] extends BaseReadWrite {
/**
* Loads the ML component from the input path.
@@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite {
}
/**
- * Trait for objects that provide [[Reader]].
+ * Trait for objects that provide [[MLReader]].
* @tparam T ML instance type
*/
@Experimental
@Since("1.6.0")
-trait Readable[T] {
+trait MLReadable[T] {
/**
- * Returns a [[Reader]] instance for this class.
+ * Returns an [[MLReader]] instance for this class.
*/
@Since("1.6.0")
- def read: Reader[T]
+ def read: MLReader[T]
/**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
@@ -171,13 +176,18 @@ trait Readable[T] {
def load(path: String): T = read.load(path)
}
+private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
+
+ override def read: MLReader[T] = new DefaultParamsReader
+}
+
/**
- * Default [[Writer]] implementation for transformers and estimators that contain basic
+ * Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @param instance object to save
*/
-private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
+private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
@@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter {
}
/**
- * Default [[Reader]] implementation for transformers and estimators that contain basic
+ * Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
-private[ml] class DefaultParamsReader[T] extends Reader[T] {
+private[ml] class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)