diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcb..a3cc49f7f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" |