diff options
author | Sean Owen <sowen@cloudera.com> | 2016-08-27 08:48:56 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-27 08:48:56 +0100 |
commit | e07baf14120bc94b783649dabf5fffea58bff0de (patch) | |
tree | 557979925874c18034e793057a9706c3ee6924fa /mllib/src/main | |
parent | 9fbced5b25c2f24d50c50516b4b7737f7e3eaf86 (diff) | |
download | spark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.gz spark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.bz2 spark-e07baf14120bc94b783649dabf5fffea58bff0de.zip |
[SPARK-17001][ML] Enable standardScaler to standardize sparse vectors when withMean=True
## What changes were proposed in this pull request?
Allow centering / mean scaling of sparse vectors in StandardScaler, if requested. This is for compatibility with `VectorAssembler` in common usages.
## How was this patch tested?
Jenkins tests, including new caes to reflect the new behavior.
Author: Sean Owen <sowen@cloudera.com>
Closes #14663 from srowen/SPARK-17001.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala | 3 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala | 41 |
2 files changed, 22 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 2494cf51a2..d76d556280 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -41,8 +41,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Whether to center the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input - * and will raise an exception. + * It will build a dense output, so take care when applying to sparse input. * Default: false * @group param */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3e86c6c59c..7667936a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD * which is computed as the square root of the unbiased sample variance. * * @param withMean False by default. Centers the data with mean before scaling. It will build a - * dense output, so this does not work on sparse input and will raise an exception. + * dense output, so take care when applying to sparse input. * @param withStd True by default. Scales the data to unit standard deviation. */ @Since("1.1.0") @@ -139,26 +139,27 @@ class StandardScalerModel @Since("1.3.0") ( // the member variables are accessed, `invokespecial` will be called which is expensive. // This can be avoid by having a local reference of `shift`. val localShift = shift - vector match { - case DenseVector(vs) => - val values = vs.clone() - val size = values.length - if (withStd) { - var i = 0 - while (i < size) { - values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 - i += 1 - } - } else { - var i = 0 - while (i < size) { - values(i) -= localShift(i) - i += 1 - } - } - Vectors.dense(values) - case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + // Must have a copy of the values since it will be modified in place + val values = vector match { + // specially handle DenseVector because its toArray does not clone already + case d: DenseVector => d.values.clone() + case v: Vector => v.toArray + } + val size = values.length + if (withStd) { + var i = 0 + while (i < size) { + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } } + Vectors.dense(values) } else if (withStd) { vector match { case DenseVector(vs) => |