aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala70
1 files changed, 50 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 4dfd1f0ab8..8c4c5db525 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.feature
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
-
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
@@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (
require(mean.size == variance.size)
- private lazy val factor: BDV[Double] = {
- val f = BDV.zeros[Double](variance.size)
+ private lazy val factor: Array[Double] = {
+ val f = Array.ofDim[Double](variance.size)
var i = 0
while (i < f.size) {
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
@@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] (
f
}
+ // Since `shift` will be only used in `withMean` branch, we have it as
+ // `lazy val` so it will be evaluated in that branch. Note that we don't
+ // want to create this array multiple times in `transform` function.
+ private lazy val shift: Array[Double] = mean.toArray
+
/**
* Applies standardization transformation on a vector.
*
@@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] (
override def transform(vector: Vector): Vector = {
require(mean.size == vector.size)
if (withMean) {
- vector.toBreeze match {
- case dv: BDV[Double] =>
- val output = vector.toBreeze.copy
- var i = 0
- while (i < output.length) {
- output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0)
- i += 1
+ // By default, Scala generates Java methods for member variables. So every time when
+ // the member variables are accessed, `invokespecial` will be called which is expensive.
+ // This can be avoid by having a local reference of `shift`.
+ val localShift = shift
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
+ if (withStd) {
+ // Having a local reference of `factor` to avoid overhead as the comment before.
+ val localFactor = factor
+ var i = 0
+ while (i < size) {
+ values(i) = (values(i) - localShift(i)) * localFactor(i)
+ i += 1
+ }
+ } else {
+ var i = 0
+ while (i < size) {
+ values(i) -= localShift(i)
+ i += 1
+ }
}
- Vectors.fromBreeze(output)
+ Vectors.dense(values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else if (withStd) {
- vector.toBreeze match {
- case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor)
- case sv: BSV[Double] =>
+ // Having a local reference of `factor` to avoid overhead as the comment before.
+ val localFactor = factor
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
+ var i = 0
+ while(i < size) {
+ values(i) *= localFactor(i)
+ i += 1
+ }
+ Vectors.dense(values)
+ case sv: SparseVector =>
// For sparse vector, the `index` array inside sparse vector object will not be changed,
// so we can re-use it to save memory.
- val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
+ val indices = sv.indices
+ val values = sv.values.clone()
+ val nnz = values.size
var i = 0
- while (i < output.data.length) {
- output.data(i) *= factor(output.index(i))
+ while (i < nnz) {
+ values(i) *= localFactor(indices(i))
i += 1
}
- Vectors.fromBreeze(output)
+ Vectors.sparse(sv.size, indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {