aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index f98b0b536d..b9621530ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint]) : M = {
val nfeatures: Int = input.first().features.length
- val initialWeights = Array.fill(nfeatures)(1.0)
+ val initialWeights = new Array[Double](nfeatures)
run(input, initialWeights)
}
@@ -134,15 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
throw new SparkException("Input validation failed.")
}
- // Add a extra variable consisting of all 1.0's for the intercept.
+ // Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*)))
+ input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
- Array(1.0, initialWeights:_*)
+ initialWeights.+:(1.0)
} else {
initialWeights
}