aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorZhenhua Wang <wzh_zju@163.com>2016-10-03 10:12:02 -0700
committerReynold Xin <rxin@databricks.com>2016-10-03 10:12:02 -0700
commit7bf92127643570e4eb3610fa3ffd36839eba2718 (patch)
tree14386f49f956e97b50a8d6b2bbf0f776eab4dd39 /sql/core
parenta27033c0bbaae8f31db9b91693947ed71738ed11 (diff)
downloadspark-7bf92127643570e4eb3610fa3ffd36839eba2718.tar.gz
spark-7bf92127643570e4eb3610fa3ffd36839eba2718.tar.bz2
spark-7bf92127643570e4eb3610fa3ffd36839eba2718.zip
[SPARK-17073][SQL] generate column-level statistics
## What changes were proposed in this pull request? Generate basic column statistics for all the atomic types: - numeric types: max, min, num of nulls, ndv (number of distinct values) - date/timestamp types: they are also represented as numbers internally, so they have the same stats as above. - string: avg length, max length, num of nulls, ndv - binary: avg length, max length, num of nulls - boolean: num of nulls, num of trues, num of falsies Also support storing and loading these statistics. One thing to notice: We support analyzing columns independently, e.g.: sql1: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key;` sql2: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS value;` when running sql2 to collect column stats for `value`, we don’t remove stats of columns `key` which are analyzed in sql1 and not in sql2. As a result, **users need to guarantee consistency** between sql1 and sql2. If the table has been changed before sql2, users should re-analyze column `key` when they want to analyze column `value`: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key, value;` ## How was this patch tested? add unit tests Author: Zhenhua Wang <wzh_zju@163.com> Closes #15090 from wzhfy/colStats.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala175
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala112
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala334
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala129
8 files changed, 724 insertions, 77 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 3f34d0f253..7f1e23e665 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -87,19 +87,27 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
}
/**
- * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN
- * option (other options are passed on to Hive) e.g.:
+ * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command.
+ * Example SQL for analyzing table :
* {{{
- * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN;
+ * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN];
+ * }}}
+ * Example SQL for analyzing columns :
+ * {{{
+ * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2;
* }}}
*/
override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) {
if (ctx.partitionSpec == null &&
ctx.identifier != null &&
ctx.identifier.getText.toLowerCase == "noscan") {
- AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString)
+ AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier))
+ } else if (ctx.identifierSeq() == null) {
+ AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false)
} else {
- AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false)
+ AnalyzeColumnCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ visitIdentifierSeq(ctx.identifierSeq()))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
new file mode 100644
index 0000000000..7066378279
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import scala.collection.mutable
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
+import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.types._
+
+
+/**
+ * Analyzes the given columns of the given table to generate statistics, which will be used in
+ * query optimizations.
+ */
+case class AnalyzeColumnCommand(
+ tableIdent: TableIdentifier,
+ columnNames: Seq[String]) extends RunnableCommand {
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val sessionState = sparkSession.sessionState
+ val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
+ val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
+ val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB))
+
+ relation match {
+ case catalogRel: CatalogRelation =>
+ updateStats(catalogRel.catalogTable,
+ AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable))
+
+ case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined =>
+ updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes)
+
+ case otherRelation =>
+ throw new AnalysisException("ANALYZE TABLE is not supported for " +
+ s"${otherRelation.nodeName}.")
+ }
+
+ def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
+ val (rowCount, columnStats) = computeColStats(sparkSession, relation)
+ val statistics = Statistics(
+ sizeInBytes = newTotalSize,
+ rowCount = Some(rowCount),
+ colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map()))
+ sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
+ // Refresh the cached data source table in the catalog.
+ sessionState.catalog.refreshTable(tableIdentWithDB)
+ }
+
+ Seq.empty[Row]
+ }
+
+ def computeColStats(
+ sparkSession: SparkSession,
+ relation: LogicalPlan): (Long, Map[String, ColumnStat]) = {
+
+ // check correctness of column names
+ val attributesToAnalyze = mutable.MutableList[Attribute]()
+ val duplicatedColumns = mutable.MutableList[String]()
+ val resolver = sparkSession.sessionState.conf.resolver
+ columnNames.foreach { col =>
+ val exprOption = relation.output.find(attr => resolver(attr.name, col))
+ val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
+ // do deduplication
+ if (!attributesToAnalyze.contains(expr)) {
+ attributesToAnalyze += expr
+ } else {
+ duplicatedColumns += col
+ }
+ }
+ if (duplicatedColumns.nonEmpty) {
+ logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " +
+ s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.")
+ }
+
+ // Collect statistics per column.
+ // The first element in the result will be the overall row count, the following elements
+ // will be structs containing all column stats.
+ // The layout of each struct follows the layout of the ColumnStats.
+ val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError
+ val expressions = Count(Literal(1)).toAggregateExpression() +:
+ attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr))
+ val namedExpressions = expressions.map(e => Alias(e, e.toString)())
+ val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation))
+ .queryExecution.toRdd.collect().head
+
+ // unwrap the result
+ val rowCount = statsRow.getLong(0)
+ val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
+ val numFields = ColumnStatStruct.numStatFields(expr.dataType)
+ (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields)))
+ }.toMap
+ (rowCount, columnStats)
+ }
+}
+
+object ColumnStatStruct {
+ val zero = Literal(0, LongType)
+ val one = Literal(1, LongType)
+
+ def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
+ def max(e: Expression): Expression = Max(e)
+ def min(e: Expression): Expression = Min(e)
+ def ndv(e: Expression, relativeSD: Double): Expression = {
+ // the approximate ndv should never be larger than the number of rows
+ Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
+ }
+ def avgLength(e: Expression): Expression = Average(Length(e))
+ def maxLength(e: Expression): Expression = Max(Length(e))
+ def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
+ def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
+
+ def getStruct(exprs: Seq[Expression]): CreateStruct = {
+ CreateStruct(exprs.map { expr: Expression =>
+ expr.transformUp {
+ case af: AggregateFunction => af.toAggregateExpression()
+ }
+ })
+ }
+
+ def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
+ Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
+ }
+
+ def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
+ Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
+ }
+
+ def binaryColumnStat(e: Expression): Seq[Expression] = {
+ Seq(numNulls(e), avgLength(e), maxLength(e))
+ }
+
+ def booleanColumnStat(e: Expression): Seq[Expression] = {
+ Seq(numNulls(e), numTrues(e), numFalses(e))
+ }
+
+ def numStatFields(dataType: DataType): Int = {
+ dataType match {
+ case BinaryType | BooleanType => 3
+ case _ => 4
+ }
+ }
+
+ def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match {
+ // Use aggregate functions to compute statistics we need.
+ case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD))
+ case StringType => getStruct(stringColumnStat(e, relativeSD))
+ case BinaryType => getStruct(binaryColumnStat(e))
+ case BooleanType => getStruct(booleanColumnStat(e))
+ case otherType =>
+ throw new AnalysisException("Analyzing columns is not supported for column " +
+ s"${e.name} of data type: ${e.dataType}.")
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index 40aecafecf..7b0e49b665 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -21,81 +21,40 @@ import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.internal.SessionState
/**
- * Analyzes the given table in the current database to generate statistics, which will be
- * used in query optimizations.
+ * Analyzes the given table to generate statistics, which will be used in query optimizations.
*/
-case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extends RunnableCommand {
+case class AnalyzeTableCommand(
+ tableIdent: TableIdentifier,
+ noscan: Boolean = true) extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
val sessionState = sparkSession.sessionState
- val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
- val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db))
- val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB))
+ val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
+ val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB))
relation match {
case relation: CatalogRelation =>
- val catalogTable: CatalogTable = relation.catalogTable
- // This method is mainly based on
- // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table)
- // in Hive 0.13 (except that we do not use fs.getContentSummary).
- // TODO: Generalize statistics collection.
- // TODO: Why fs.getContentSummary returns wrong size on Jenkins?
- // Can we use fs.getContentSummary in future?
- // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use
- // countFileSize to count the table size.
- val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging")
-
- def calculateTableSize(fs: FileSystem, path: Path): Long = {
- val fileStatus = fs.getFileStatus(path)
- val size = if (fileStatus.isDirectory) {
- fs.listStatus(path)
- .map { status =>
- if (!status.getPath.getName.startsWith(stagingDir)) {
- calculateTableSize(fs, status.getPath)
- } else {
- 0L
- }
- }.sum
- } else {
- fileStatus.getLen
- }
-
- size
- }
-
- val newTotalSize =
- catalogTable.storage.locationUri.map { p =>
- val path = new Path(p)
- try {
- val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
- calculateTableSize(fs, path)
- } catch {
- case NonFatal(e) =>
- logWarning(
- s"Failed to get the size of table ${catalogTable.identifier.table} in the " +
- s"database ${catalogTable.identifier.database} because of ${e.toString}", e)
- 0L
- }
- }.getOrElse(0L)
-
- updateTableStats(catalogTable, newTotalSize)
+ updateTableStats(relation.catalogTable,
+ AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable))
// data source tables have been converted into LogicalRelations
case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined =>
updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes)
case otherRelation =>
- throw new AnalysisException(s"ANALYZE TABLE is not supported for " +
+ throw new AnalysisException("ANALYZE TABLE is not supported for " +
s"${otherRelation.nodeName}.")
}
@@ -125,10 +84,57 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend
if (newStats.isDefined) {
sessionState.catalog.alterTable(catalogTable.copy(stats = newStats))
// Refresh the cached data source table in the catalog.
- sessionState.catalog.refreshTable(tableIdent)
+ sessionState.catalog.refreshTable(tableIdentWithDB)
}
}
Seq.empty[Row]
}
}
+
+object AnalyzeTableCommand extends Logging {
+
+ def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = {
+ // This method is mainly based on
+ // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table)
+ // in Hive 0.13 (except that we do not use fs.getContentSummary).
+ // TODO: Generalize statistics collection.
+ // TODO: Why fs.getContentSummary returns wrong size on Jenkins?
+ // Can we use fs.getContentSummary in future?
+ // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use
+ // countFileSize to count the table size.
+ val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging")
+
+ def calculateTableSize(fs: FileSystem, path: Path): Long = {
+ val fileStatus = fs.getFileStatus(path)
+ val size = if (fileStatus.isDirectory) {
+ fs.listStatus(path)
+ .map { status =>
+ if (!status.getPath.getName.startsWith(stagingDir)) {
+ calculateTableSize(fs, status.getPath)
+ } else {
+ 0L
+ }
+ }.sum
+ } else {
+ fileStatus.getLen
+ }
+
+ size
+ }
+
+ catalogTable.storage.locationUri.map { p =>
+ val path = new Path(p)
+ try {
+ val fs = path.getFileSystem(sessionState.newHadoopConf())
+ calculateTableSize(fs, path)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(
+ s"Failed to get the size of table ${catalogTable.identifier.table} in the " +
+ s"database ${catalogTable.identifier.database} because of ${e.toString}", e)
+ 0L
+ }
+ }.getOrElse(0L)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e67140fefe..fecdf792fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -581,6 +581,13 @@ object SQLConf {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefault(10L)
+ val NDV_MAX_ERROR =
+ SQLConfigBuilder("spark.sql.statistics.ndv.maxError")
+ .internal()
+ .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.")
+ .doubleConf
+ .createWithDefault(0.05)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -757,6 +764,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
+
+ def ndvMaxError: Double = getConf(NDV_MAX_ERROR)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index c899773b6b..9f7d0019c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -188,11 +189,8 @@ private[sql] class SessionState(sparkSession: SparkSession) {
/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
- *
- * Right now, it only supports catalog tables and it only updates the size of a catalog table
- * in the external catalog.
*/
- def analyze(tableName: String, noscan: Boolean = true): Unit = {
- AnalyzeTableCommand(tableName, noscan).run(sparkSession)
+ def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
+ AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
new file mode 100644
index 0000000000..0ee0547c45
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
+import org.apache.spark.sql.test.SQLTestData.ArrayData
+import org.apache.spark.sql.types._
+
+class StatisticsColumnSuite extends StatisticsTest {
+ import testImplicits._
+
+ test("parse analyze column commands") {
+ val tableName = "tbl"
+
+ // we need to specify column names
+ intercept[ParseException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS")
+ }
+
+ val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value"
+ val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql)
+ val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value"))
+ comparePlans(parsed, expected)
+ }
+
+ test("analyzing columns of non-atomic types is not supported") {
+ val tableName = "tbl"
+ withTable(tableName) {
+ Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
+ val err = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
+ }
+ assert(err.message.contains("Analyzing columns is not supported"))
+ }
+ }
+
+ test("check correctness of columns") {
+ val table = "tbl"
+ val colName1 = "abc"
+ val colName2 = "x.yz"
+ withTable(table) {
+ sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET")
+
+ val invalidColError = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key")
+ }
+ assert(invalidColError.message == "Invalid column name: key.")
+
+ withSQLConf("spark.sql.caseSensitive" -> "true") {
+ val invalidErr = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}")
+ }
+ assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.")
+ }
+
+ withSQLConf("spark.sql.caseSensitive" -> "false") {
+ val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2)
+ val tableIdent = TableIdentifier(table, Some("default"))
+ val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
+ val (_, columnStats) =
+ AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation)
+ assert(columnStats.contains(colName1))
+ assert(columnStats.contains(colName2))
+ // check deduplication
+ assert(columnStats.size == 2)
+ assert(!columnStats.contains(colName2.toUpperCase))
+ }
+ }
+ }
+
+ private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = {
+ values.filter(_.isDefined).map(_.get)
+ }
+
+ test("column-level statistics for integral type columns") {
+ val values = (0 to 5).map { i =>
+ if (i % 2 == 0) None else Some(i)
+ }
+ val data = values.map { i =>
+ (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong))
+ }
+
+ val df = data.toDF("c1", "c2", "c3", "c4")
+ val nonNullValues = getNonNullValues[Int](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ nonNullValues.max,
+ nonNullValues.min,
+ nonNullValues.distinct.length.toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for fractional type columns") {
+ val values: Seq[Option[Decimal]] = (0 to 5).map { i =>
+ if (i == 0) None else Some(Decimal(i + i * 0.01))
+ }
+ val data = values.map { i =>
+ (i.map(_.toFloat), i.map(_.toDouble), i)
+ }
+
+ val df = data.toDF("c1", "c2", "c3")
+ val nonNullValues = getNonNullValues[Decimal](values)
+ val numNulls = values.count(_.isEmpty).toLong
+ val ndv = nonNullValues.distinct.length.toLong
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = f.dataType match {
+ case floatType: FloatType =>
+ ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat,
+ ndv))
+ case doubleType: DoubleType =>
+ ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble,
+ ndv))
+ case decimalType: DecimalType =>
+ ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv))
+ }
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for string column") {
+ val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some(""))
+ val df = values.toDF("c1")
+ val nonNullValues = getNonNullValues[String](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
+ nonNullValues.map(_.length).max.toLong,
+ nonNullValues.distinct.length.toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for binary column") {
+ val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes))
+ val df = values.toDF("c1")
+ val nonNullValues = getNonNullValues[Array[Byte]](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
+ nonNullValues.map(_.length).max.toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for boolean column") {
+ val values = Seq(None, Some(true), Some(false), Some(true))
+ val df = values.toDF("c1")
+ val nonNullValues = getNonNullValues[Boolean](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ nonNullValues.count(_.equals(true)).toLong,
+ nonNullValues.count(_.equals(false)).toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for date column") {
+ val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf))
+ val df = values.toDF("c1")
+ val nonNullValues = getNonNullValues[Date](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ // Internally, DateType is represented as the number of days from 1970-01-01.
+ nonNullValues.map(DateTimeUtils.fromJavaDate).max,
+ nonNullValues.map(DateTimeUtils.fromJavaDate).min,
+ nonNullValues.distinct.length.toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for timestamp column") {
+ val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i =>
+ i.map(Timestamp.valueOf)
+ }
+ val df = values.toDF("c1")
+ val nonNullValues = getNonNullValues[Timestamp](values)
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = ColumnStat(InternalRow(
+ values.count(_.isEmpty).toLong,
+ // Internally, TimestampType is represented as the number of days from 1970-01-01
+ nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max,
+ nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min,
+ nonNullValues.distinct.length.toLong))
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for null columns") {
+ val values = Seq(None, None)
+ val data = values.map { i =>
+ (i.map(_.toString), i.map(_.toString.toInt))
+ }
+ val df = data.toDF("c1", "c2")
+ val expectedColStatsSeq = df.schema.map { f =>
+ (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)))
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("column-level statistics for columns with different types") {
+ val intSeq = Seq(1, 2)
+ val doubleSeq = Seq(1.01d, 2.02d)
+ val stringSeq = Seq("a", "bb")
+ val binarySeq = Seq("a", "bb").map(_.getBytes)
+ val booleanSeq = Seq(true, false)
+ val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf)
+ val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf)
+ val longSeq = Seq(5L, 4L)
+
+ val data = intSeq.indices.map { i =>
+ (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i),
+ timestampSeq(i), longSeq(i))
+ }
+ val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8")
+ val expectedColStatsSeq = df.schema.map { f =>
+ val colStat = f.dataType match {
+ case IntegerType =>
+ ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
+ case DoubleType =>
+ ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min,
+ doubleSeq.distinct.length.toLong))
+ case StringType =>
+ ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
+ stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong))
+ case BinaryType =>
+ ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
+ binarySeq.map(_.length).max.toLong))
+ case BooleanType =>
+ ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
+ booleanSeq.count(_.equals(false)).toLong))
+ case DateType =>
+ ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max,
+ dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong))
+ case TimestampType =>
+ ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max,
+ timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min,
+ timestampSeq.distinct.length.toLong))
+ case LongType =>
+ ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong))
+ }
+ (f, colStat)
+ }
+ checkColStats(df, expectedColStatsSeq)
+ }
+
+ test("update table-level stats while collecting column-level stats") {
+ val table = "tbl"
+ withTable(table) {
+ sql(s"CREATE TABLE $table (c1 int) USING PARQUET")
+ sql(s"INSERT INTO $table SELECT 1")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS")
+ checkTableStats(tableName = table, expectedRowCount = Some(1))
+
+ // update table-level stats between analyze table and analyze column commands
+ sql(s"INSERT INTO $table SELECT 1")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
+ val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2))
+
+ val colStat = fetchedStats.get.colStats("c1")
+ StatisticsTest.checkColStat(
+ dataType = IntegerType,
+ colStat = colStat,
+ expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
+ rsd = spark.sessionState.conf.ndvMaxError)
+ }
+ }
+
+ test("analyze column stats independently") {
+ val table = "tbl"
+ withTable(table) {
+ sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
+ val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0))
+ assert(fetchedStats1.get.colStats.size == 1)
+ val expected1 = ColumnStat(InternalRow(0L, null, null, 0L))
+ val rsd = spark.sessionState.conf.ndvMaxError
+ StatisticsTest.checkColStat(
+ dataType = IntegerType,
+ colStat = fetchedStats1.get.colStats("c1"),
+ expectedColStat = expected1,
+ rsd = rsd)
+
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2")
+ val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0))
+ // column c1 is kept in the stats
+ assert(fetchedStats2.get.colStats.size == 2)
+ StatisticsTest.checkColStat(
+ dataType = IntegerType,
+ colStat = fetchedStats2.get.colStats("c1"),
+ expectedColStat = expected1,
+ rsd = rsd)
+ val expected2 = ColumnStat(InternalRow(0L, null, null, 0L))
+ StatisticsTest.checkColStat(
+ dataType = LongType,
+ colStat = fetchedStats2.get.colStats("c2"),
+ expectedColStat = expected2,
+ rsd = rsd)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
index 264a2ffbeb..8cf42e9248 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
@@ -18,11 +18,9 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-class StatisticsSuite extends QueryTest with SharedSQLContext {
+class StatisticsSuite extends StatisticsTest {
import testImplicits._
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
@@ -77,20 +75,10 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
}
test("test table-level statistics for data source table created in InMemoryCatalog") {
- def checkTableStats(tableName: String, expectedRowCount: Option[BigInt]): Unit = {
- val df = sql(s"SELECT * FROM $tableName")
- val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- assert(rel.catalogTable.isDefined)
- assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
- rel
- }
- assert(relations.size === 1)
- }
-
val tableName = "tbl"
withTable(tableName) {
sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
- Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl")
+ Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
// noscan won't count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
new file mode 100644
index 0000000000..5134ac0e7e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
+import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+
+trait StatisticsTest extends QueryTest with SharedSQLContext {
+
+ def checkColStats(
+ df: DataFrame,
+ expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
+ val table = "tbl"
+ withTable(table) {
+ df.write.format("json").saveAsTable(table)
+ val columns = expectedColStatsSeq.map(_._1)
+ val tableIdent = TableIdentifier(table, Some("default"))
+ val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
+ val (_, columnStats) =
+ AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation)
+ expectedColStatsSeq.foreach { case (field, expectedColStat) =>
+ assert(columnStats.contains(field.name))
+ val colStat = columnStats(field.name)
+ StatisticsTest.checkColStat(
+ dataType = field.dataType,
+ colStat = colStat,
+ expectedColStat = expectedColStat,
+ rsd = spark.sessionState.conf.ndvMaxError)
+
+ // check if we get the same colStat after encoding and decoding
+ val encodedCS = colStat.toString
+ val numFields = ColumnStatStruct.numStatFields(field.dataType)
+ val decodedCS = ColumnStat(numFields, encodedCS)
+ StatisticsTest.checkColStat(
+ dataType = field.dataType,
+ colStat = decodedCS,
+ expectedColStat = expectedColStat,
+ rsd = spark.sessionState.conf.ndvMaxError)
+ }
+ }
+ }
+
+ def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = {
+ val df = spark.table(tableName)
+ val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
+ assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
+ rel.catalogTable.get.stats
+ }
+ assert(stats.size == 1)
+ stats.head
+ }
+}
+
+object StatisticsTest {
+ def checkColStat(
+ dataType: DataType,
+ colStat: ColumnStat,
+ expectedColStat: ColumnStat,
+ rsd: Double): Unit = {
+ dataType match {
+ case StringType =>
+ val cs = colStat.forString
+ val expectedCS = expectedColStat.forString
+ assert(cs.numNulls == expectedCS.numNulls)
+ assert(cs.avgColLen == expectedCS.avgColLen)
+ assert(cs.maxColLen == expectedCS.maxColLen)
+ checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
+ case BinaryType =>
+ val cs = colStat.forBinary
+ val expectedCS = expectedColStat.forBinary
+ assert(cs.numNulls == expectedCS.numNulls)
+ assert(cs.avgColLen == expectedCS.avgColLen)
+ assert(cs.maxColLen == expectedCS.maxColLen)
+ case BooleanType =>
+ val cs = colStat.forBoolean
+ val expectedCS = expectedColStat.forBoolean
+ assert(cs.numNulls == expectedCS.numNulls)
+ assert(cs.numTrues == expectedCS.numTrues)
+ assert(cs.numFalses == expectedCS.numFalses)
+ case atomicType: AtomicType =>
+ checkNumericColStats(
+ dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd)
+ }
+ }
+
+ private def checkNumericColStats(
+ dataType: AtomicType,
+ colStat: ColumnStat,
+ expectedColStat: ColumnStat,
+ rsd: Double): Unit = {
+ val cs = colStat.forNumeric(dataType)
+ val expectedCS = expectedColStat.forNumeric(dataType)
+ assert(cs.numNulls == expectedCS.numNulls)
+ assert(cs.max == expectedCS.max)
+ assert(cs.min == expectedCS.min)
+ checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
+ }
+
+ private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = {
+ // ndv is an approximate value, so we make sure we have the value, and it should be
+ // within 3*SD's of the given rsd.
+ if (expectedNdv == 0) {
+ assert(ndv == 0)
+ } else if (expectedNdv > 0) {
+ assert(ndv > 0)
+ val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
+ assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.")
+ }
+ }
+}