aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2014-04-11 19:43:22 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-11 19:43:22 -0700
commitfdfb45e691946f3153d6c696bec6d7f3e391e301 (patch)
treedf93888223f9c274ddfb5520fc4cb998f57703b1 /mllib
parent7038b00be9c84a4d92f9d95ff3d75fae47d57d87 (diff)
downloadspark-fdfb45e691946f3153d6c696bec6d7f3e391e301.tar.gz
spark-fdfb45e691946f3153d6c696bec6d7f3e391e301.tar.bz2
spark-fdfb45e691946f3153d6c696bec6d7f3e391e301.zip
[WIP] [SPARK-1328] Add vector statistics
As with the new vector system in MLlib, we find that it is good to add some new APIs to precess the `RDD[Vector]`. Beside, the former implementation of `computeStat` is not stable which could loss precision, and has the possibility to cause `Nan` in scientific computing, just as said in the [SPARK-1328](https://spark-project.atlassian.net/browse/SPARK-1328). APIs contain: * rowMeans(): RDD[Double] * rowNorm2(): RDD[Double] * rowSDs(): RDD[Double] * colMeans(): Vector * colMeans(size: Int): Vector * colNorm2(): Vector * colNorm2(size: Int): Vector * colSDs(): Vector * colSDs(size: Int): Vector * maxOption((Vector, Vector) => Boolean): Option[Vector] * minOption((Vector, Vector) => Boolean): Option[Vector] * rowShrink(): RDD[Vector] * colShrink(): RDD[Vector] This is working in process now, and some more APIs will add to `LabeledPoint`. Moreover, the implicit declaration will move from `MLUtils` to `MLContext` later. Author: Xusen Yin <yinxusen@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #268 from yinxusen/vector-statistics and squashes the following commits: d61363f [Xusen Yin] rebase to latest master 16ae684 [Xusen Yin] fix minor error and remove useless method 10cf5d3 [Xusen Yin] refine some return type b064714 [Xusen Yin] remove computeStat in MLUtils cbbefdb [Xiangrui Meng] update multivariate statistical summary interface and clean tests 4eaf28a [Xusen Yin] merge VectorRDDStatistics into RowMatrix 48ee053 [Xusen Yin] fix minor error e624f93 [Xusen Yin] fix scala style error 1fba230 [Xusen Yin] merge while loop together 69e1f37 [Xusen Yin] remove lazy eval, and minor memory footprint 548e9de [Xusen Yin] minor revision 86522c4 [Xusen Yin] add comments on functions dc77e38 [Xusen Yin] test sparse vector RDD 18cf072 [Xusen Yin] change def to lazy val to make sure that the computations in function be evaluated only once f7a3ca2 [Xusen Yin] fix the corner case of maxmin 967d041 [Xusen Yin] full revision with Aggregator class 138300c [Xusen Yin] add new Aggregator class 1376ff4 [Xusen Yin] rename variables and adjust code 4a5c38d [Xusen Yin] add scala doc, refine code and comments 036b7a5 [Xusen Yin] fix the bug of Nan occur f6e8e9a [Xusen Yin] add sparse vectors test 4cfbadf [Xusen Yin] fix bug of min max 4e4fbd1 [Xusen Yin] separate seqop and combop out as independent functions a6d5a2e [Xusen Yin] rewrite for only computing non-zero elements 3980287 [Xusen Yin] rename variables 62a2c3e [Xusen Yin] use axpy and in-place if possible 9a75ebd [Xusen Yin] add case class to wrap return values d816ac7 [Xusen Yin] remove useless APIs c4651bb [Xusen Yin] remove row-wise APIs and refine code 1338ea1 [Xusen Yin] all-in-one version test passed cc65810 [Xusen Yin] add parallel mean and variance 9af2e95 [Xusen Yin] refine the code style ad6c82d [Xusen Yin] add shrink test e09d5d2 [Xusen Yin] add scala docs and refine shrink method 8ef3377 [Xusen Yin] pass all tests 28cf060 [Xusen Yin] fix error of column means 54b19ab [Xusen Yin] add new API to shrink RDD[Vector] 8c6c0e1 [Xusen Yin] add basic statistics
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala165
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala56
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala57
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala13
5 files changed, 230 insertions, 76 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index f65f43dd30..0c0afcd9ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
import java.util
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
+import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -27,6 +27,138 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
+import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
+
+/**
+ * Column statistics aggregator implementing
+ * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
+ * together with add() and merge() function.
+ * A numerically stable algorithm is implemented to compute sample mean and variance:
+ *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
+ * Zero elements (including explicit zero values) are skipped when calling add() and merge(),
+ * to have time complexity O(nnz) instead of O(n) for each column.
+ */
+private class ColumnStatisticsAggregator(private val n: Int)
+ extends MultivariateStatisticalSummary with Serializable {
+
+ private val currMean: BDV[Double] = BDV.zeros[Double](n)
+ private val currM2n: BDV[Double] = BDV.zeros[Double](n)
+ private var totalCnt = 0.0
+ private val nnz: BDV[Double] = BDV.zeros[Double](n)
+ private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
+ private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
+
+ override def mean: Vector = {
+ val realMean = BDV.zeros[Double](n)
+ var i = 0
+ while (i < n) {
+ realMean(i) = currMean(i) * nnz(i) / totalCnt
+ i += 1
+ }
+ Vectors.fromBreeze(realMean)
+ }
+
+ override def variance: Vector = {
+ val realVariance = BDV.zeros[Double](n)
+
+ val denominator = totalCnt - 1.0
+
+ // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
+ if (denominator > 0.0) {
+ val deltaMean = currMean
+ var i = 0
+ while (i < currM2n.size) {
+ realVariance(i) =
+ currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
+ realVariance(i) /= denominator
+ i += 1
+ }
+ }
+
+ Vectors.fromBreeze(realVariance)
+ }
+
+ override def count: Long = totalCnt.toLong
+
+ override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
+
+ override def max: Vector = {
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMax)
+ }
+
+ override def min: Vector = {
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMin)
+ }
+
+ /**
+ * Aggregates a row.
+ */
+ def add(currData: BV[Double]): this.type = {
+ currData.activeIterator.foreach {
+ case (_, 0.0) => // Skip explicit zero elements.
+ case (i, value) =>
+ if (currMax(i) < value) {
+ currMax(i) = value
+ }
+ if (currMin(i) > value) {
+ currMin(i) = value
+ }
+
+ val tmpPrevMean = currMean(i)
+ currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
+ currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
+
+ nnz(i) += 1.0
+ }
+
+ totalCnt += 1.0
+ this
+ }
+
+ /**
+ * Merges another aggregator.
+ */
+ def merge(other: ColumnStatisticsAggregator): this.type = {
+ require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
+
+ totalCnt += other.totalCnt
+ val deltaMean = currMean - other.currMean
+
+ var i = 0
+ while (i < n) {
+ // merge mean together
+ if (other.currMean(i) != 0.0) {
+ currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
+ (nnz(i) + other.nnz(i))
+ }
+ // merge m2n together
+ if (nnz(i) + other.nnz(i) != 0.0) {
+ currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
+ (nnz(i) + other.nnz(i))
+ }
+ if (currMax(i) < other.currMax(i)) {
+ currMax(i) = other.currMax(i)
+ }
+ if (currMin(i) > other.currMin(i)) {
+ currMin(i) = other.currMin(i)
+ }
+ i += 1
+ }
+
+ nnz += other.nnz
+ this
+ }
+}
/**
* :: Experimental ::
@@ -182,13 +314,7 @@ class RowMatrix(
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
)
- // Update _m if it is not set, or verify its value.
- if (nRows <= 0L) {
- nRows = m
- } else {
- require(nRows == m,
- s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
- }
+ updateNumRows(m)
mean :/= m.toDouble
@@ -241,6 +367,19 @@ class RowMatrix(
}
/**
+ * Computes column-wise summary statistics.
+ */
+ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
+ val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
+ val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
+ (aggregator, data) => aggregator.add(data),
+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+ )
+ updateNumRows(summary.count)
+ summary
+ }
+
+ /**
* Multiply this matrix by a local matrix on the right.
*
* @param B a local matrix whose number of rows must match the number of columns of this matrix
@@ -276,6 +415,16 @@ class RowMatrix(
}
mat
}
+
+ /** Updates or verfires the number of rows. */
+ private def updateNumRows(m: Long) {
+ if (nRows <= 0) {
+ nRows == m
+ } else {
+ require(nRows == m,
+ s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
+ }
+ }
}
object RowMatrix {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
new file mode 100644
index 0000000000..f9eb343da2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * Trait for multivariate statistical summary of a data matrix.
+ */
+trait MultivariateStatisticalSummary {
+
+ /**
+ * Sample mean vector.
+ */
+ def mean: Vector
+
+ /**
+ * Sample variance vector. Should return a zero vector if the sample size is 1.
+ */
+ def variance: Vector
+
+ /**
+ * Sample size.
+ */
+ def count: Long
+
+ /**
+ * Number of nonzero elements (including explicitly presented zero values) in each column.
+ */
+ def numNonzeros: Vector
+
+ /**
+ * Maximum value of each column.
+ */
+ def max: Vector
+
+ /**
+ * Minimum value of each column.
+ */
+ def min: Vector
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index ac2360c429..901c3180ea 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.util
-import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
- squaredDistance => breezeSquaredDistance}
+import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.Vectors
/**
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -159,58 +158,6 @@ object MLUtils {
}
/**
- * Utility function to compute mean and standard deviation on a given dataset.
- *
- * @param data - input data set whose statistics are computed
- * @param numFeatures - number of features
- * @param numExamples - number of examples in input dataset
- *
- * @return (yMean, xColMean, xColSd) - Tuple consisting of
- * yMean - mean of the labels
- * xColMean - Row vector with mean for every column (or feature) of the input data
- * xColSd - Row vector standard deviation for every column (or feature) of the input data.
- */
- private[mllib] def computeStats(
- data: RDD[LabeledPoint],
- numFeatures: Int,
- numExamples: Long): (Double, Vector, Vector) = {
- val brzData = data.map { case LabeledPoint(label, features) =>
- (label, features.toBreeze)
- }
- val aggStats = brzData.aggregate(
- (0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
- )(
- seqOp = (c, v) => (c, v) match {
- case ((n, sumLabel, sum, sumSq), (label, features)) =>
- features.activeIterator.foreach { case (i, x) =>
- sumSq(i) += x * x
- }
- (n + 1L, sumLabel + label, sum += features, sumSq)
- },
- combOp = (c1, c2) => (c1, c2) match {
- case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
- (n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
- }
- )
- val (nl, sumLabel, sum, sumSq) = aggStats
-
- require(nl > 0, "Input data is empty.")
- require(nl == numExamples)
-
- val n = nl.toDouble
- val yMean = sumLabel / n
- val mean = sum / n
- val std = new Array[Double](sum.length)
- var i = 0
- while (i < numFeatures) {
- std(i) = sumSq(i) / n - mean(i) * mean(i)
- i += 1
- }
-
- (yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
- }
-
- /**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
* <pre>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 71ee8e8a4f..c9f9acf4c1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
))
}
}
+
+ test("compute column summary statistics") {
+ for (mat <- Seq(denseMat, sparseMat)) {
+ val summary = mat.computeColumnSummaryStatistics()
+ // Run twice to make sure no internal states are changed.
+ for (k <- 0 to 1) {
+ assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
+ assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
+ assert(summary.count === m, "count mismatch.")
+ assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
+ assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
+ assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index e451c350b8..812a843478 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -27,7 +27,6 @@ import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
class MLUtilsSuite extends FunSuite with LocalSparkContext {
@@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}
- test("compute stats") {
- val data = Seq.fill(3)(Seq(
- LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
- LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
- )).flatten
- val rdd = sc.parallelize(data, 2)
- val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
- assert(meanLabel === 0.5)
- assert(mean === Vectors.dense(2.0, 3.0, 4.0))
- assert(std === Vectors.dense(1.0, 1.0, 1.0))
- }
-
test("loadLibSVMData") {
val lines =
"""