aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala17
1 files changed, 10 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 36c242bb5f..3ebb78f792 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -20,14 +20,14 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
- OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
@@ -59,9 +59,12 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* - L2 + L1 (elastic net)
*/
@AlphaComponent
-class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
+class LinearRegression(override val uid: String)
+ extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams with Logging {
+ def this() = this(Identifiable.randomUID("linReg"))
+
/**
* Set the regularization parameter.
* Default is 0.0.
@@ -128,7 +131,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
s"and the intercept will be the mean of the label; as a result, training is not needed.")
if (handlePersistence) instances.unpersist()
- return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean)
+ return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
}
val featuresMean = summarizer.mean.toArray
@@ -182,7 +185,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
if (handlePersistence) instances.unpersist()
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
- new LinearRegressionModel(this, weights.compressed, intercept)
+ copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
}
}
@@ -193,7 +196,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
*/
@AlphaComponent
class LinearRegressionModel private[ml] (
- override val parent: LinearRegression,
+ override val uid: String,
val weights: Vector,
val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
@@ -204,7 +207,7 @@ class LinearRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LinearRegressionModel = {
- copyValues(new LinearRegressionModel(parent, weights, intercept), extra)
+ copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
}
}