aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala13
1 files changed, 7 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 00231d65a7..725d6821bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging {
/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
- val counts = collectStatisticalData(df, cols)
+ val counts = collectStatisticalData(df, cols, "correlation")
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}
@@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging {
def cov: Double = Ck / (count - 1)
}
- private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
- require(cols.length == 2, "Currently cov supports calculating the covariance " +
+ private def collectStatisticalData(df: DataFrame, cols: Seq[String],
+ functionName: String): CovarianceCounter = {
+ require(cols.length == 2, s"Currently $functionName calculation is supported " +
"between two columns.")
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
require(data.nonEmpty, s"Couldn't find column with name $name")
- require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " +
- s"with dataType ${data.get.dataType} not supported.")
+ require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " +
+ s"for columns with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)(
@@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging {
* @return the covariance of the two columns.
*/
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
- val counts = collectStatisticalData(df, cols)
+ val counts = collectStatisticalData(df, cols, "covariance")
counts.cov
}