diff options
author | wangzhenhua <wangzhenhua@huawei.com> | 2016-10-14 21:18:49 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-10-14 21:18:49 +0800 |
commit | 7486442fe0b70f2aea21d569604e71d7ddf19a77 (patch) | |
tree | e691541bddda84476f509aeb676f8d3b5fc8c82c /sql | |
parent | 28b645b1e643ae0f6c56cbe5a92356623306717f (diff) | |
download | spark-7486442fe0b70f2aea21d569604e71d7ddf19a77.tar.gz spark-7486442fe0b70f2aea21d569604e71d7ddf19a77.tar.bz2 spark-7486442fe0b70f2aea21d569604e71d7ddf19a77.zip |
[SPARK-17073][SQL][FOLLOWUP] generate column-level statistics
## What changes were proposed in this pull request?
This pr adds some test cases for statistics: case sensitive column names, non ascii column names, refresh table, and also improves some documentation.
## How was this patch tested?
add test cases
Author: wangzhenhua <wangzhenhua@huawei.com>
Closes #15360 from wzhfy/colStats2.
Diffstat (limited to 'sql')
3 files changed, 197 insertions, 57 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7066378279..488138709a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -59,10 +59,12 @@ case class AnalyzeColumnCommand( def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { val (rowCount, columnStats) = computeColStats(sparkSession, relation) + // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( sizeInBytes = newTotalSize, rowCount = Some(rowCount), - colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) @@ -90,8 +92,9 @@ case class AnalyzeColumnCommand( } } if (duplicatedColumns.nonEmpty) { - logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + - s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + + s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + + s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") } // Collect statistics per column. @@ -116,22 +119,24 @@ case class AnalyzeColumnCommand( } object ColumnStatStruct { - val zero = Literal(0, LongType) - val one = Literal(1, LongType) + private val zero = Literal(0, LongType) + private val one = Literal(1, LongType) - def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - def max(e: Expression): Expression = Max(e) - def min(e: Expression): Expression = Min(e) - def ndv(e: Expression, relativeSD: Double): Expression = { + private def numNulls(e: Expression): Expression = { + if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + } + private def max(e: Expression): Expression = Max(e) + private def min(e: Expression): Expression = Min(e) + private def ndv(e: Expression, relativeSD: Double): Expression = { // the approximate ndv should never be larger than the number of rows Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) } - def avgLength(e: Expression): Expression = Average(Length(e)) - def maxLength(e: Expression): Expression = Max(Length(e)) - def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + private def avgLength(e: Expression): Expression = Average(Length(e)) + private def maxLength(e: Expression): Expression = Max(Length(e)) + private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -139,19 +144,19 @@ object ColumnStatStruct { }) } - def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) } - def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) } - def binaryColumnStat(e: Expression): Seq[Expression] = { + private def binaryColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e)) } - def booleanColumnStat(e: Expression): Seq[Expression] = { + private def booleanColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), numTrues(e), numFalses(e)) } @@ -162,14 +167,14 @@ object ColumnStatStruct { } } - def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) - case StringType => getStruct(stringColumnStat(e, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(e)) - case BooleanType => getStruct(booleanColumnStat(e)) + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) + case StringType => getStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(attr)) + case BooleanType => getStruct(booleanColumnStat(attr)) case otherType => throw new AnalysisException("Analyzing columns is not supported for column " + - s"${e.name} of data type: ${e.dataType}.") + s"${attr.name} of data type: ${attr.dataType}.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e671604c39..c8447651dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -578,7 +578,8 @@ object SQLConf { val NDV_MAX_ERROR = SQLConfigBuilder("spark.sql.statistics.ndv.maxError") .internal() - .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " + + "column level statistics.") .doubleConf .createWithDefault(0.05) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 85228bb001..c351063a63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} @@ -358,53 +358,187 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - test("generate column-level statistics and load them from hive metastore") { + private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { + val tableName = "tbl" + var statsBeforeUpdate: Statistics = null + var statsAfterUpdate: Statistics = null + withTable(tableName) { + val tableIndent = TableIdentifier(tableName, Some("default")) + val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + sql(s"CREATE TABLE $tableName (key int) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + // Table lookup will make the table cached. + catalog.lookupRelation(tableIndent) + statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + + sql(s"INSERT INTO $tableName SELECT 2") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + catalog.lookupRelation(tableIndent) + statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + } + (statsBeforeUpdate, statsAfterUpdate) + } + + test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + } + + test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsBeforeUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsAfterUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + + private lazy val (testDataFrame, expectedColStatsSeq) = { import testImplicits._ val intSeq = Seq(1, 2) val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), booleanSeq(i)) + (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) } - val tableName = "table" - withTable(tableName) { - val df = data.toDF("c1", "c2", "c3") - df.write.format("parquet").saveAsTable(tableName) - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) + val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") + val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toInt)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) } + (f, colStat) + } + (df, expectedColStatsSeq) + } + + private def checkColStats( + tableName: String, + isDataSourceTable: Boolean, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val readback = spark.table(tableName) + val stats = readback.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") + rel.catalogTable.stats.get + case rel: LogicalRelation => + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") + rel.catalogTable.get.stats.get + } + assert(stats.length == 1) + val columnStats = stats.head.colStats + assert(columnStats.size == expectedColStatsSeq.length) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = columnStats(field.name), + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) + test("generate and load column-level stats for data source table") { + val dsTable = "dsTable" + withTable(dsTable) { + testDataFrame.write.format("parquet").saveAsTable(dsTable) + sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) + } + } + + test("generate and load column-level stats for hive serde table") { + val hTable = "hTable" + val tmp = "tmp" + withTable(hTable, tmp) { + testDataFrame.write.format("parquet").saveAsTable(tmp) + sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") + sql(s"INSERT INTO $hTable SELECT * FROM $tmp") + sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) + } + } + + // When caseSensitive is on, for columns with only case difference, they are different columns + // and we should generate column stats for all of them. + private def checkCaseSensitiveColStats(columnName: String): Unit = { + val tableName = "tbl" + withTable(tableName) { + val column1 = columnName.toLowerCase + val column2 = columnName.toUpperCase + withSQLConf("spark.sql.caseSensitive" -> "true") { + sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1, 3.0") + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + assert(columnStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = columnStats(column1), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, + dataType = DoubleType, + colStat = columnStats(column2), + expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), rsd = spark.sessionState.conf.ndvMaxError) + rel } - rel + assert(relations.size == 1) } - assert(relations.size == 1) } } + test("check column statistics for case sensitive column names") { + checkCaseSensitiveColStats(columnName = "c1") + } + + test("check column statistics for case sensitive non-ascii column names") { + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkCaseSensitiveColStats(columnName = "列c") + // scalastyle:on + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => |