aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOctavian Geagla <ogeagla@gmail.com>2015-02-01 09:21:14 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-01 09:21:14 -0800
commitbdb0680d37614ccdec8933d2dec53793825e43d7 (patch)
tree4a665b0a605a63b8b19886022d4f5246e3fdedc4
parent80bd715a3e2c39449ed5e4d4e7058d75281ef3cb (diff)
downloadspark-bdb0680d37614ccdec8933d2dec53793825e43d7.tar.gz
spark-bdb0680d37614ccdec8933d2dec53793825e43d7.tar.bz2
spark-bdb0680d37614ccdec8933d2dec53793825e43d7.zip
[SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use
This seems complete, the duplication of tests for provided means/variances might be overkill, would appreciate some feedback. Author: Octavian Geagla <ogeagla@gmail.com> Closes #4140 from ogeagla/SPARK-5207 and squashes the following commits: fa64dfa [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel to take stddev instead of variance 9078fe0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] Incorporate code review feedback: change arg ordering, add dev api annotations, do better null checking, add another test and some doc for this. 997d2e0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] make withMean and withStd public, add constructor which uses defaults, un-refactor test class 64408a4 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel contructor to not be private to mllib, added tests for newly-exposed functionality
-rw-r--r--docs/mllib-feature-extraction.md11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala71
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala258
3 files changed, 267 insertions, 73 deletions
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 197bc77d50..d4a61a7fbf 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -240,11 +240,11 @@ following parameters in the constructor:
* `withMean` False by default. Centers the data with mean before scaling. It will build a dense
output, so this does not work on sparse input and will raise an exception.
-* `withStd` True by default. Scales the data to unit variance.
+* `withStd` True by default. Scales the data to unit standard deviation.
We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in
`StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then
-return a model which can transform the input dataset into unit variance and/or zero mean features
+return a model which can transform the input dataset into unit standard deviation and/or zero mean features
depending how we configure the `StandardScaler`.
This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer)
@@ -257,7 +257,7 @@ for that feature.
### Example
The example below demonstrates how to load a dataset in libsvm format, and standardize the features
-so that the new features have unit variance and/or zero mean.
+so that the new features have unit standard deviation and/or zero mean.
<div class="codetabs">
<div data-lang="scala">
@@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val scaler1 = new StandardScaler().fit(data.map(x => x.features))
val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))
+// scaler3 is an identical model to scaler2, and will produce identical transformations
+val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)
// data1 will be unit variance.
val data1 = data.map(x => (x.label, scaler1.transform(x.features)))
@@ -294,6 +296,9 @@ features = data.map(lambda x: x.features)
scaler1 = StandardScaler().fit(features)
scaler2 = StandardScaler(withMean=True, withStd=True).fit(features)
+# scaler3 is an identical model to scaler2, and will produce identical transformations
+scaler3 = StandardScalerModel(scaler2.std, scaler2.mean)
+
# data1 will be unit variance.
data1 = label.zip(scaler1.transform(features))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 2f2c6f94e9..6ae6917eae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -18,14 +18,14 @@
package org.apache.spark.mllib.feature
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
/**
* :: Experimental ::
- * Standardizes features by removing the mean and scaling to unit variance using column summary
+ * Standardizes features by removing the mean and scaling to unit std using column summary
* statistics on the samples in the training set.
*
* @param withMean False by default. Centers the data with mean before scaling. It will build a
@@ -52,7 +52,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
- new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
+ new StandardScalerModel(
+ Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
+ summary.mean,
+ withStd,
+ withMean)
}
}
@@ -60,28 +64,43 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
* :: Experimental ::
* Represents a StandardScaler model that can transform vectors.
*
- * @param withMean whether to center the data before scaling
- * @param withStd whether to scale the data to have unit standard deviation
+ * @param std column standard deviation values
* @param mean column mean values
- * @param variance column variance values
+ * @param withStd whether to scale the data to have unit standard deviation
+ * @param withMean whether to center the data before scaling
*/
@Experimental
-class StandardScalerModel private[mllib] (
- val withMean: Boolean,
- val withStd: Boolean,
+class StandardScalerModel (
+ val std: Vector,
val mean: Vector,
- val variance: Vector) extends VectorTransformer {
-
- require(mean.size == variance.size)
+ var withStd: Boolean,
+ var withMean: Boolean) extends VectorTransformer {
- private lazy val factor: Array[Double] = {
- val f = Array.ofDim[Double](variance.size)
- var i = 0
- while (i < f.size) {
- f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
- i += 1
+ def this(std: Vector, mean: Vector) {
+ this(std, mean, withStd = std != null, withMean = mean != null)
+ require(this.withStd || this.withMean,
+ "at least one of std or mean vectors must be provided")
+ if (this.withStd && this.withMean) {
+ require(mean.size == std.size,
+ "mean and std vectors must have equal size if both are provided")
}
- f
+ }
+
+ def this(std: Vector) = this(std, null)
+
+ @DeveloperApi
+ def setWithMean(withMean: Boolean): this.type = {
+ require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null")
+ this.withMean = withMean
+ this
+ }
+
+ @DeveloperApi
+ def setWithStd(withStd: Boolean): this.type = {
+ require(!(withStd && this.std == null),
+ "cannot set withStd to true while std is null")
+ this.withStd = withStd
+ this
}
// Since `shift` will be only used in `withMean` branch, we have it as
@@ -93,8 +112,8 @@ class StandardScalerModel private[mllib] (
* Applies standardization transformation on a vector.
*
* @param vector Vector to be standardized.
- * @return Standardized vector. If the variance of a column is zero, it will return default `0.0`
- * for the column with zero variance.
+ * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
+ * for the column with zero std.
*/
override def transform(vector: Vector): Vector = {
require(mean.size == vector.size)
@@ -108,11 +127,9 @@ class StandardScalerModel private[mllib] (
val values = vs.clone()
val size = values.size
if (withStd) {
- // Having a local reference of `factor` to avoid overhead as the comment before.
- val localFactor = factor
var i = 0
while (i < size) {
- values(i) = (values(i) - localShift(i)) * localFactor(i)
+ values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
i += 1
}
} else {
@@ -126,15 +143,13 @@ class StandardScalerModel private[mllib] (
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else if (withStd) {
- // Having a local reference of `factor` to avoid overhead as the comment before.
- val localFactor = factor
vector match {
case DenseVector(vs) =>
val values = vs.clone()
val size = values.size
var i = 0
while(i < size) {
- values(i) *= localFactor(i)
+ values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
i += 1
}
Vectors.dense(values)
@@ -145,7 +160,7 @@ class StandardScalerModel private[mllib] (
val nnz = values.size
var i = 0
while (i < nnz) {
- values(i) *= localFactor(indices(i))
+ values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
i += 1
}
Vectors.sparse(size, indices, values)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index e9e510b6f5..7f94564b2a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -27,23 +27,109 @@ import org.apache.spark.rdd.RDD
class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
+ // When the input data is all constant, the variance is zero. The standardization against
+ // zero variance is not well-defined, but we decide to just set it into zero here.
+ val constantData = Array(
+ Vectors.dense(2.0),
+ Vectors.dense(2.0),
+ Vectors.dense(2.0)
+ )
+
+ val sparseData = Array(
+ Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
+ Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
+ Vectors.sparse(3, Seq((1, -5.1))),
+ Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
+ Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
+ Vectors.sparse(3, Seq((1, 1.9)))
+ )
+
+ val denseData = Array(
+ Vectors.dense(-2.0, 2.3, 0),
+ Vectors.dense(0.0, -1.0, -3.0),
+ Vectors.dense(0.0, -5.1, 0.0),
+ Vectors.dense(3.8, 0.0, 1.9),
+ Vectors.dense(1.7, -0.6, 0.0),
+ Vectors.dense(0.0, 1.9, 0.0)
+ )
+
private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
data.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
}
+ test("Standardization with dense input when means and stds are provided") {
+
+ val dataRDD = sc.parallelize(denseData, 3)
+
+ val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+ val standardizer2 = new StandardScaler()
+ val standardizer3 = new StandardScaler(withMean = true, withStd = false)
+
+ val model1 = standardizer1.fit(dataRDD)
+ val model2 = standardizer2.fit(dataRDD)
+ val model3 = standardizer3.fit(dataRDD)
+
+ val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+ val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+ val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+ val data1 = denseData.map(equivalentModel1.transform)
+ val data2 = denseData.map(equivalentModel2.transform)
+ val data3 = denseData.map(equivalentModel3.transform)
+
+ val data1RDD = equivalentModel1.transform(dataRDD)
+ val data2RDD = equivalentModel2.transform(dataRDD)
+ val data3RDD = equivalentModel3.transform(dataRDD)
+
+ val summary = computeSummary(dataRDD)
+ val summary1 = computeSummary(data1RDD)
+ val summary2 = computeSummary(data2RDD)
+ val summary3 = computeSummary(data3RDD)
+
+ assert((denseData, data1, data1RDD.collect()).zipped.forall {
+ case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+ case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+ case _ => false
+ }, "The vector type should be preserved after standardization.")
+
+ assert((denseData, data2, data2RDD.collect()).zipped.forall {
+ case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+ case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+ case _ => false
+ }, "The vector type should be preserved after standardization.")
+
+ assert((denseData, data3, data3RDD.collect()).zipped.forall {
+ case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+ case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+ case _ => false
+ }, "The vector type should be preserved after standardization.")
+
+ assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+ assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+ assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+
+ assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+ assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+ assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary3.variance ~== summary.variance absTol 1E-5)
+
+ assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5)
+ assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5)
+ assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5)
+ assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5)
+ assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5)
+ assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5)
+ }
+
test("Standardization with dense input") {
- val data = Array(
- Vectors.dense(-2.0, 2.3, 0),
- Vectors.dense(0.0, -1.0, -3.0),
- Vectors.dense(0.0, -5.1, 0.0),
- Vectors.dense(3.8, 0.0, 1.9),
- Vectors.dense(1.7, -0.6, 0.0),
- Vectors.dense(0.0, 1.9, 0.0)
- )
- val dataRDD = sc.parallelize(data, 3)
+ val dataRDD = sc.parallelize(denseData, 3)
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
val standardizer2 = new StandardScaler()
@@ -53,9 +139,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)
- val data1 = data.map(model1.transform)
- val data2 = data.map(model2.transform)
- val data3 = data.map(model3.transform)
+ val data1 = denseData.map(model1.transform)
+ val data2 = denseData.map(model2.transform)
+ val data3 = denseData.map(model3.transform)
val data1RDD = model1.transform(dataRDD)
val data2RDD = model2.transform(dataRDD)
@@ -66,19 +152,19 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
val summary2 = computeSummary(data2RDD)
val summary3 = computeSummary(data3RDD)
- assert((data, data1, data1RDD.collect()).zipped.forall {
+ assert((denseData, data1, data1RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
case _ => false
}, "The vector type should be preserved after standardization.")
- assert((data, data2, data2RDD.collect()).zipped.forall {
+ assert((denseData, data2, data2RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
case _ => false
}, "The vector type should be preserved after standardization.")
- assert((data, data3, data3RDD.collect()).zipped.forall {
+ assert((denseData, data3, data3RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
case _ => false
@@ -106,17 +192,58 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
}
+ test("Standardization with sparse input when means and stds are provided") {
+
+ val dataRDD = sc.parallelize(sparseData, 3)
+
+ val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+ val standardizer2 = new StandardScaler()
+ val standardizer3 = new StandardScaler(withMean = true, withStd = false)
+
+ val model1 = standardizer1.fit(dataRDD)
+ val model2 = standardizer2.fit(dataRDD)
+ val model3 = standardizer3.fit(dataRDD)
+
+ val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+ val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+ val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+ val data2 = sparseData.map(equivalentModel2.transform)
+
+ withClue("Standardization with mean can not be applied on sparse input.") {
+ intercept[IllegalArgumentException] {
+ sparseData.map(equivalentModel1.transform)
+ }
+ }
+
+ withClue("Standardization with mean can not be applied on sparse input.") {
+ intercept[IllegalArgumentException] {
+ sparseData.map(equivalentModel3.transform)
+ }
+ }
+
+ val data2RDD = equivalentModel2.transform(dataRDD)
+
+ val summary = computeSummary(data2RDD)
+
+ assert((sparseData, data2, data2RDD.collect()).zipped.forall {
+ case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+ case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+ case _ => false
+ }, "The vector type should be preserved after standardization.")
+
+ assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+
+ assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+ assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
+ assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
+ }
+
test("Standardization with sparse input") {
- val data = Array(
- Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
- Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
- Vectors.sparse(3, Seq((1, -5.1))),
- Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
- Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
- Vectors.sparse(3, Seq((1, 1.9)))
- )
- val dataRDD = sc.parallelize(data, 3)
+ val dataRDD = sc.parallelize(sparseData, 3)
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
val standardizer2 = new StandardScaler()
@@ -126,25 +253,26 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)
- val data2 = data.map(model2.transform)
+ val data2 = sparseData.map(model2.transform)
withClue("Standardization with mean can not be applied on sparse input.") {
intercept[IllegalArgumentException] {
- data.map(model1.transform)
+ sparseData.map(model1.transform)
}
}
withClue("Standardization with mean can not be applied on sparse input.") {
intercept[IllegalArgumentException] {
- data.map(model3.transform)
+ sparseData.map(model3.transform)
}
}
val data2RDD = model2.transform(dataRDD)
- val summary2 = computeSummary(data2RDD)
- assert((data, data2, data2RDD.collect()).zipped.forall {
+ val summary = computeSummary(data2RDD)
+
+ assert((sparseData, data2, data2RDD.collect()).zipped.forall {
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
case _ => false
@@ -152,23 +280,44 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
- assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
- assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+ assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+ assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
}
+ test("Standardization with constant input when means and stds are provided") {
+
+ val dataRDD = sc.parallelize(constantData, 2)
+
+ val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+ val standardizer2 = new StandardScaler(withMean = true, withStd = false)
+ val standardizer3 = new StandardScaler(withMean = false, withStd = true)
+
+ val model1 = standardizer1.fit(dataRDD)
+ val model2 = standardizer2.fit(dataRDD)
+ val model3 = standardizer3.fit(dataRDD)
+
+ val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+ val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+ val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+ val data1 = constantData.map(equivalentModel1.transform)
+ val data2 = constantData.map(equivalentModel2.transform)
+ val data3 = constantData.map(equivalentModel3.transform)
+
+ assert(data1.forall(_.toArray.forall(_ == 0.0)),
+ "The variance is zero, so the transformed result should be 0.0")
+ assert(data2.forall(_.toArray.forall(_ == 0.0)),
+ "The variance is zero, so the transformed result should be 0.0")
+ assert(data3.forall(_.toArray.forall(_ == 0.0)),
+ "The variance is zero, so the transformed result should be 0.0")
+ }
+
test("Standardization with constant input") {
- // When the input data is all constant, the variance is zero. The standardization against
- // zero variance is not well-defined, but we decide to just set it into zero here.
- val data = Array(
- Vectors.dense(2.0),
- Vectors.dense(2.0),
- Vectors.dense(2.0)
- )
- val dataRDD = sc.parallelize(data, 2)
+ val dataRDD = sc.parallelize(constantData, 2)
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
val standardizer2 = new StandardScaler(withMean = true, withStd = false)
@@ -178,9 +327,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
val model2 = standardizer2.fit(dataRDD)
val model3 = standardizer3.fit(dataRDD)
- val data1 = data.map(model1.transform)
- val data2 = data.map(model2.transform)
- val data3 = data.map(model3.transform)
+ val data1 = constantData.map(model1.transform)
+ val data2 = constantData.map(model2.transform)
+ val data3 = constantData.map(model3.transform)
assert(data1.forall(_.toArray.forall(_ == 0.0)),
"The variance is zero, so the transformed result should be 0.0")
@@ -190,4 +339,29 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
"The variance is zero, so the transformed result should be 0.0")
}
+ test("StandardScalerModel argument nulls are properly handled") {
+
+ withClue("model needs at least one of std or mean vectors") {
+ intercept[IllegalArgumentException] {
+ val model = new StandardScalerModel(null, null)
+ }
+ }
+ withClue("model needs std to set withStd to true") {
+ intercept[IllegalArgumentException] {
+ val model = new StandardScalerModel(null, Vectors.dense(0.0))
+ model.setWithStd(true)
+ }
+ }
+ withClue("model needs mean to set withMean to true") {
+ intercept[IllegalArgumentException] {
+ val model = new StandardScalerModel(Vectors.dense(0.0), null)
+ model.setWithMean(true)
+ }
+ }
+ withClue("model needs std and mean vectors to be equal size when both are provided") {
+ intercept[IllegalArgumentException] {
+ val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0))
+ }
+ }
+ }
}