aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala8
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index e76bc9feff..2e414a73be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -53,8 +53,14 @@ class RowMatrix(
/** Gets or computes the number of columns. */
override def numCols(): Long = {
if (nCols <= 0) {
- // Calling `first` will throw an exception if `rows` is empty.
- nCols = rows.first().size
+ try {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().size
+ } catch {
+ case err: UnsupportedOperationException =>
+ sys.error("Cannot determine the number of cols because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
}
nCols
}
@@ -293,6 +299,10 @@ class RowMatrix(
(s1._1 + s2._1, s1._2 += s2._2)
)
+ if (m <= 1) {
+ sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." +
+ " Cannot compute the covariance of a RowMatrix with <= 1 row.")
+ }
updateNumRows(m)
mean :/= m.toDouble
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 5105b5c37a..7d845c4436 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
def add(sample: Vector): this.type = {
if (n == 0) {
- require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
- n = sample.toBreeze.length
+ require(sample.size > 0, s"Vector should have dimension larger than zero.")
+ n = sample.size
currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
@@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMin = BDV.fill(n)(Double.MaxValue)
}
- require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
- s" Expecting $n but got ${sample.toBreeze.length}.")
+ require(n == sample.size, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${sample.size}.")
sample.toBreeze.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.