From fdfac22d08fc4fdc640843dd93a29e2ce4aee2ef Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Mon, 4 Jan 2016 16:14:49 -0800 Subject: [SPARK-12509][SQL] Fixed error messages for DataFrame correlation and covariance Currently, when we call corr or cov on dataframe with invalid input we see these error messages for both corr and cov: - "Currently cov supports calculating the covariance between two columns" - "Covariance calculation for columns with dataType "[DataType Name]" not supported." I've fixed this issue by passing the function name as an argument. We could also do the input checks separately for each function. I avoided doing that because of code duplication. Thanks! Author: Narine Kokhlikyan Closes #10458 from NarineK/sparksqlstatsmessages. --- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7..725d6821bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging { def cov: Double = Ck / (count - 1) } - private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { - require(cols.length == 2, "Currently cov supports calculating the covariance " + + private def collectStatisticalData(df: DataFrame, cols: Seq[String], + functionName: String): CovarianceCounter = { + require(cols.length == 2, s"Currently $functionName calculation is supported " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") - require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + - s"with dataType ${data.get.dataType} not supported.") + require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( @@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging { * @return the covariance of the two columns. */ private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "covariance") counts.cov } -- cgit v1.2.3