aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-08-27 08:48:56 +0100
committerSean Owen <sowen@cloudera.com>2016-08-27 08:48:56 +0100
commite07baf14120bc94b783649dabf5fffea58bff0de (patch)
tree557979925874c18034e793057a9706c3ee6924fa /mllib/src/main
parent9fbced5b25c2f24d50c50516b4b7737f7e3eaf86 (diff)
downloadspark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.gz
spark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.bz2
spark-e07baf14120bc94b783649dabf5fffea58bff0de.zip
[SPARK-17001][ML] Enable standardScaler to standardize sparse vectors when withMean=True
## What changes were proposed in this pull request? Allow centering / mean scaling of sparse vectors in StandardScaler, if requested. This is for compatibility with `VectorAssembler` in common usages. ## How was this patch tested? Jenkins tests, including new caes to reflect the new behavior. Author: Sean Owen <sowen@cloudera.com> Closes #14663 from srowen/SPARK-17001.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala41
2 files changed, 22 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 2494cf51a2..d76d556280 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -41,8 +41,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/**
* Whether to center the data with mean before scaling.
- * It will build a dense output, so this does not work on sparse input
- * and will raise an exception.
+ * It will build a dense output, so take care when applying to sparse input.
* Default: false
* @group param
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 3e86c6c59c..7667936a3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD
* which is computed as the square root of the unbiased sample variance.
*
* @param withMean False by default. Centers the data with mean before scaling. It will build a
- * dense output, so this does not work on sparse input and will raise an exception.
+ * dense output, so take care when applying to sparse input.
* @param withStd True by default. Scales the data to unit standard deviation.
*/
@Since("1.1.0")
@@ -139,26 +139,27 @@ class StandardScalerModel @Since("1.3.0") (
// the member variables are accessed, `invokespecial` will be called which is expensive.
// This can be avoid by having a local reference of `shift`.
val localShift = shift
- vector match {
- case DenseVector(vs) =>
- val values = vs.clone()
- val size = values.length
- if (withStd) {
- var i = 0
- while (i < size) {
- values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
- i += 1
- }
- } else {
- var i = 0
- while (i < size) {
- values(i) -= localShift(i)
- i += 1
- }
- }
- Vectors.dense(values)
- case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+ // Must have a copy of the values since it will be modified in place
+ val values = vector match {
+ // specially handle DenseVector because its toArray does not clone already
+ case d: DenseVector => d.values.clone()
+ case v: Vector => v.toArray
+ }
+ val size = values.length
+ if (withStd) {
+ var i = 0
+ while (i < size) {
+ values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
+ i += 1
+ }
+ } else {
+ var i = 0
+ while (i < size) {
+ values(i) -= localShift(i)
+ i += 1
+ }
}
+ Vectors.dense(values)
} else if (withStd) {
vector match {
case DenseVector(vs) =>