aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-18 18:01:39 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-18 18:01:39 -0700
commitc8b16ca0d86cc60fb960eebf0cb383f159a88b03 (patch)
tree27f6b16cc7bd14af681d1678fda53ea3051e2e36 /mllib
parent115eeb30dd9c9dd10685a71f2c23ca23794d3142 (diff)
downloadspark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.tar.gz
spark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.tar.bz2
spark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.zip
[SPARK-2850] [SPARK-2626] [mllib] MLlib stats examples + small fixes
Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey Added sc.stop() to all examples. CorrelationSuite.scala * Added 1 test for RDDs with only 1 value RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. Python SparseVector (pyspark/mllib/linalg.py) * Added toDense() function python/run-tests script * Added stat.py (doc test) CC: mengxr dorx Main changes were examples to show usage across APIs. Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1878 from jkbradley/mllib-stats-api-check and squashes the following commits: ea5c047 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check dafebe2 [Joseph K. Bradley] Bug fixes for examples SampledRDDs.scala and sampled_rdds.py: Check for division by 0 and for missing key in maps. 8d1e555 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 60c72d9 [Joseph K. Bradley] Fixed stat.py doc test to work for Python versions printing nan or NaN. b20d90a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 4e5d15e [Joseph K. Bradley] Changed pyspark/mllib/stat.py doc tests to use NaN instead of nan. 32173b7 [Joseph K. Bradley] Stats examples update. c8c20dc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check cf70b07 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 0b7cec3 [Joseph K. Bradley] Small updates based on code review. Renamed statistical_summary.py to correlations.py ab48f6e [Joseph K. Bradley] RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. 65e4ebc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 8195c78 [Joseph K. Bradley] Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey 064985b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check ee918e9 [Joseph K. Bradley] Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API)
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala6
4 files changed, 33 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index e76bc9feff..2e414a73be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -53,8 +53,14 @@ class RowMatrix(
/** Gets or computes the number of columns. */
override def numCols(): Long = {
if (nCols <= 0) {
- // Calling `first` will throw an exception if `rows` is empty.
- nCols = rows.first().size
+ try {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().size
+ } catch {
+ case err: UnsupportedOperationException =>
+ sys.error("Cannot determine the number of cols because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
}
nCols
}
@@ -293,6 +299,10 @@ class RowMatrix(
(s1._1 + s2._1, s1._2 += s2._2)
)
+ if (m <= 1) {
+ sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." +
+ " Cannot compute the covariance of a RowMatrix with <= 1 row.")
+ }
updateNumRows(m)
mean :/= m.toDouble
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 5105b5c37a..7d845c4436 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
def add(sample: Vector): this.type = {
if (n == 0) {
- require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
- n = sample.toBreeze.length
+ require(sample.size > 0, s"Vector should have dimension larger than zero.")
+ n = sample.size
currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
@@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMin = BDV.fill(n)(Double.MaxValue)
}
- require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
- s" Expecting $n but got ${sample.toBreeze.length}.")
+ require(n == sample.size, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${sample.size}.")
sample.toBreeze.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index a3f76f77a5..34548c86eb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -39,6 +39,17 @@ class CorrelationSuite extends FunSuite with LocalSparkContext {
Vectors.dense(9.0, 0.0, 0.0, 1.0)
)
+ test("corr(x, y) pearson, 1 value in data") {
+ val x = sc.parallelize(Array(1.0))
+ val y = sc.parallelize(Array(4.0))
+ intercept[RuntimeException] {
+ Statistics.corr(x, y, "pearson")
+ }
+ intercept[RuntimeException] {
+ Statistics.corr(x, y, "spearman")
+ }
+ }
+
test("corr(x, y) default, pearson") {
val x = sc.parallelize(xData)
val y = sc.parallelize(yData)
@@ -58,7 +69,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext {
// RDD of zero variance
val z = sc.parallelize(zeros)
- assert(Statistics.corr(x, z).isNaN())
+ assert(Statistics.corr(x, z).isNaN)
}
test("corr(x, y) spearman") {
@@ -78,7 +89,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext {
// RDD of zero variance => zero variance in ranks
val z = sc.parallelize(zeros)
- assert(Statistics.corr(x, z, "spearman").isNaN())
+ assert(Statistics.corr(x, z, "spearman").isNaN)
}
test("corr(X) default, pearson") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index db13f142df..1e94152491 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -139,7 +139,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~==
- Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
+ "variance mismatch")
assert(summarizer.count === 6)
}
@@ -167,7 +168,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~==
- Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch")
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
+ "variance mismatch")
assert(summarizer.count === 6)
}