From 6f7ff75091154fed7649ea6d79e887aad9fbde6a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Nov 2016 16:34:11 -0800 Subject: [SPARK-18505][SQL] Simplify AnalyzeColumnCommand ## What changes were proposed in this pull request? I'm spending more time at the design & code level for cost-based optimizer now, and have found a number of issues related to maintainability and compatibility that I will like to address. This is a small pull request to clean up AnalyzeColumnCommand: 1. Removed warning on duplicated columns. Warnings in log messages are useless since most users that run SQL don't see them. 2. Removed the nested updateStats function, by just inlining the function. 3. Renamed a few functions to better reflect what they do. 4. Removed the factory apply method for ColumnStatStruct. It is a bad pattern to use a apply method that returns an instantiation of a class that is not of the same type (ColumnStatStruct.apply used to return CreateNamedStruct). 5. Renamed ColumnStatStruct to just AnalyzeColumnCommand. 6. Added more documentation explaining some of the non-obvious return types and code blocks. In follow-up pull requests, I'd like to address the following: 1. Get rid of the Map[String, ColumnStat] map, since internally we should be using Attribute to reference columns, rather than strings. 2. Decouple the fields exposed by ColumnStat and internals of Spark SQL's execution path. Currently the two are coupled because ColumnStat takes in an InternalRow. 3. Correctness: Remove code path that stores statistics in the catalog using the base64 encoding of the UnsafeRow format, which is not stable across Spark versions. 4. Clearly document the data representation stored in the catalog for statistics. ## How was this patch tested? Affected test cases have been updated. Author: Reynold Xin Closes #15933 from rxin/SPARK-18505. --- .../execution/command/AnalyzeColumnCommand.scala | 115 ++++++++++++--------- .../apache/spark/sql/StatisticsColumnSuite.scala | 2 +- .../org/apache/spark/sql/StatisticsTest.scala | 7 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- .../spark/sql/hive/client/HiveClientImpl.scala | 2 +- 5 files changed, 74 insertions(+), 56 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4af..7fc57d09e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.command -import scala.collection.mutable - +import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases @@ -44,13 +43,16 @@ case class AnalyzeColumnCommand( val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) - relation match { + // Compute total size + val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { case catalogRel: CatalogRelation => - updateStats(catalogRel.catalogTable, + // This is a Hive serde format table + (catalogRel.catalogTable, AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, + // This is a data source format table + (logicalRel.catalogTable.get, AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => @@ -58,45 +60,45 @@ case class AnalyzeColumnCommand( s"${otherRelation.nodeName}.") } - def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val (rowCount, columnStats) = computeColStats(sparkSession, relation) - // We also update table-level stats in order to keep them consistent with column-level stats. - val statistics = Statistics( - sizeInBytes = newTotalSize, - rowCount = Some(rowCount), - // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + // Compute stats for each column + val (rowCount, newColStats) = + AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + + // We also update table-level stats in order to keep them consistent with column-level stats. + val statistics = Statistics( + sizeInBytes = sizeInBytes, + rowCount = Some(rowCount), + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } +} +object AnalyzeColumnCommand extends Logging { + + /** + * Compute stats for the given columns. + * @return (row count, map from column name to ColumnStats) + * + * This is visible for testing. + */ def computeColStats( sparkSession: SparkSession, - relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + relation: LogicalPlan, + columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - // check correctness of column names - val attributesToAnalyze = mutable.MutableList[Attribute]() - val duplicatedColumns = mutable.MutableList[String]() + // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - columnNames.foreach { col => + val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) - // do deduplication - if (!attributesToAnalyze.contains(expr)) { - attributesToAnalyze += expr - } else { - duplicatedColumns += col - } - } - if (duplicatedColumns.nonEmpty) { - logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + - s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + - s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") - } + exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + }).toSeq // Collect statistics per column. // The first element in the result will be the overall row count, the following elements @@ -104,22 +106,21 @@ case class AnalyzeColumnCommand( // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head // unwrap the result + // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = ColumnStatStruct.numStatFields(expr.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) }.toMap (rowCount, columnStats) } -} -object ColumnStatStruct { private val zero = Literal(0, LongType) private val one = Literal(1, LongType) @@ -137,7 +138,11 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + /** + * Creates a struct that groups the sequence of expressions together. This is used to create + * one top level struct per column. + */ + private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -161,6 +166,7 @@ object ColumnStatStruct { Seq(numNulls(e), numTrues(e), numFalses(e)) } + // TODO(rxin): Get rid of this function. def numStatFields(dataType: DataType): Int = { dataType match { case BinaryType | BooleanType => 3 @@ -168,14 +174,25 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { - // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) - case StringType => getStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(attr)) - case BooleanType => getStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") + /** + * Creates a struct expression that contains the statistics to collect for a column. + * + * @param attr column to collect statistics + * @param relativeSD relative error for approximate number of distinct values. + */ + def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { + attr.dataType match { + case _: NumericType | TimestampType | DateType => + createStruct(numericColumnStat(attr, relativeSD)) + case StringType => + createStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => + createStruct(binaryColumnStat(attr)) + case BooleanType => + createStruct(booleanColumnStat(attr)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${attr.name} of data type: ${attr.dataType}.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index f1a201abd8..e866ac2cb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -79,7 +79,7 @@ class StatisticsColumnSuite extends StatisticsTest { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 5134ac0e7e..915ee0d31b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ + trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( @@ -36,7 +37,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) expectedColStatsSeq.foreach { case (field, expectedColStat) => assert(columnStats.contains(field.name)) val colStat = columnStats(field.name) @@ -48,7 +49,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val numFields = ColumnStatStruct.numStatFields(field.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( dataType = field.dataType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cacffcf33c..5dbb4024bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -634,7 +634,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => - val numFields = ColumnStatStruct.numStatFields(f.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap tableWithSchema.copy( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2bf9a26b0b..daae8523c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,7 +97,7 @@ private[hive] class HiveClientImpl( } // Create an internal session state for this HiveClientImpl. - val state = { + val state: SessionState = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) -- cgit v1.2.3