aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-18 18:01:39 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-18 18:01:39 -0700
commitc8b16ca0d86cc60fb960eebf0cb383f159a88b03 (patch)
tree27f6b16cc7bd14af681d1678fda53ea3051e2e36 /mllib/src/main
parent115eeb30dd9c9dd10685a71f2c23ca23794d3142 (diff)
downloadspark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.tar.gz
spark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.tar.bz2
spark-c8b16ca0d86cc60fb960eebf0cb383f159a88b03.zip
[SPARK-2850] [SPARK-2626] [mllib] MLlib stats examples + small fixes
Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey Added sc.stop() to all examples. CorrelationSuite.scala * Added 1 test for RDDs with only 1 value RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. Python SparseVector (pyspark/mllib/linalg.py) * Added toDense() function python/run-tests script * Added stat.py (doc test) CC: mengxr dorx Main changes were examples to show usage across APIs. Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1878 from jkbradley/mllib-stats-api-check and squashes the following commits: ea5c047 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check dafebe2 [Joseph K. Bradley] Bug fixes for examples SampledRDDs.scala and sampled_rdds.py: Check for division by 0 and for missing key in maps. 8d1e555 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 60c72d9 [Joseph K. Bradley] Fixed stat.py doc test to work for Python versions printing nan or NaN. b20d90a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 4e5d15e [Joseph K. Bradley] Changed pyspark/mllib/stat.py doc tests to use NaN instead of nan. 32173b7 [Joseph K. Bradley] Stats examples update. c8c20dc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check cf70b07 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 0b7cec3 [Joseph K. Bradley] Small updates based on code review. Renamed statistical_summary.py to correlations.py ab48f6e [Joseph K. Bradley] RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. 65e4ebc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 8195c78 [Joseph K. Bradley] Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey 064985b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check ee918e9 [Joseph K. Bradley] Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API)
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala8
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index e76bc9feff..2e414a73be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -53,8 +53,14 @@ class RowMatrix(
/** Gets or computes the number of columns. */
override def numCols(): Long = {
if (nCols <= 0) {
- // Calling `first` will throw an exception if `rows` is empty.
- nCols = rows.first().size
+ try {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().size
+ } catch {
+ case err: UnsupportedOperationException =>
+ sys.error("Cannot determine the number of cols because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
}
nCols
}
@@ -293,6 +299,10 @@ class RowMatrix(
(s1._1 + s2._1, s1._2 += s2._2)
)
+ if (m <= 1) {
+ sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." +
+ " Cannot compute the covariance of a RowMatrix with <= 1 row.")
+ }
updateNumRows(m)
mean :/= m.toDouble
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 5105b5c37a..7d845c4436 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
def add(sample: Vector): this.type = {
if (n == 0) {
- require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
- n = sample.toBreeze.length
+ require(sample.size > 0, s"Vector should have dimension larger than zero.")
+ n = sample.size
currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
@@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMin = BDV.fill(n)(Double.MaxValue)
}
- require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
- s" Expecting $n but got ${sample.toBreeze.length}.")
+ require(n == sample.size, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${sample.size}.")
sample.toBreeze.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.