aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala29
1 files changed, 19 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index dfad25d57c..a9c2e23717 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -17,10 +17,10 @@
package org.apache.spark.mllib.feature
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm}
+import breeze.linalg.{norm => brzNorm}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
/**
* :: Experimental ::
@@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer {
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
*/
override def transform(vector: Vector): Vector = {
- var norm = brzNorm(vector.toBreeze, p)
+ val norm = brzNorm(vector.toBreeze, p)
if (norm != 0.0) {
// For dense vector, we've to allocate new memory for new output vector.
// However, for sparse vector, the `index` array will not be changed,
// so we can re-use it to save memory.
- vector.toBreeze match {
- case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm)
- case sv: BSV[Double] =>
- val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
var i = 0
- while (i < output.data.length) {
- output.data(i) /= norm
+ while (i < size) {
+ values(i) /= norm
i += 1
}
- Vectors.fromBreeze(output)
+ Vectors.dense(values)
+ case sv: SparseVector =>
+ val values = sv.values.clone()
+ val nnz = values.size
+ var i = 0
+ while (i < nnz) {
+ values(i) /= norm
+ i += 1
+ }
+ Vectors.sparse(sv.size, sv.indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {