aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-08-27 08:48:56 +0100
committerSean Owen <sowen@cloudera.com>2016-08-27 08:48:56 +0100
commite07baf14120bc94b783649dabf5fffea58bff0de (patch)
tree557979925874c18034e793057a9706c3ee6924fa /mllib
parent9fbced5b25c2f24d50c50516b4b7737f7e3eaf86 (diff)
downloadspark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.gz
spark-e07baf14120bc94b783649dabf5fffea58bff0de.tar.bz2
spark-e07baf14120bc94b783649dabf5fffea58bff0de.zip
[SPARK-17001][ML] Enable standardScaler to standardize sparse vectors when withMean=True
## What changes were proposed in this pull request? Allow centering / mean scaling of sparse vectors in StandardScaler, if requested. This is for compatibility with `VectorAssembler` in common usages. ## How was this patch tested? Jenkins tests, including new caes to reflect the new behavior. Author: Sean Owen <sowen@cloudera.com> Closes #14663 from srowen/SPARK-17001.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala41
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala69
4 files changed, 76 insertions, 53 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 2494cf51a2..d76d556280 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -41,8 +41,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/**
* Whether to center the data with mean before scaling.
- * It will build a dense output, so this does not work on sparse input
- * and will raise an exception.
+ * It will build a dense output, so take care when applying to sparse input.
* Default: false
* @group param
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 3e86c6c59c..7667936a3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD
* which is computed as the square root of the unbiased sample variance.
*
* @param withMean False by default. Centers the data with mean before scaling. It will build a
- * dense output, so this does not work on sparse input and will raise an exception.
+ * dense output, so take care when applying to sparse input.
* @param withStd True by default. Scales the data to unit standard deviation.
*/
@Since("1.1.0")
@@ -139,26 +139,27 @@ class StandardScalerModel @Since("1.3.0") (
// the member variables are accessed, `invokespecial` will be called which is expensive.
// This can be avoid by having a local reference of `shift`.
val localShift = shift
- vector match {
- case DenseVector(vs) =>
- val values = vs.clone()
- val size = values.length
- if (withStd) {
- var i = 0
- while (i < size) {
- values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
- i += 1
- }
- } else {
- var i = 0
- while (i < size) {
- values(i) -= localShift(i)
- i += 1
- }
- }
- Vectors.dense(values)
- case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+ // Must have a copy of the values since it will be modified in place
+ val values = vector match {
+ // specially handle DenseVector because its toArray does not clone already
+ case d: DenseVector => d.values.clone()
+ case v: Vector => v.toArray
+ }
+ val size = values.length
+ if (withStd) {
+ var i = 0
+ while (i < size) {
+ values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
+ i += 1
+ }
+ } else {
+ var i = 0
+ while (i < size) {
+ values(i) -= localShift(i)
+ i += 1
+ }
}
+ Vectors.dense(values)
} else if (withStd) {
vector match {
case DenseVector(vs) =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 2243a0f972..827ecb0fad 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -114,6 +114,22 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
assertResult(standardScaler3.transform(df3))
}
+ test("sparse data and withMean") {
+ val someSparseData = Array(
+ Vectors.sparse(3, Array(0, 1), Array(-2.0, 2.3)),
+ Vectors.sparse(3, Array(1, 2), Array(-5.1, 1.0)),
+ Vectors.dense(1.7, -0.6, 3.3)
+ )
+ val df = spark.createDataFrame(someSparseData.zip(resWithMean)).toDF("features", "expected")
+ val standardScaler = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("standardized_features")
+ .setWithMean(true)
+ .setWithStd(false)
+ .fit(df)
+ assertResult(standardScaler.transform(df))
+ }
+
test("StandardScaler read/write") {
val t = new StandardScaler()
.setInputCol("myInputCol")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index b4e26b2aeb..a5769631e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -207,23 +207,17 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+ val data1 = sparseData.map(equivalentModel1.transform)
val data2 = sparseData.map(equivalentModel2.transform)
+ val data3 = sparseData.map(equivalentModel3.transform)
- withClue("Standardization with mean can not be applied on sparse input.") {
- intercept[IllegalArgumentException] {
- sparseData.map(equivalentModel1.transform)
- }
- }
-
- withClue("Standardization with mean can not be applied on sparse input.") {
- intercept[IllegalArgumentException] {
- sparseData.map(equivalentModel3.transform)
- }
- }
-
+ val data1RDD = equivalentModel1.transform(dataRDD)
val data2RDD = equivalentModel2.transform(dataRDD)
+ val data3RDD = equivalentModel3.transform(dataRDD)
- val summary = computeSummary(data2RDD)
+ val summary1 = computeSummary(data1RDD)
+ val summary2 = computeSummary(data2RDD)
+ val summary3 = computeSummary(data3RDD)
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
@@ -231,13 +225,23 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
case _ => false
}, "The vector type should be preserved after standardization.")
+ assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+ assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
- assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
- assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary3.variance !~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(data1(4) ~== Vectors.dense(0.56854, -0.069068, 0.116377) absTol 1E-5)
+ assert(data1(5) ~== Vectors.dense(-0.296998, 0.872775, 0.116377) absTol 1E-5)
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
+ assert(data3(4) ~== Vectors.dense(1.116666, -0.183333, 0.183333) absTol 1E-5)
+ assert(data3(5) ~== Vectors.dense(-0.583333, 2.316666, 0.183333) absTol 1E-5)
}
test("Standardization with sparse input") {
@@ -252,24 +256,17 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)
+ val data1 = sparseData.map(model1.transform)
val data2 = sparseData.map(model2.transform)
+ val data3 = sparseData.map(model3.transform)
- withClue("Standardization with mean can not be applied on sparse input.") {
- intercept[IllegalArgumentException] {
- sparseData.map(model1.transform)
- }
- }
-
- withClue("Standardization with mean can not be applied on sparse input.") {
- intercept[IllegalArgumentException] {
- sparseData.map(model3.transform)
- }
- }
-
+ val data1RDD = model1.transform(dataRDD)
val data2RDD = model2.transform(dataRDD)
+ val data3RDD = model3.transform(dataRDD)
-
- val summary = computeSummary(data2RDD)
+ val summary1 = computeSummary(data1RDD)
+ val summary2 = computeSummary(data2RDD)
+ val summary3 = computeSummary(data3RDD)
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
@@ -277,13 +274,23 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
case _ => false
}, "The vector type should be preserved after standardization.")
+ assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+ assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
- assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
- assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary3.variance !~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(data1(4) ~== Vectors.dense(0.56854, -0.069068, 0.116377) absTol 1E-5)
+ assert(data1(5) ~== Vectors.dense(-0.296998, 0.872775, 0.116377) absTol 1E-5)
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
+ assert(data3(4) ~== Vectors.dense(1.116666, -0.183333, 0.183333) absTol 1E-5)
+ assert(data3(5) ~== Vectors.dense(-0.583333, 2.316666, 0.183333) absTol 1E-5)
}
test("Standardization with constant input when means and stds are provided") {