aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala25
1 files changed, 8 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index d7eaa5a926..3d64f7f296 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.Logging
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
@@ -122,16 +123,6 @@ private[ml] class WeightedLeastSquares(
private[ml] object WeightedLeastSquares {
/**
- * Case class for weighted observations.
- * @param w weight, must be positive
- * @param a features
- * @param b label
- */
- case class Instance(w: Double, a: Vector, b: Double) {
- require(w >= 0.0, s"Weight cannot be negative: $w.")
- }
-
- /**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
// TODO: consolidate aggregates for summary statistics
@@ -168,8 +159,8 @@ private[ml] object WeightedLeastSquares {
* Adds an instance.
*/
def add(instance: Instance): this.type = {
- val Instance(w, a, b) = instance
- val ak = a.size
+ val Instance(l, w, f) = instance
+ val ak = f.size
if (!initialized) {
init(ak)
}
@@ -177,11 +168,11 @@ private[ml] object WeightedLeastSquares {
count += 1L
wSum += w
wwSum += w * w
- bSum += w * b
- bbSum += w * b * b
- BLAS.axpy(w, a, aSum)
- BLAS.axpy(w * b, a, abSum)
- BLAS.spr(w, a, aaSum)
+ bSum += w * l
+ bbSum += w * l * l
+ BLAS.axpy(w, f, aSum)
+ BLAS.axpy(w * l, f, abSum)
+ BLAS.spr(w, f, aaSum)
this
}