aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala21
1 files changed, 5 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 8cca926f1c..fe41863bce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.regression
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.util.MLUtils._
/**
* :: DeveloperApi ::
@@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
run(input, initialWeights)
}
- /** Prepends one to the input vector. */
- private def prependOne(vector: Vector): Vector = {
- val vector1 = vector.toBreeze match {
- case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
- case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
- case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
- }
- Vectors.fromBreeze(vector1)
- }
-
/**
* Run the algorithm with the configured parameters on an input RDD
* of LabeledPoint entries starting from the initial weights provided.
@@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features)))
+ input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
- prependOne(initialWeights)
+ appendBias(initialWeights)
} else {
initialWeights
}
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
- val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
+ val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
val weights =
if (addIntercept) {
- Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size))
+ Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
} else {
weightsWithIntercept
}