aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala136
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala201
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala209
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala45
-rw-r--r--project/MimaExcludes.scala1
5 files changed, 458 insertions, 134 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 99cb6516e0..711e32a330 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -28,138 +28,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
-import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
-
-/**
- * Column statistics aggregator implementing
- * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
- * together with add() and merge() function.
- * A numerically stable algorithm is implemented to compute sample mean and variance:
- * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
- * Zero elements (including explicit zero values) are skipped when calling add() and merge(),
- * to have time complexity O(nnz) instead of O(n) for each column.
- */
-private class ColumnStatisticsAggregator(private val n: Int)
- extends MultivariateStatisticalSummary with Serializable {
-
- private val currMean: BDV[Double] = BDV.zeros[Double](n)
- private val currM2n: BDV[Double] = BDV.zeros[Double](n)
- private var totalCnt = 0.0
- private val nnz: BDV[Double] = BDV.zeros[Double](n)
- private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
- private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
-
- override def mean: Vector = {
- val realMean = BDV.zeros[Double](n)
- var i = 0
- while (i < n) {
- realMean(i) = currMean(i) * nnz(i) / totalCnt
- i += 1
- }
- Vectors.fromBreeze(realMean)
- }
-
- override def variance: Vector = {
- val realVariance = BDV.zeros[Double](n)
-
- val denominator = totalCnt - 1.0
-
- // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
- if (denominator > 0.0) {
- val deltaMean = currMean
- var i = 0
- while (i < currM2n.size) {
- realVariance(i) =
- currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
- realVariance(i) /= denominator
- i += 1
- }
- }
-
- Vectors.fromBreeze(realVariance)
- }
-
- override def count: Long = totalCnt.toLong
-
- override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
-
- override def max: Vector = {
- var i = 0
- while (i < n) {
- if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
- i += 1
- }
- Vectors.fromBreeze(currMax)
- }
-
- override def min: Vector = {
- var i = 0
- while (i < n) {
- if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
- i += 1
- }
- Vectors.fromBreeze(currMin)
- }
-
- /**
- * Aggregates a row.
- */
- def add(currData: BV[Double]): this.type = {
- currData.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val tmpPrevMean = currMean(i)
- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
-
- nnz(i) += 1.0
- }
-
- totalCnt += 1.0
- this
- }
-
- /**
- * Merges another aggregator.
- */
- def merge(other: ColumnStatisticsAggregator): this.type = {
- require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
-
- totalCnt += other.totalCnt
- val deltaMean = currMean - other.currMean
-
- var i = 0
- while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
- (nnz(i) + other.nnz(i))
- }
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
- }
- i += 1
- }
-
- nnz += other.nnz
- this
- }
-}
+import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
/**
* :: Experimental ::
@@ -478,8 +347,7 @@ class RowMatrix(
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
- val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
- val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
+ val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
new file mode 100644
index 0000000000..5105b5c37a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
+/**
+ * :: DeveloperApi ::
+ * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
+ * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
+ * format in a online fashion.
+ *
+ * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
+ * the corresponding joint dataset.
+ *
+ * A numerically stable algorithm is implemented to compute sample mean and variance:
+ * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
+ * Zero elements (including explicit zero values) are skipped when calling add(),
+ * to have time complexity O(nnz) instead of O(n) for each column.
+ */
+@DeveloperApi
+class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
+
+ private var n = 0
+ private var currMean: BDV[Double] = _
+ private var currM2n: BDV[Double] = _
+ private var totalCnt: Long = 0
+ private var nnz: BDV[Double] = _
+ private var currMax: BDV[Double] = _
+ private var currMin: BDV[Double] = _
+
+ /**
+ * Add a new sample to this summarizer, and update the statistical summary.
+ *
+ * @param sample The sample in dense/sparse vector format to be added into this summarizer.
+ * @return This MultivariateOnlineSummarizer object.
+ */
+ def add(sample: Vector): this.type = {
+ if (n == 0) {
+ require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
+ n = sample.toBreeze.length
+
+ currMean = BDV.zeros[Double](n)
+ currM2n = BDV.zeros[Double](n)
+ nnz = BDV.zeros[Double](n)
+ currMax = BDV.fill(n)(Double.MinValue)
+ currMin = BDV.fill(n)(Double.MaxValue)
+ }
+
+ require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${sample.toBreeze.length}.")
+
+ sample.toBreeze.activeIterator.foreach {
+ case (_, 0.0) => // Skip explicit zero elements.
+ case (i, value) =>
+ if (currMax(i) < value) {
+ currMax(i) = value
+ }
+ if (currMin(i) > value) {
+ currMin(i) = value
+ }
+
+ val tmpPrevMean = currMean(i)
+ currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
+ currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
+
+ nnz(i) += 1.0
+ }
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another MultivariateOnlineSummarizer, and update the statistical summary.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other MultivariateOnlineSummarizer to be merged.
+ * @return This MultivariateOnlineSummarizer object.
+ */
+ def merge(other: MultivariateOnlineSummarizer): this.type = {
+ if (this.totalCnt != 0 && other.totalCnt != 0) {
+ require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
+ s"Expecting $n but got ${other.n}.")
+ totalCnt += other.totalCnt
+ val deltaMean: BDV[Double] = currMean - other.currMean
+ var i = 0
+ while (i < n) {
+ // merge mean together
+ if (other.currMean(i) != 0.0) {
+ currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
+ (nnz(i) + other.nnz(i))
+ }
+ // merge m2n together
+ if (nnz(i) + other.nnz(i) != 0.0) {
+ currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
+ (nnz(i) + other.nnz(i))
+ }
+ if (currMax(i) < other.currMax(i)) {
+ currMax(i) = other.currMax(i)
+ }
+ if (currMin(i) > other.currMin(i)) {
+ currMin(i) = other.currMin(i)
+ }
+ i += 1
+ }
+ nnz += other.nnz
+ } else if (totalCnt == 0 && other.totalCnt != 0) {
+ this.n = other.n
+ this.currMean = other.currMean.copy
+ this.currM2n = other.currM2n.copy
+ this.totalCnt = other.totalCnt
+ this.nnz = other.nnz.copy
+ this.currMax = other.currMax.copy
+ this.currMin = other.currMin.copy
+ }
+ this
+ }
+
+ override def mean: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ val realMean = BDV.zeros[Double](n)
+ var i = 0
+ while (i < n) {
+ realMean(i) = currMean(i) * (nnz(i) / totalCnt)
+ i += 1
+ }
+ Vectors.fromBreeze(realMean)
+ }
+
+ override def variance: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ val realVariance = BDV.zeros[Double](n)
+
+ val denominator = totalCnt - 1.0
+
+ // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
+ if (denominator > 0.0) {
+ val deltaMean = currMean
+ var i = 0
+ while (i < currM2n.size) {
+ realVariance(i) =
+ currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
+ realVariance(i) /= denominator
+ i += 1
+ }
+ }
+
+ Vectors.fromBreeze(realVariance)
+ }
+
+ override def count: Long = totalCnt
+
+ override def numNonzeros: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ Vectors.fromBreeze(nnz)
+ }
+
+ override def max: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMax)
+ }
+
+ override def min: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMin)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
new file mode 100644
index 0000000000..4b7b019d82
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.TestingUtils._
+
+class MultivariateOnlineSummarizerSuite extends FunSuite {
+
+ test("basic error handing") {
+ val summarizer = new MultivariateOnlineSummarizer
+
+ assert(summarizer.count === 0, "should be zero since nothing is added.")
+
+ withClue("Getting numNonzeros from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.numNonzeros
+ }
+ }
+
+ withClue("Getting variance from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.variance
+ }
+ }
+
+ withClue("Getting mean from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.mean
+ }
+ }
+
+ withClue("Getting max from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.max
+ }
+ }
+
+ withClue("Getting min from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.min
+ }
+ }
+
+ summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0))))
+
+ withClue("Adding a new dense sample with different array size should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.add(Vectors.dense(3.0, 1.0))
+ }
+ }
+
+ withClue("Adding a new sparse sample with different array size should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0))))
+ }
+ }
+
+ val summarizer2 = (new MultivariateOnlineSummarizer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0))
+ withClue("Merging a new summarizer with different dimensions should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.merge(summarizer2)
+ }
+ }
+ }
+
+ test("dense vector input") {
+ // For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all
+ // the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we
+ // need to test as well.
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(-1.0, 0.0, 6.0))
+ .add(Vectors.dense(3.0, -3.0, 0.0))
+
+ assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+ assert(summarizer.count === 2)
+ }
+
+ test("sparse vector input") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
+ .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
+
+ assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+ assert(summarizer.count === 2)
+ }
+
+ test("mixing dense and sparse vector input") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+ .add(Vectors.dense(0.0, -1.0, -3.0))
+ .add(Vectors.sparse(3, Seq((1, -5.1))))
+ .add(Vectors.dense(3.8, 0.0, 1.9))
+ .add(Vectors.dense(1.7, -0.6, 0.0))
+ .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+ assert(summarizer.mean.almostEquals(
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+ assert(summarizer.count === 6)
+ }
+
+ test("merging two summarizers") {
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+ .add(Vectors.dense(0.0, -1.0, -3.0))
+
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((1, -5.1))))
+ .add(Vectors.dense(3.8, 0.0, 1.9))
+ .add(Vectors.dense(1.7, -0.6, 0.0))
+ .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+ val summarizer = summarizer1.merge(summarizer2)
+
+ assert(summarizer.mean.almostEquals(
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+ assert(summarizer.count === 6)
+ }
+
+ test("merging summarizer with empty summarizer") {
+ // If one of two is non-empty, this should return the non-empty summarizer.
+ // If both of them are empty, then just return the empty summarizer.
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(0.0, -1.0, -3.0)).merge(new MultivariateOnlineSummarizer)
+ assert(summarizer1.count === 1)
+
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .merge((new MultivariateOnlineSummarizer).add(Vectors.dense(0.0, -1.0, -3.0)))
+ assert(summarizer2.count === 1)
+
+ val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer)
+ assert(summarizer3.count === 0)
+
+ assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+ assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+ assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+ assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+ assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+ assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+ assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+
+ assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
new file mode 100644
index 0000000000..64b1ba7527
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.apache.spark.mllib.linalg.Vector
+
+object TestingUtils {
+
+ implicit class DoubleWithAlmostEquals(val x: Double) {
+ // An improved version of AlmostEquals would always divide by the larger number.
+ // This will avoid the problem of diving by zero.
+ def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = {
+ if(x == y) {
+ true
+ } else if(math.abs(x) > math.abs(y)) {
+ math.abs(x - y) / math.abs(x) < epsilon
+ } else {
+ math.abs(x - y) / math.abs(y) < epsilon
+ }
+ }
+ }
+
+ implicit class VectorWithAlmostEquals(val x: Vector) {
+ def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = {
+ x.toArray.corresponds(y.toArray) {
+ _.almostEquals(_, epsilon)
+ }
+ }
+ }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3b7b87b80c..d67c6571a0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -75,6 +75,7 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$<init>$default$7")
) ++
+ MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++