diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala | 14 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 8 |
2 files changed, 16 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e76bc9feff..2e414a73be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -53,8 +53,14 @@ class RowMatrix( /** Gets or computes the number of columns. */ override def numCols(): Long = { if (nCols <= 0) { - // Calling `first` will throw an exception if `rows` is empty. - nCols = rows.first().size + try { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().size + } catch { + case err: UnsupportedOperationException => + sys.error("Cannot determine the number of cols because it is not specified in the " + + "constructor and the rows RDD is empty.") + } } nCols } @@ -293,6 +299,10 @@ class RowMatrix( (s1._1 + s2._1, s1._2 += s2._2) ) + if (m <= 1) { + sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." + + " Cannot compute the covariance of a RowMatrix with <= 1 row.") + } updateNumRows(m) mean :/= m.toDouble diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 5105b5c37a..7d845c4436 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ def add(sample: Vector): this.type = { if (n == 0) { - require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.") - n = sample.toBreeze.length + require(sample.size > 0, s"Vector should have dimension larger than zero.") + n = sample.size currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) @@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = BDV.fill(n)(Double.MaxValue) } - require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.toBreeze.length}.") + require(n == sample.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${sample.size}.") sample.toBreeze.activeIterator.foreach { case (_, 0.0) => // Skip explicit zero elements. |