From ae58aea2d1435b5bb011e68127e1bcddc2edf5b2 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 3 Aug 2014 21:39:21 -0700 Subject: SPARK-2272 [MLlib] Feature scaling which standardizes the range of independent variables or features of data Feature scaling is a method used to standardize the range of independent variables or features of data. In data processing, it is generally performed during the data preprocessing step. In this work, a trait called `VectorTransformer` is defined for generic transformation on a vector. It contains one method to be implemented, `transform` which applies transformation on a vector. There are two implementations of `VectorTransformer` now, and they all can be easily extended with PMML transformation support. 1) `StandardScaler` - Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. 2) `Normalizer` - Normalizes samples individually to unit L^n norm Author: DB Tsai Closes #1207 from dbtsai/dbtsai-feature-scaling and squashes the following commits: 78c15d3 [DB Tsai] Alpine Data Labs --- .../apache/spark/mllib/feature/Normalizer.scala | 76 ++++++++ .../spark/mllib/feature/StandardScaler.scala | 119 ++++++++++++ .../spark/mllib/feature/VectorTransformer.scala | 51 ++++++ .../spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- .../spark/mllib/feature/NormalizerSuite.scala | 120 +++++++++++++ .../spark/mllib/feature/StandardScalerSuite.scala | 200 +++++++++++++++++++++ 6 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala new file mode 100644 index 0000000000..ea9fd0a80d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * :: DeveloperApi :: + * Normalizes samples individually to unit L^p^ norm + * + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * sum(abs(vector).^p^)^(1/p)^ as norm. + * + * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. + * + * @param p Normalization in L^p^ space, p = 2 by default. + */ +@DeveloperApi +class Normalizer(p: Double) extends VectorTransformer { + + def this() = this(2) + + require(p >= 1.0) + + /** + * Applies unit length normalization on a vector. + * + * @param vector vector to be normalized. + * @return normalized vector. If the norm of the input is zero, it will return the input vector. + */ + override def transform(vector: Vector): Vector = { + var norm = vector.toBreeze.norm(p) + + if (norm != 0.0) { + // For dense vector, we've to allocate new memory for new output vector. + // However, for sparse vector, the `index` array will not be changed, + // so we can re-use it to save memory. + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) + case sv: BSV[Double] => + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) /= norm + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Since the norm is zero, return the input vector object itself. + // Note that it's safe since we always assume that the data in RDD + // should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala new file mode 100644 index 0000000000..cc2d7579c2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + * + * @param withMean False by default. Centers the data with mean before scaling. It will build a + * dense output, so this does not work on sparse input and will raise an exception. + * @param withStd True by default. Scales the data to unit standard deviation. + */ +@DeveloperApi +class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { + + def this() = this(false, true) + + require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") + + private var mean: BV[Double] = _ + private var factor: BV[Double] = _ + + /** + * Computes the mean and variance and stores as a model to be used for later scaling. + * + * @param data The data used to compute the mean and variance to build the transformation model. + * @return This StandardScalar object. + */ + def fit(data: RDD[Vector]): this.type = { + val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + + mean = summary.mean.toBreeze + factor = summary.variance.toBreeze + require(mean.length == factor.length) + + var i = 0 + while (i < factor.length) { + factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + i += 1 + } + + this + } + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` + * for the column with zero variance. + */ + override def transform(vector: Vector): Vector = { + if (mean == null || factor == null) { + throw new IllegalStateException( + "Haven't learned column summary statistics yet. Call fit first.") + } + + require(vector.size == mean.length) + + if (withMean) { + vector.toBreeze match { + case dv: BDV[Double] => + val output = vector.toBreeze.copy + var i = 0 + while (i < output.length) { + output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else if (withStd) { + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) + case sv: BSV[Double] => + // For sparse vector, the `index` array inside sparse vector object will not be changed, + // so we can re-use it to save memory. + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) *= factor(output.index(i)) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Note that it's safe since we always assume that the data in RDD should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala new file mode 100644 index 0000000000..415a845332 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for transformation of a vector + */ +@DeveloperApi +trait VectorTransformer extends Serializable { + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + def transform(vector: Vector): Vector + + /** + * Applies transformation on an RDD[Vector]. + * + * @param data RDD[Vector] to be transformed. + * @return transformed RDD[Vector]. + */ + def transform(data: RDD[Vector]): RDD[Vector] = { + // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // So it should be no longer necessary to explicitly broadcast `this` object. + data.map(x => this.transform(x)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 58c1322757..45486b2c7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala new file mode 100644 index 0000000000..fb76dccfdf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class NormalizerSuite extends FunSuite with LocalSparkContext { + + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + + lazy val dataRDD = sc.parallelize(data, 3) + + test("Normalization using L1 distance") { + val l1Normalizer = new Normalizer(1) + + val data1 = data.map(l1Normalizer.transform) + val data1RDD = l1Normalizer.transform(dataRDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + + assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) + assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data1(2) ~== Vectors.dense(0.12765957, -0.23404255, -0.63829787) absTol 1E-5) + assert(data1(3) ~== Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))) absTol 1E-5) + assert(data1(4) ~== Vectors.dense(0.625, 0.07894737, 0.29605263) absTol 1E-5) + assert(data1(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L2 distance") { + val l2Normalizer = new Normalizer() + + val data2 = data.map(l2Normalizer.transform) + val data2RDD = l2Normalizer.transform(dataRDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + + assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) + assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data2(2) ~== Vectors.dense(0.184549876, -0.3383414, -0.922749378) absTol 1E-5) + assert(data2(3) ~== Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.897906166, 0.113419726, 0.42532397) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L^Inf distance.") { + val lInfNormalizer = new Normalizer(Double.PositiveInfinity) + + val dataInf = data.map(lInfNormalizer.transform) + val dataInfRDD = lInfNormalizer.transform(dataRDD) + + assert((data, dataInf, dataInfRDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((dataInf, dataInfRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(dataInf(0).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(2).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(3).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(4).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + + assert(dataInf(0) ~== Vectors.sparse(3, Seq((0, -0.86956522), (1, 1.0))) absTol 1E-5) + assert(dataInf(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(dataInf(2) ~== Vectors.dense(0.2, -0.36666667, -1.0) absTol 1E-5) + assert(dataInf(3) ~== Vectors.sparse(3, Seq((1, 0.284375), (2, 1.0))) absTol 1E-5) + assert(dataInf(4) ~== Vectors.dense(1.0, 0.12631579, 0.473684211) absTol 1E-5) + assert(dataInf(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000..5a9be923a8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.rdd.RDD + +class StandardScalerSuite extends FunSuite with LocalSparkContext { + + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { + data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + } + + test("Standardization with dense input") { + val data = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + withClue("Using a standardizer before fitting the model should throw exception.") { + intercept[IllegalStateException] { + data.map(standardizer1.transform) + } + } + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + val data1RDD = standardizer1.transform(dataRDD) + val data2RDD = standardizer2.transform(dataRDD) + val data3RDD = standardizer3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + + + test("Standardization with sparse input") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data2 = data.map(standardizer2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer3.transform) + } + } + + val data2RDD = standardizer2.transform(dataRDD) + + val summary2 = computeSummary(data2RDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + + test("Standardization with constant input") { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val data = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + +} -- cgit v1.2.3