From 7bf92127643570e4eb3610fa3ffd36839eba2718 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Oct 2016 10:12:02 -0700 Subject: [SPARK-17073][SQL] generate column-level statistics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Generate basic column statistics for all the atomic types: - numeric types: max, min, num of nulls, ndv (number of distinct values) - date/timestamp types: they are also represented as numbers internally, so they have the same stats as above. - string: avg length, max length, num of nulls, ndv - binary: avg length, max length, num of nulls - boolean: num of nulls, num of trues, num of falsies Also support storing and loading these statistics. One thing to notice: We support analyzing columns independently, e.g.: sql1: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key;` sql2: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS value;` when running sql2 to collect column stats for `value`, we don’t remove stats of columns `key` which are analyzed in sql1 and not in sql2. As a result, **users need to guarantee consistency** between sql1 and sql2. If the table has been changed before sql2, users should re-analyze column `key` when they want to analyze column `value`: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key, value;` ## How was this patch tested? add unit tests Author: Zhenhua Wang Closes #15090 from wzhfy/colStats. --- .../spark/sql/execution/SparkSqlParser.scala | 18 +- .../execution/command/AnalyzeColumnCommand.scala | 175 +++++++++++ .../execution/command/AnalyzeTableCommand.scala | 112 +++---- .../org/apache/spark/sql/internal/SQLConf.scala | 9 + .../apache/spark/sql/internal/SessionState.scala | 8 +- .../apache/spark/sql/StatisticsColumnSuite.scala | 334 +++++++++++++++++++++ .../org/apache/spark/sql/StatisticsSuite.scala | 16 +- .../org/apache/spark/sql/StatisticsTest.scala | 129 ++++++++ 8 files changed, 724 insertions(+), 77 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3f34d0f253..7f1e23e665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -87,19 +87,27 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN - * option (other options are passed on to Hive) e.g.: + * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. + * Example SQL for analyzing table : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN]; + * }}} + * Example SQL for analyzing columns : + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2; * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec == null && ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + } else if (ctx.identifierSeq() == null) { + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) + AnalyzeColumnCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala new file mode 100644 index 0000000000..7066378279 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types._ + + +/** + * Analyzes the given columns of the given table to generate statistics, which will be used in + * query optimizations. + */ +case class AnalyzeColumnCommand( + tableIdent: TableIdentifier, + columnNames: Seq[String]) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) + + relation match { + case catalogRel: CatalogRelation => + updateStats(catalogRel.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) + + case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => + updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + + case otherRelation => + throw new AnalysisException("ANALYZE TABLE is not supported for " + + s"${otherRelation.nodeName}.") + } + + def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { + val (rowCount, columnStats) = computeColStats(sparkSession, relation) + val statistics = Statistics( + sizeInBytes = newTotalSize, + rowCount = Some(rowCount), + colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) + } + + Seq.empty[Row] + } + + def computeColStats( + sparkSession: SparkSession, + relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + + // check correctness of column names + val attributesToAnalyze = mutable.MutableList[Attribute]() + val duplicatedColumns = mutable.MutableList[String]() + val resolver = sparkSession.sessionState.conf.resolver + columnNames.foreach { col => + val exprOption = relation.output.find(attr => resolver(attr.name, col)) + val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + // do deduplication + if (!attributesToAnalyze.contains(expr)) { + attributesToAnalyze += expr + } else { + duplicatedColumns += col + } + } + if (duplicatedColumns.nonEmpty) { + logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + + s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + } + + // Collect statistics per column. + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the ColumnStats. + val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError + val expressions = Count(Literal(1)).toAggregateExpression() +: + attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .queryExecution.toRdd.collect().head + + // unwrap the result + val rowCount = statsRow.getLong(0) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => + val numFields = ColumnStatStruct.numStatFields(expr.dataType) + (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + }.toMap + (rowCount, columnStats) + } +} + +object ColumnStatStruct { + val zero = Literal(0, LongType) + val one = Literal(1, LongType) + + def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + def max(e: Expression): Expression = Max(e) + def min(e: Expression): Expression = Min(e) + def ndv(e: Expression, relativeSD: Double): Expression = { + // the approximate ndv should never be larger than the number of rows + Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) + } + def avgLength(e: Expression): Expression = Average(Length(e)) + def maxLength(e: Expression): Expression = Max(Length(e)) + def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + + def getStruct(exprs: Seq[Expression]): CreateStruct = { + CreateStruct(exprs.map { expr: Expression => + expr.transformUp { + case af: AggregateFunction => af.toAggregateExpression() + } + }) + } + + def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) + } + + def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) + } + + def binaryColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e)) + } + + def booleanColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), numTrues(e), numFalses(e)) + } + + def numStatFields(dataType: DataType): Int = { + dataType match { + case BinaryType | BooleanType => 3 + case _ => 4 + } + } + + def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + // Use aggregate functions to compute statistics we need. + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) + case StringType => getStruct(stringColumnStat(e, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(e)) + case BooleanType => getStruct(booleanColumnStat(e)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${e.name} of data type: ${e.dataType}.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 40aecafecf..7b0e49b665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -21,81 +21,40 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SessionState /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Analyzes the given table to generate statistics, which will be used in query optimizations. */ -case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extends RunnableCommand { +case class AnalyzeTableCommand( + tableIdent: TableIdentifier, + noscan: Boolean = true) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case relation: CatalogRelation => - val catalogTable: CatalogTable = relation.catalogTable - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateTableSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - val newTotalSize = - catalogTable.storage.locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - calculateTableSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - - updateTableStats(catalogTable, newTotalSize) + updateTableStats(relation.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) case otherRelation => - throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + throw new AnalysisException("ANALYZE TABLE is not supported for " + s"${otherRelation.nodeName}.") } @@ -125,10 +84,57 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) } } Seq.empty[Row] } } + +object AnalyzeTableCommand extends Logging { + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + catalogTable.storage.locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + calculateTableSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${catalogTable.identifier.table} in the " + + s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e67140fefe..fecdf792fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -581,6 +581,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val NDV_MAX_ERROR = + SQLConfigBuilder("spark.sql.statistics.ndv.maxError") + .internal() + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doubleConf + .createWithDefault(0.05) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -757,6 +764,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c899773b6b..9f7d0019c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -188,11 +189,8 @@ private[sql] class SessionState(sparkSession: SparkSession) { /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. - * - * Right now, it only supports catalog tables and it only updates the size of a catalog table - * in the external catalog. */ - def analyze(tableName: String, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableName, noscan).run(sparkSession) + def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { + AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala new file mode 100644 index 0000000000..0ee0547c45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + +class StatisticsColumnSuite extends StatisticsTest { + import testImplicits._ + + test("parse analyze column commands") { + val tableName = "tbl" + + // we need to specify column names + intercept[ParseException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") + } + + val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) + val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) + comparePlans(parsed, expected) + } + + test("analyzing columns of non-atomic types is not supported") { + val tableName = "tbl" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + val err = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err.message.contains("Analyzing columns is not supported")) + } + } + + test("check correctness of columns") { + val table = "tbl" + val colName1 = "abc" + val colName2 = "x.yz" + withTable(table) { + sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") + + val invalidColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") + } + assert(invalidColError.message == "Invalid column name: key.") + + withSQLConf("spark.sql.caseSensitive" -> "true") { + val invalidErr = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") + } + assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") + } + + withSQLConf("spark.sql.caseSensitive" -> "false") { + val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + assert(columnStats.contains(colName1)) + assert(columnStats.contains(colName2)) + // check deduplication + assert(columnStats.size == 2) + assert(!columnStats.contains(colName2.toUpperCase)) + } + } + } + + private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { + values.filter(_.isDefined).map(_.get) + } + + test("column-level statistics for integral type columns") { + val values = (0 to 5).map { i => + if (i % 2 == 0) None else Some(i) + } + val data = values.map { i => + (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) + } + + val df = data.toDF("c1", "c2", "c3", "c4") + val nonNullValues = getNonNullValues[Int](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.max, + nonNullValues.min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for fractional type columns") { + val values: Seq[Option[Decimal]] = (0 to 5).map { i => + if (i == 0) None else Some(Decimal(i + i * 0.01)) + } + val data = values.map { i => + (i.map(_.toFloat), i.map(_.toDouble), i) + } + + val df = data.toDF("c1", "c2", "c3") + val nonNullValues = getNonNullValues[Decimal](values) + val numNulls = values.count(_.isEmpty).toLong + val ndv = nonNullValues.distinct.length.toLong + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case floatType: FloatType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, + ndv)) + case doubleType: DoubleType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, + ndv)) + case decimalType: DecimalType => + ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for string column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[String](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for binary column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Array[Byte]](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for boolean column") { + val values = Seq(None, Some(true), Some(false), Some(true)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Boolean](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.count(_.equals(true)).toLong, + nonNullValues.count(_.equals(false)).toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for date column") { + val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Date](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, DateType is represented as the number of days from 1970-01-01. + nonNullValues.map(DateTimeUtils.fromJavaDate).max, + nonNullValues.map(DateTimeUtils.fromJavaDate).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for timestamp column") { + val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => + i.map(Timestamp.valueOf) + } + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Timestamp](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, TimestampType is represented as the number of days from 1970-01-01 + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for null columns") { + val values = Seq(None, None) + val data = values.map { i => + (i.map(_.toString), i.map(_.toString.toInt)) + } + val df = data.toDF("c1", "c2") + val expectedColStatsSeq = df.schema.map { f => + (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for columns with different types") { + val intSeq = Seq(1, 2) + val doubleSeq = Seq(1.01d, 2.02d) + val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) + val booleanSeq = Seq(true, false) + val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) + val longSeq = Seq(5L, 4L) + + val data = intSeq.indices.map { i => + (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), + timestampSeq(i), longSeq(i)) + } + val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case DoubleType => + ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, + doubleSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + case DateType => + ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, + dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) + case TimestampType => + ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, + timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, + timestampSeq.distinct.length.toLong)) + case LongType => + ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("update table-level stats while collecting column-level stats") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int) USING PARQUET") + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + checkTableStats(tableName = table, expectedRowCount = Some(1)) + + // update table-level stats between analyze table and analyze column commands + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) + + val colStat = fetchedStats.get.colStats("c1") + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = colStat, + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + } + + test("analyze column stats independently") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + assert(fetchedStats1.get.colStats.size == 1) + val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) + val rsd = spark.sessionState.conf.ndvMaxError + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats1.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + // column c1 is kept in the stats + assert(fetchedStats2.get.colStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats2.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) + StatisticsTest.checkColStat( + dataType = LongType, + colStat = fetchedStats2.get.colStats("c2"), + expectedColStat = expected2, + rsd = rsd) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 264a2ffbeb..8cf42e9248 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with SharedSQLContext { +class StatisticsSuite extends StatisticsTest { import testImplicits._ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { @@ -77,20 +75,10 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { } test("test table-level statistics for data source table created in InMemoryCatalog") { - def checkTableStats(tableName: String, expectedRowCount: Option[BigInt]): Unit = { - val df = sql(s"SELECT * FROM $tableName") - val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.isDefined) - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel - } - assert(relations.size === 1) - } - val tableName = "tbl" withTable(tableName) { sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala new file mode 100644 index 0000000000..5134ac0e7e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +trait StatisticsTest extends QueryTest with SharedSQLContext { + + def checkColStats( + df: DataFrame, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val table = "tbl" + withTable(table) { + df.write.format("json").saveAsTable(table) + val columns = expectedColStatsSeq.map(_._1) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = colStat, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + + // check if we get the same colStat after encoding and decoding + val encodedCS = colStat.toString + val numFields = ColumnStatStruct.numStatFields(field.dataType) + val decodedCS = ColumnStat(numFields, encodedCS) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = decodedCS, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } + } + + def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } +} + +object StatisticsTest { + def checkColStat( + dataType: DataType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + dataType match { + case StringType => + val cs = colStat.forString + val expectedCS = expectedColStat.forString + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + case BinaryType => + val cs = colStat.forBinary + val expectedCS = expectedColStat.forBinary + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + case BooleanType => + val cs = colStat.forBoolean + val expectedCS = expectedColStat.forBoolean + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.numTrues == expectedCS.numTrues) + assert(cs.numFalses == expectedCS.numFalses) + case atomicType: AtomicType => + checkNumericColStats( + dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) + } + } + + private def checkNumericColStats( + dataType: AtomicType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + val cs = colStat.forNumeric(dataType) + val expectedCS = expectedColStat.forNumeric(dataType) + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.max == expectedCS.max) + assert(cs.min == expectedCS.min) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + } + + private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { + // ndv is an approximate value, so we make sure we have the value, and it should be + // within 3*SD's of the given rsd. + if (expectedNdv == 0) { + assert(ndv == 0) + } else if (expectedNdv > 0) { + assert(ndv > 0) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") + } + } +} -- cgit v1.2.3