aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala29
1 files changed, 14 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 71c2533bcb..a3cc49f7f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
@Experimental
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
- with LogisticRegressionParams with Writable with Logging {
+ with LogisticRegressionParams with DefaultParamsWritable with Logging {
def this() = this(Identifiable.randomUID("logreg"))
@@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
-
- override def write: Writer = new DefaultParamsWriter(this)
}
-object LogisticRegression extends Readable[LogisticRegression] {
- override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression]
+object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
+
+ override def load(path: String): LogisticRegression = super.load(path)
}
/**
@@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
- with LogisticRegressionParams with Writable {
+ with LogisticRegressionParams with MLWritable {
@deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients
@@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] (
}
/**
- * Returns a [[Writer]] instance for this ML instance.
+ * Returns a [[MLWriter]] instance for this ML instance.
*
* For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
* This also does not save the [[parent]] currently.
*/
- override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
+ override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
}
-object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
+object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
- override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader
+ override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
- override def load(path: String): LogisticRegressionModel = read.load(path)
+ override def load(path: String): LogisticRegressionModel = super.load(path)
- /** [[Writer]] instance for [[LogisticRegressionModel]] */
+ /** [[MLWriter]] instance for [[LogisticRegressionModel]] */
private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
- extends Writer with Logging {
+ extends MLWriter with Logging {
private case class Data(
numClasses: Int,
@@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
}
private[classification] class LogisticRegressionModelReader
- extends Reader[LogisticRegressionModel] {
+ extends MLReader[LogisticRegressionModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"