aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-05-08 14:41:16 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-08 14:41:16 -0700
commit90527f560462cc2d693176bd961b02767e460e06 (patch)
tree2a2fb7ced252e14aa4f4250d571c02f3b40cb415 /sql
parent5467c34c3d6538e053957b5513df218f1f5bae6b (diff)
downloadspark-90527f560462cc2d693176bd961b02767e460e06.tar.gz
spark-90527f560462cc2d693176bd961b02767e460e06.tar.bz2
spark-90527f560462cc2d693176bd961b02767e460e06.zip
[SPARK-7390] [SQL] Only merge other CovarianceCounter when its count is greater than zero
JIRA: https://issues.apache.org/jira/browse/SPARK-7390 Also fix a minor typo. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5931 from viirya/fix_covariancecounter and squashes the following commits: 352eda6 [Liang-Chi Hsieh] Only merge other CovarianceCounter when its count is greater than zero.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala22
1 files changed, 12 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 386ac969f1..71b7f6c2a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging {
var yAvg = 0.0 // the mean of all examples seen so far in col2
var Ck = 0.0 // the co-moment after k examples
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
- var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
var count = 0L // count of observed examples
// add an example to the calculation
def add(x: Double, y: Double): this.type = {
@@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging {
// merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
def merge(other: CovarianceCounter): this.type = {
- val totalCount = count + other.count
- val deltaX = xAvg - other.xAvg
- val deltaY = yAvg - other.yAvg
- Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
- xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
- yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
- MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
- MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
- count = totalCount
+ if (other.count > 0) {
+ val totalCount = count + other.count
+ val deltaX = xAvg - other.xAvg
+ val deltaY = yAvg - other.yAvg
+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
+ xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
+ yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
+ count = totalCount
+ }
this
}
// return the sample covariance for the observed examples