diff options
Diffstat (limited to 'mllib/src')
27 files changed, 190 insertions, 321 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f4..b0f22e042e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcb..a3cc49f7f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8..63c0658148 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70a..324353a96a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d1..4969cf4245 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a61617..6bed72164a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f5..9e15835429 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef12..0e00ef6f2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc..2181119f04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae8..ed24eabb50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd951..65414ecbef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d5..c2d514fd96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa..d70164eaf0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 4941539832..08610593fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d2..7bf67c6325 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff3..3a735017ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418d..1f689c1da1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dc..318808596d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc00..97a2e4f6d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2..8ad7bbedaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292..0feec05498 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b..5410a50bc2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e2..795b73c4c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51..7ba1a60eda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af5..ff9322dba1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895ac..12aba6bc6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce9..84d06b43d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext |