aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-11-23 20:48:41 +0800
committerWenchen Fan <wenchen@databricks.com>2016-11-23 20:48:41 +0800
commit70ad07a9d20586ae182c4e60ed97bdddbcbceff3 (patch)
tree14666ca06583b5ee8fc6ee09b0434aa824c2efde
parent9785ed40d7fe4e1fcd440e55706519c6e5f8d6b1 (diff)
downloadspark-70ad07a9d20586ae182c4e60ed97bdddbcbceff3.tar.gz
spark-70ad07a9d20586ae182c4e60ed97bdddbcbceff3.tar.bz2
spark-70ad07a9d20586ae182c4e60ed97bdddbcbceff3.zip
[SPARK-18522][SQL] Explicit contract for column stats serialization
## What changes were proposed in this pull request? The current implementation of column stats uses the base64 encoding of the internal UnsafeRow format to persist statistics (in table properties in Hive metastore). This is an internal format that is not stable across different versions of Spark and should NOT be used for persistence. In addition, it would be better if statistics stored in the catalog is human readable. This pull request introduces the following changes: 1. Created a single ColumnStat class to for all data types. All data types track the same set of statistics. 2. Updated the implementation for stats collection to get rid of the dependency on internal data structures (e.g. InternalRow, or storing DateType as an int32). For example, previously dates were stored as a single integer, but are now stored as java.sql.Date. When we implement the next steps of CBO, we can add code to convert those back into internal types again. 3. Documented clearly what JVM data types are being used to store what data. 4. Defined a simple Map[String, String] interface for serializing and deserializing column stats into/from the catalog. 5. Rearranged the method/function structure so it is more clear what the supported data types are, and also moved how stats are generated into ColumnStat class so they are easy to find. ## How was this patch tested? Removed most of the original test cases created for column statistics, and added three very simple ones to cover all the cases. The three test cases validate: 1. Roundtrip serialization works. 2. Behavior when analyzing non-existent column or unsupported data type column. 3. Result for stats collection for all valid data types. Also moved parser related tests into a parser test suite and added an explicit serialization test for the Hive external catalog. Author: Reynold Xin <rxin@databricks.com> Closes #15959 from rxin/SPARK-18522.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala212
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala105
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala218
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala334
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala92
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala93
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala299
9 files changed, 591 insertions, 918 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index f3e2147b8f..79865609cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.commons.codec.binary.Base64
+import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._
+
/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
* corresponding statistic produced by the children. To override this behavior, override
@@ -58,60 +61,175 @@ case class Statistics(
}
}
+
/**
- * Statistics for a column.
+ * Statistics collected for a column.
+ *
+ * 1. Supported data types are defined in `ColumnStat.supportsType`.
+ * 2. The JVM data type stored in min/max is the external data type (used in Row) for the
+ * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
+ * TimestampType we store java.sql.Timestamp.
+ * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
+ * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * (sketches) might have been used, and the data collected can also be stale.
+ *
+ * @param distinctCount number of distinct values
+ * @param min minimum value
+ * @param max maximum value
+ * @param nullCount number of nulls
+ * @param avgLen average length of the values. For fixed-length types, this should be a constant.
+ * @param maxLen maximum length of the values. For fixed-length types, this should be a constant.
*/
-case class ColumnStat(statRow: InternalRow) {
+case class ColumnStat(
+ distinctCount: BigInt,
+ min: Option[Any],
+ max: Option[Any],
+ nullCount: BigInt,
+ avgLen: Long,
+ maxLen: Long) {
- def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = {
- NumericColumnStat(statRow, dataType)
- }
- def forString: StringColumnStat = StringColumnStat(statRow)
- def forBinary: BinaryColumnStat = BinaryColumnStat(statRow)
- def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow)
+ // We currently don't store min/max for binary/string type. This can change in the future and
+ // then we need to remove this require.
+ require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String]))
+ require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String]))
- override def toString: String = {
- // use Base64 for encoding
- Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes)
+ /**
+ * Returns a map from string to string that can be used to serialize the column stats.
+ * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
+ * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
+ *
+ * As part of the protocol, the returned map always contains a key called "version".
+ * In the case min/max values are null (None), they won't appear in the map.
+ */
+ def toMap: Map[String, String] = {
+ val map = new scala.collection.mutable.HashMap[String, String]
+ map.put(ColumnStat.KEY_VERSION, "1")
+ map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
+ map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
+ map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
+ map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
+ min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
+ max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
+ map.toMap
}
}
-object ColumnStat {
- def apply(numFields: Int, str: String): ColumnStat = {
- // use Base64 for decoding
- val bytes = Base64.decodeBase64(str)
- val unsafeRow = new UnsafeRow(numFields)
- unsafeRow.pointTo(bytes, bytes.length)
- ColumnStat(unsafeRow)
+
+object ColumnStat extends Logging {
+
+ // List of string keys used to serialize ColumnStat
+ val KEY_VERSION = "version"
+ private val KEY_DISTINCT_COUNT = "distinctCount"
+ private val KEY_MIN_VALUE = "min"
+ private val KEY_MAX_VALUE = "max"
+ private val KEY_NULL_COUNT = "nullCount"
+ private val KEY_AVG_LEN = "avgLen"
+ private val KEY_MAX_LEN = "maxLen"
+
+ /** Returns true iff the we support gathering column statistics on column of the given type. */
+ def supportsType(dataType: DataType): Boolean = dataType match {
+ case _: IntegralType => true
+ case _: DecimalType => true
+ case DoubleType | FloatType => true
+ case BooleanType => true
+ case DateType => true
+ case TimestampType => true
+ case BinaryType | StringType => true
+ case _ => false
}
-}
-case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) {
- // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType]
- val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType]
- val ndv: Long = statRow.getLong(3)
-}
+ /**
+ * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
+ * from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
+ */
+ def fromMap(table: String, field: StructField, map: Map[String, String])
+ : Option[ColumnStat] = {
+ val str2val: (String => Any) = field.dataType match {
+ case _: IntegralType => _.toLong
+ case _: DecimalType => new java.math.BigDecimal(_)
+ case DoubleType | FloatType => _.toDouble
+ case BooleanType => _.toBoolean
+ case DateType => java.sql.Date.valueOf
+ case TimestampType => java.sql.Timestamp.valueOf
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case BinaryType | StringType => _ => null
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column ${field.name} of data type: ${field.dataType}.")
+ }
-case class StringColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val avgColLen: Double = statRow.getDouble(1)
- val maxColLen: Long = statRow.getInt(2)
- val ndv: Long = statRow.getLong(3)
-}
+ try {
+ Some(ColumnStat(
+ distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
+ // Note that flatMap(Option.apply) turns Option(null) into None.
+ min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
+ max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
+ nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
+ avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
+ maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
+ ))
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e)
+ None
+ }
+ }
-case class BinaryColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val avgColLen: Double = statRow.getDouble(1)
- val maxColLen: Long = statRow.getInt(2)
-}
+ /**
+ * Constructs an expression to compute column statistics for a given column.
+ *
+ * The expression should create a single struct column with the following schema:
+ * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long
+ *
+ * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and
+ * as a result should stay in sync with it.
+ */
+ def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = {
+ def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
+ expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
+ })
+ val one = Literal(1, LongType)
+
+ // the approximate ndv (num distinct value) should never be larger than the number of rows
+ val numNonNulls = if (col.nullable) Count(col) else Count(one)
+ val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls))
+ val numNulls = Subtract(Count(one), numNonNulls)
+
+ def fixedLenTypeStruct(castType: DataType) = {
+ // For fixed width types, avg size should be the same as max size.
+ val avgSize = Literal(col.dataType.defaultSize, LongType)
+ struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize)
+ }
+
+ col.dataType match {
+ case _: IntegralType => fixedLenTypeStruct(LongType)
+ case _: DecimalType => fixedLenTypeStruct(col.dataType)
+ case DoubleType | FloatType => fixedLenTypeStruct(DoubleType)
+ case BooleanType => fixedLenTypeStruct(col.dataType)
+ case DateType => fixedLenTypeStruct(col.dataType)
+ case TimestampType => fixedLenTypeStruct(col.dataType)
+ case BinaryType | StringType =>
+ // For string and binary type, we don't store min/max.
+ val nullLit = Literal(null, col.dataType)
+ struct(
+ ndv, nullLit, nullLit, numNulls,
+ Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType))
+ case _ =>
+ throw new AnalysisException("Analyzing column statistics is not supported for column " +
+ s"${col.name} of data type: ${col.dataType}.")
+ }
+ }
+
+ /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
+ def rowToColumnStat(row: Row): ColumnStat = {
+ ColumnStat(
+ distinctCount = BigInt(row.getLong(0)),
+ min = Option(row.get(1)), // for string/binary min/max, get should return null
+ max = Option(row.get(2)),
+ nullCount = BigInt(row.getLong(3)),
+ avgLen = row.getLong(4),
+ maxLen = row.getLong(5)
+ )
+ }
-case class BooleanColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val numTrues: Long = statRow.getLong(1)
- val numFalses: Long = statRow.getLong(2)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 7fc57d09e9..9dffe3614a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types._
/**
@@ -62,7 +61,7 @@ case class AnalyzeColumnCommand(
// Compute stats for each column
val (rowCount, newColStats) =
- AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames)
+ AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
@@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging {
*
* This is visible for testing.
*/
- def computeColStats(
+ def computeColumnStats(
sparkSession: SparkSession,
+ tableName: String,
relation: LogicalPlan,
columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = {
@@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging {
val resolver = sparkSession.sessionState.conf.resolver
val attributesToAnalyze = AttributeSet(columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
- exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
+ exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
}).toSeq
+ // Make sure the column types are supported for stats gathering.
+ attributesToAnalyze.foreach { attr =>
+ if (!ColumnStat.supportsType(attr.dataType)) {
+ throw new AnalysisException(
+ s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " +
+ "and Spark does not support statistics collection on this column type.")
+ }
+ }
+
// Collect statistics per column.
// The first element in the result will be the overall row count, the following elements
// will be structs containing all column stats.
// The layout of each struct follows the layout of the ColumnStats.
val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError
val expressions = Count(Literal(1)).toAggregateExpression() +:
- attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr))
+ attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr))
+
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
- val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation))
- .queryExecution.toRdd.collect().head
+ val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
- // unwrap the result
- // TODO: Get rid of numFields by using the public Dataset API.
val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
- val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType)
- (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields)))
+ (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
}.toMap
(rowCount, columnStats)
}
-
- private val zero = Literal(0, LongType)
- private val one = Literal(1, LongType)
-
- private def numNulls(e: Expression): Expression = {
- if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
- }
- private def max(e: Expression): Expression = Max(e)
- private def min(e: Expression): Expression = Min(e)
- private def ndv(e: Expression, relativeSD: Double): Expression = {
- // the approximate ndv should never be larger than the number of rows
- Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
- }
- private def avgLength(e: Expression): Expression = Average(Length(e))
- private def maxLength(e: Expression): Expression = Max(Length(e))
- private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
- private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
-
- /**
- * Creates a struct that groups the sequence of expressions together. This is used to create
- * one top level struct per column.
- */
- private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = {
- CreateStruct(exprs.map { expr: Expression =>
- expr.transformUp {
- case af: AggregateFunction => af.toAggregateExpression()
- }
- })
- }
-
- private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
- Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
- }
-
- private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
- Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
- }
-
- private def binaryColumnStat(e: Expression): Seq[Expression] = {
- Seq(numNulls(e), avgLength(e), maxLength(e))
- }
-
- private def booleanColumnStat(e: Expression): Seq[Expression] = {
- Seq(numNulls(e), numTrues(e), numFalses(e))
- }
-
- // TODO(rxin): Get rid of this function.
- def numStatFields(dataType: DataType): Int = {
- dataType match {
- case BinaryType | BooleanType => 3
- case _ => 4
- }
- }
-
- /**
- * Creates a struct expression that contains the statistics to collect for a column.
- *
- * @param attr column to collect statistics
- * @param relativeSD relative error for approximate number of distinct values.
- */
- def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = {
- attr.dataType match {
- case _: NumericType | TimestampType | DateType =>
- createStruct(numericColumnStat(attr, relativeSD))
- case StringType =>
- createStruct(stringColumnStat(attr, relativeSD))
- case BinaryType =>
- createStruct(binaryColumnStat(attr))
- case BooleanType =>
- createStruct(booleanColumnStat(attr))
- case otherType =>
- throw new AnalysisException("Analyzing columns is not supported for column " +
- s"${attr.name} of data type: ${attr.dataType}.")
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
new file mode 100644
index 0000000000..1fcccd0610
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.{lang => jl}
+import java.sql.{Date, Timestamp}
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.test.SQLTestData.ArrayData
+import org.apache.spark.sql.types._
+
+
+/**
+ * End-to-end suite testing statistics collection and use on both entire table and columns.
+ */
+class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext {
+ import testImplicits._
+
+ private def checkTableStats(tableName: String, expectedRowCount: Option[Int])
+ : Option[Statistics] = {
+ val df = spark.table(tableName)
+ val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
+ assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
+ rel.catalogTable.get.stats
+ }
+ assert(stats.size == 1)
+ stats.head
+ }
+
+ test("estimates the size of a limit 0 on outer join") {
+ withTempView("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ val df1 = spark.table("test")
+ val df2 = spark.table("test").limit(0)
+ val df = df1.join(df2, Seq("k"), "left")
+
+ val sizes = df.queryExecution.analyzed.collect { case g: Join =>
+ g.statistics.sizeInBytes
+ }
+
+ assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
+ assert(sizes.head === BigInt(96),
+ s"expected exact size 96 for table 'test', got: ${sizes.head}")
+ }
+ }
+
+ test("analyze column command - unsupported types and invalid columns") {
+ val tableName = "column_stats_test1"
+ withTable(tableName) {
+ Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
+
+ // Test unsupported data types
+ val err1 = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
+ }
+ assert(err1.message.contains("does not support statistics collection"))
+
+ // Test invalid columns
+ val err2 = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column")
+ }
+ assert(err2.message.contains("does not exist"))
+ }
+ }
+
+ test("test table-level statistics for data source table") {
+ val tableName = "tbl"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
+ Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
+
+ // noscan won't count the number of rows
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
+ checkTableStats(tableName, expectedRowCount = None)
+
+ // without noscan, we count the number of rows
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+ checkTableStats(tableName, expectedRowCount = Some(2))
+ }
+ }
+
+ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
+ val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
+ val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
+ assert(df.queryExecution.analyzed.statistics.sizeInBytes >
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
+ assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
+ }
+
+ test("estimates the size of limit") {
+ withTempView("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
+ val df = sql(s"""SELECT * FROM test limit $limit""")
+
+ val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
+ g.statistics.sizeInBytes
+ }
+ assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesGlobalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
+
+ val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
+ l.statistics.sizeInBytes
+ }
+ assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesLocalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
+ }
+ }
+ }
+
+}
+
+
+/**
+ * The base for test cases that we want to include in both the hive module (for verifying behavior
+ * when using the Hive external catalog) as well as in the sql/core module.
+ */
+abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils {
+ import testImplicits._
+
+ private val dec1 = new java.math.BigDecimal("1.000000000000000000")
+ private val dec2 = new java.math.BigDecimal("8.000000000000000000")
+ private val d1 = Date.valueOf("2016-05-08")
+ private val d2 = Date.valueOf("2016-05-09")
+ private val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
+ private val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+
+ /**
+ * Define a very simple 3 row table used for testing column serialization.
+ * Note: last column is seq[int] which doesn't support stats collection.
+ */
+ protected val data = Seq[
+ (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long,
+ jl.Double, jl.Float, java.math.BigDecimal,
+ String, Array[Byte], Date, Timestamp,
+ Seq[Int])](
+ (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null),
+ (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null),
+ (null, null, null, null, null, null, null, null, null, null, null, null, null)
+ )
+
+ /** A mapping from column to the stats collected. */
+ protected val stats = mutable.LinkedHashMap(
+ "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
+ "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
+ "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
+ "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
+ "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
+ "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
+ "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
+ "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
+ "cstring" -> ColumnStat(2, None, None, 1, 3, 3),
+ "cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
+ "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
+ "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
+ )
+
+ test("column stats round trip serialization") {
+ // Make sure we serialize and then deserialize and we will get the result data
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+ stats.zip(df.schema).foreach { case ((k, v), field) =>
+ withClue(s"column $k with type ${field.dataType}") {
+ val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
+ assert(roundtrip == Some(v))
+ }
+ }
+ }
+
+ test("analyze column command - result verification") {
+ val tableName = "column_stats_test2"
+ // (data.head.productArity - 1) because the last column does not support stats collection.
+ assert(stats.size == data.head.productArity - 1)
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+
+ withTable(tableName) {
+ df.write.saveAsTable(tableName)
+
+ // Collect statistics
+ sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
+
+ // Validate statistics
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
+ assert(table.stats.isDefined)
+ assert(table.stats.get.colStats.size == stats.size)
+
+ stats.foreach { case (k, v) =>
+ withClue(s"column $k") {
+ assert(table.stats.get.colStats(k) == v)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
deleted file mode 100644
index e866ac2cb3..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
+++ /dev/null
@@ -1,334 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.sql.{Date, Timestamp}
-
-import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
-import org.apache.spark.sql.test.SQLTestData.ArrayData
-import org.apache.spark.sql.types._
-
-class StatisticsColumnSuite extends StatisticsTest {
- import testImplicits._
-
- test("parse analyze column commands") {
- val tableName = "tbl"
-
- // we need to specify column names
- intercept[ParseException] {
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS")
- }
-
- val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value"
- val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql)
- val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value"))
- comparePlans(parsed, expected)
- }
-
- test("analyzing columns of non-atomic types is not supported") {
- val tableName = "tbl"
- withTable(tableName) {
- Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
- val err = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
- }
- assert(err.message.contains("Analyzing columns is not supported"))
- }
- }
-
- test("check correctness of columns") {
- val table = "tbl"
- val colName1 = "abc"
- val colName2 = "x.yz"
- withTable(table) {
- sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET")
-
- val invalidColError = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key")
- }
- assert(invalidColError.message == "Invalid column name: key.")
-
- withSQLConf("spark.sql.caseSensitive" -> "true") {
- val invalidErr = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}")
- }
- assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.")
- }
-
- withSQLConf("spark.sql.caseSensitive" -> "false") {
- val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2)
- val tableIdent = TableIdentifier(table, Some("default"))
- val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
- val (_, columnStats) =
- AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze)
- assert(columnStats.contains(colName1))
- assert(columnStats.contains(colName2))
- // check deduplication
- assert(columnStats.size == 2)
- assert(!columnStats.contains(colName2.toUpperCase))
- }
- }
- }
-
- private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = {
- values.filter(_.isDefined).map(_.get)
- }
-
- test("column-level statistics for integral type columns") {
- val values = (0 to 5).map { i =>
- if (i % 2 == 0) None else Some(i)
- }
- val data = values.map { i =>
- (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong))
- }
-
- val df = data.toDF("c1", "c2", "c3", "c4")
- val nonNullValues = getNonNullValues[Int](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.max,
- nonNullValues.min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for fractional type columns") {
- val values: Seq[Option[Decimal]] = (0 to 5).map { i =>
- if (i == 0) None else Some(Decimal(i + i * 0.01))
- }
- val data = values.map { i =>
- (i.map(_.toFloat), i.map(_.toDouble), i)
- }
-
- val df = data.toDF("c1", "c2", "c3")
- val nonNullValues = getNonNullValues[Decimal](values)
- val numNulls = values.count(_.isEmpty).toLong
- val ndv = nonNullValues.distinct.length.toLong
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = f.dataType match {
- case floatType: FloatType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat,
- ndv))
- case doubleType: DoubleType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble,
- ndv))
- case decimalType: DecimalType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv))
- }
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for string column") {
- val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some(""))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[String](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
- nonNullValues.map(_.length).max.toInt,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for binary column") {
- val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Array[Byte]](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
- nonNullValues.map(_.length).max.toInt))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for boolean column") {
- val values = Seq(None, Some(true), Some(false), Some(true))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Boolean](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.count(_.equals(true)).toLong,
- nonNullValues.count(_.equals(false)).toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for date column") {
- val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Date](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- // Internally, DateType is represented as the number of days from 1970-01-01.
- nonNullValues.map(DateTimeUtils.fromJavaDate).max,
- nonNullValues.map(DateTimeUtils.fromJavaDate).min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for timestamp column") {
- val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i =>
- i.map(Timestamp.valueOf)
- }
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Timestamp](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- // Internally, TimestampType is represented as the number of days from 1970-01-01
- nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max,
- nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for null columns") {
- val values = Seq(None, None)
- val data = values.map { i =>
- (i.map(_.toString), i.map(_.toString.toInt))
- }
- val df = data.toDF("c1", "c2")
- val expectedColStatsSeq = df.schema.map { f =>
- (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)))
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for columns with different types") {
- val intSeq = Seq(1, 2)
- val doubleSeq = Seq(1.01d, 2.02d)
- val stringSeq = Seq("a", "bb")
- val binarySeq = Seq("a", "bb").map(_.getBytes)
- val booleanSeq = Seq(true, false)
- val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf)
- val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf)
- val longSeq = Seq(5L, 4L)
-
- val data = intSeq.indices.map { i =>
- (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i),
- timestampSeq(i), longSeq(i))
- }
- val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8")
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = f.dataType match {
- case IntegerType =>
- ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
- case DoubleType =>
- ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min,
- doubleSeq.distinct.length.toLong))
- case StringType =>
- ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
- stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
- case BinaryType =>
- ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
- binarySeq.map(_.length).max.toInt))
- case BooleanType =>
- ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
- booleanSeq.count(_.equals(false)).toLong))
- case DateType =>
- ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max,
- dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong))
- case TimestampType =>
- ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max,
- timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min,
- timestampSeq.distinct.length.toLong))
- case LongType =>
- ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong))
- }
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("update table-level stats while collecting column-level stats") {
- val table = "tbl"
- withTable(table) {
- sql(s"CREATE TABLE $table (c1 int) USING PARQUET")
- sql(s"INSERT INTO $table SELECT 1")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS")
- checkTableStats(tableName = table, expectedRowCount = Some(1))
-
- // update table-level stats between analyze table and analyze column commands
- sql(s"INSERT INTO $table SELECT 1")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
- val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2))
-
- val colStat = fetchedStats.get.colStats("c1")
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = colStat,
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
-
- test("analyze column stats independently") {
- val table = "tbl"
- withTable(table) {
- sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
- val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0))
- assert(fetchedStats1.get.colStats.size == 1)
- val expected1 = ColumnStat(InternalRow(0L, null, null, 0L))
- val rsd = spark.sessionState.conf.ndvMaxError
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = fetchedStats1.get.colStats("c1"),
- expectedColStat = expected1,
- rsd = rsd)
-
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2")
- val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0))
- // column c1 is kept in the stats
- assert(fetchedStats2.get.colStats.size == 2)
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = fetchedStats2.get.colStats("c1"),
- expectedColStat = expected1,
- rsd = rsd)
- val expected2 = ColumnStat(InternalRow(0L, null, null, 0L))
- StatisticsTest.checkColStat(
- dataType = LongType,
- colStat = fetchedStats2.get.colStats("c2"),
- expectedColStat = expected2,
- rsd = rsd)
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
deleted file mode 100644
index 8cf42e9248..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
-import org.apache.spark.sql.types._
-
-class StatisticsSuite extends StatisticsTest {
- import testImplicits._
-
- test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
- val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
- val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
- assert(df.queryExecution.analyzed.statistics.sizeInBytes >
- spark.sessionState.conf.autoBroadcastJoinThreshold)
- assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
- spark.sessionState.conf.autoBroadcastJoinThreshold)
- }
-
- test("estimates the size of limit") {
- withTempView("test") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
- .createOrReplaceTempView("test")
- Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
- val df = sql(s"""SELECT * FROM test limit $limit""")
-
- val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
- g.statistics.sizeInBytes
- }
- assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesGlobalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
-
- val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
- l.statistics.sizeInBytes
- }
- assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesLocalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
- }
- }
- }
-
- test("estimates the size of a limit 0 on outer join") {
- withTempView("test") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
- .createOrReplaceTempView("test")
- val df1 = spark.table("test")
- val df2 = spark.table("test").limit(0)
- val df = df1.join(df2, Seq("k"), "left")
-
- val sizes = df.queryExecution.analyzed.collect { case g: Join =>
- g.statistics.sizeInBytes
- }
-
- assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
- assert(sizes.head === BigInt(96),
- s"expected exact size 96 for table 'test', got: ${sizes.head}")
- }
- }
-
- test("test table-level statistics for data source table created in InMemoryCatalog") {
- val tableName = "tbl"
- withTable(tableName) {
- sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
- Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
-
- // noscan won't count the number of rows
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
- checkTableStats(tableName, expectedRowCount = None)
-
- // without noscan, we count the number of rows
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
- checkTableStats(tableName, expectedRowCount = Some(2))
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
deleted file mode 100644
index 915ee0d31b..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
-import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types._
-
-
-trait StatisticsTest extends QueryTest with SharedSQLContext {
-
- def checkColStats(
- df: DataFrame,
- expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
- val table = "tbl"
- withTable(table) {
- df.write.format("json").saveAsTable(table)
- val columns = expectedColStatsSeq.map(_._1)
- val tableIdent = TableIdentifier(table, Some("default"))
- val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
- val (_, columnStats) =
- AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name))
- expectedColStatsSeq.foreach { case (field, expectedColStat) =>
- assert(columnStats.contains(field.name))
- val colStat = columnStats(field.name)
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = colStat,
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
-
- // check if we get the same colStat after encoding and decoding
- val encodedCS = colStat.toString
- val numFields = AnalyzeColumnCommand.numStatFields(field.dataType)
- val decodedCS = ColumnStat(numFields, encodedCS)
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = decodedCS,
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
- }
-
- def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = {
- val df = spark.table(tableName)
- val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
- rel.catalogTable.get.stats
- }
- assert(stats.size == 1)
- stats.head
- }
-}
-
-object StatisticsTest {
- def checkColStat(
- dataType: DataType,
- colStat: ColumnStat,
- expectedColStat: ColumnStat,
- rsd: Double): Unit = {
- dataType match {
- case StringType =>
- val cs = colStat.forString
- val expectedCS = expectedColStat.forString
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.avgColLen == expectedCS.avgColLen)
- assert(cs.maxColLen == expectedCS.maxColLen)
- checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
- case BinaryType =>
- val cs = colStat.forBinary
- val expectedCS = expectedColStat.forBinary
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.avgColLen == expectedCS.avgColLen)
- assert(cs.maxColLen == expectedCS.maxColLen)
- case BooleanType =>
- val cs = colStat.forBoolean
- val expectedCS = expectedColStat.forBoolean
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.numTrues == expectedCS.numTrues)
- assert(cs.numFalses == expectedCS.numFalses)
- case atomicType: AtomicType =>
- checkNumericColStats(
- dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd)
- }
- }
-
- private def checkNumericColStats(
- dataType: AtomicType,
- colStat: ColumnStat,
- expectedColStat: ColumnStat,
- rsd: Double): Unit = {
- val cs = colStat.forNumeric(dataType)
- val expectedCS = expectedColStat.forNumeric(dataType)
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.max == expectedCS.max)
- assert(cs.min == expectedCS.min)
- checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
- }
-
- private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = {
- // ndv is an approximate value, so we make sure we have the value, and it should be
- // within 3*SD's of the given rsd.
- if (expectedNdv == 0) {
- assert(ndv == 0)
- } else if (expectedNdv > 0) {
- assert(ndv > 0)
- val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
- assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.")
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 797fe9ffa8..b070138be0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand,
- DescribeTableCommand, ShowFunctionsCommand}
-import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing}
+import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest {
intercept("explain describe tables x", "Unsupported SQL statement")
}
- test("SPARK-18106 analyze table") {
+ test("analyze table statistics") {
assertEqual("analyze table t compute statistics",
AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
assertEqual("analyze table t compute statistics noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
- assertEqual("analyze table t partition (a) compute statistics noscan",
+ assertEqual("analyze table t partition (a) compute statistics nOscAn",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+
+ // Partitions specified - we currently parse them but don't do anything with it
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+ assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
intercept("analyze table t compute statistics xxxx",
@@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest {
intercept("analyze table t partition (a) compute statistics xxxx",
"Expected `NOSCAN` instead of `xxxx`")
}
+
+ test("analyze table column statistics") {
+ intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "")
+
+ assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
+ AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index ff0923f048..fd9dc32063 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils}
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
@@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
}
stats.colStats.foreach { case (colName, colStat) =>
- statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString
+ colStat.toMap.foreach { case (k, v) =>
+ statsProperties += (columnStatKeyPropName(colName, k) -> v)
+ }
}
tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties)
} else {
@@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* It reads table schema, provider, partition column names and bucket specification from table
* properties, and filter out these special entries from table properties.
*/
- private def restoreTableMetadata(table: CatalogTable): CatalogTable = {
+ private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = {
if (conf.get(DEBUG_MODE)) {
- return table
+ return inputTable
}
- val tableWithSchema = if (table.tableType == VIEW) {
- table
- } else {
- getProviderFromTableProperties(table) match {
+ var table = inputTable
+
+ if (table.tableType != VIEW) {
+ table.properties.get(DATASOURCE_PROVIDER) match {
// No provider in table properties, which means this table is created by Spark prior to 2.1,
// or is created at Hive side.
case None =>
- table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
+ table = table.copy(
+ provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
// This is a Hive serde table created by Spark 2.1 or higher versions.
- case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table)
+ case Some(DDLUtils.HIVE_PROVIDER) =>
+ table = restoreHiveSerdeTable(table)
// This is a regular data source table.
- case Some(provider) => restoreDataSourceTable(table, provider)
+ case Some(provider) =>
+ table = restoreDataSourceTable(table, provider)
}
}
// construct Spark's statistics from information in Hive metastore
- val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
- val tableWithStats = if (statsProps.nonEmpty) {
- val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX))
- .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) }
- val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect {
- case f if colStatsProps.contains(f.name) =>
- val numFields = AnalyzeColumnCommand.numStatFields(f.dataType)
- (f.name, ColumnStat(numFields, colStatsProps(f.name)))
- }.toMap
- tableWithSchema.copy(
+ val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
+
+ if (statsProps.nonEmpty) {
+ val colStats = new scala.collection.mutable.HashMap[String, ColumnStat]
+
+ // For each column, recover its column stats. Note that this is currently a O(n^2) operation,
+ // but given the number of columns it usually not enormous, this is probably OK as a start.
+ // If we want to map this a linear operation, we'd need a stronger contract between the
+ // naming convention used for serialization.
+ table.schema.foreach { field =>
+ if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) {
+ // If "version" field is defined, then the column stat is defined.
+ val keyPrefix = columnStatKeyPropName(field.name, "")
+ val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) =>
+ (k.drop(keyPrefix.length), v)
+ }
+
+ ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach {
+ colStat => colStats += field.name -> colStat
+ }
+ }
+ }
+
+ table = table.copy(
stats = Some(Statistics(
- sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)),
- rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
- colStats = colStats)))
- } else {
- tableWithSchema
+ sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)),
+ rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
+ colStats = colStats.toMap)))
}
- tableWithStats.copy(properties = getOriginalTableProperties(table))
+ // Get the original table properties as defined by the user.
+ table.copy(
+ properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) })
}
private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = {
@@ -1020,17 +1039,17 @@ object HiveExternalCatalog {
val TABLE_PARTITION_PROVIDER_CATALOG = "catalog"
val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem"
-
- def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = {
- metadata.properties.get(DATASOURCE_PROVIDER)
- }
-
- def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = {
- metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }
+ /**
+ * Returns the fully qualified name used in table properties for a particular column stat.
+ * For example, for column "mycol", and "min" stat, this should return
+ * "spark.sql.statistics.colStats.mycol.min".
+ */
+ private def columnStatKeyPropName(columnName: String, statKey: String): String = {
+ STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey
}
// A persisted data source table always store its schema in the catalog.
- def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
+ private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
val errorMessage = "Could not read schema from the hive metastore because it is corrupted."
val props = metadata.properties
val schema = props.get(DATASOURCE_SCHEMA)
@@ -1078,11 +1097,11 @@ object HiveExternalCatalog {
)
}
- def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
+ private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
getColumnNamesByType(metadata.properties, "part", "partitioning columns")
}
- def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
+ private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets =>
BucketSpec(
numBuckets.toInt,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 4f5ebc3d83..5ae202fdc9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -22,56 +22,16 @@ import java.io.{File, PrintWriter}
import scala.reflect.ClassTag
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
-import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
-class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
-
- test("parse analyze commands") {
- def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
- val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand)
- val operators = parsed.collect {
- case a: AnalyzeTableCommand => a
- case o => o
- }
-
- assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(
- s"""$analyzeCommand expected command: $c, but got ${operators(0)}
- |parsed command:
- |$parsed
- """.stripMargin)
- }
- }
-
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan",
- classOf[AnalyzeTableCommand])
-
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn",
- classOf[AnalyzeTableCommand])
- }
+class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
test("MetastoreRelations fallback to HDFS for size estimation") {
val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled
@@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
+ test("verify serialized column stats after analyzing columns") {
+ import testImplicits._
+
+ val tableName = "column_stats_test2"
+ // (data.head.productArity - 1) because the last column does not support stats collection.
+ assert(stats.size == data.head.productArity - 1)
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+
+ withTable(tableName) {
+ df.write.saveAsTable(tableName)
+
+ // Collect statistics
+ sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
+
+ // Validate statistics
+ val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ val table = hiveClient.getTable("default", tableName)
+
+ val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats"))
+ assert(props == Map(
+ "spark.sql.statistics.colStats.cbinary.avgLen" -> "3",
+ "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbinary.maxLen" -> "3",
+ "spark.sql.statistics.colStats.cbinary.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbinary.version" -> "1",
+ "spark.sql.statistics.colStats.cbool.avgLen" -> "1",
+ "spark.sql.statistics.colStats.cbool.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbool.max" -> "true",
+ "spark.sql.statistics.colStats.cbool.maxLen" -> "1",
+ "spark.sql.statistics.colStats.cbool.min" -> "false",
+ "spark.sql.statistics.colStats.cbool.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbool.version" -> "1",
+ "spark.sql.statistics.colStats.cbyte.avgLen" -> "1",
+ "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbyte.max" -> "2",
+ "spark.sql.statistics.colStats.cbyte.maxLen" -> "1",
+ "spark.sql.statistics.colStats.cbyte.min" -> "1",
+ "spark.sql.statistics.colStats.cbyte.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbyte.version" -> "1",
+ "spark.sql.statistics.colStats.cdate.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cdate.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09",
+ "spark.sql.statistics.colStats.cdate.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08",
+ "spark.sql.statistics.colStats.cdate.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdate.version" -> "1",
+ "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16",
+ "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000",
+ "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16",
+ "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000",
+ "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdecimal.version" -> "1",
+ "spark.sql.statistics.colStats.cdouble.avgLen" -> "8",
+ "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdouble.max" -> "6.0",
+ "spark.sql.statistics.colStats.cdouble.maxLen" -> "8",
+ "spark.sql.statistics.colStats.cdouble.min" -> "1.0",
+ "spark.sql.statistics.colStats.cdouble.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdouble.version" -> "1",
+ "spark.sql.statistics.colStats.cfloat.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cfloat.max" -> "7.0",
+ "spark.sql.statistics.colStats.cfloat.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cfloat.min" -> "1.0",
+ "spark.sql.statistics.colStats.cfloat.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cfloat.version" -> "1",
+ "spark.sql.statistics.colStats.cint.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cint.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cint.max" -> "4",
+ "spark.sql.statistics.colStats.cint.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cint.min" -> "1",
+ "spark.sql.statistics.colStats.cint.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cint.version" -> "1",
+ "spark.sql.statistics.colStats.clong.avgLen" -> "8",
+ "spark.sql.statistics.colStats.clong.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.clong.max" -> "5",
+ "spark.sql.statistics.colStats.clong.maxLen" -> "8",
+ "spark.sql.statistics.colStats.clong.min" -> "1",
+ "spark.sql.statistics.colStats.clong.nullCount" -> "1",
+ "spark.sql.statistics.colStats.clong.version" -> "1",
+ "spark.sql.statistics.colStats.cshort.avgLen" -> "2",
+ "spark.sql.statistics.colStats.cshort.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cshort.max" -> "3",
+ "spark.sql.statistics.colStats.cshort.maxLen" -> "2",
+ "spark.sql.statistics.colStats.cshort.min" -> "1",
+ "spark.sql.statistics.colStats.cshort.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cshort.version" -> "1",
+ "spark.sql.statistics.colStats.cstring.avgLen" -> "3",
+ "spark.sql.statistics.colStats.cstring.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cstring.maxLen" -> "3",
+ "spark.sql.statistics.colStats.cstring.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cstring.version" -> "1",
+ "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8",
+ "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0",
+ "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8",
+ "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0",
+ "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1",
+ "spark.sql.statistics.colStats.ctimestamp.version" -> "1"
+ ))
+ }
+ }
+
private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = {
test("test table-level statistics for " + tableDescription) {
val parquetTable = "parquetTable"
@@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
TableIdentifier(parquetTable))
assert(DDLUtils.isDatasourceTable(catalogTable))
- sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
+ // Add a filter to avoid creating too many partitions
+ sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None)
@@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
val fetchedStats1 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
- sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
+ sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan")
val fetchedStats2 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
@@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
parquetTable,
isDataSourceTable = true,
hasSizeInBytes = true,
- expectedRowCounts = Some(1000))
+ expectedRowCounts = Some(20))
assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes)
}
}
@@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
+ /** Used to test refreshing cached metadata once table stats are updated. */
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
val tableName = "tbl"
var statsBeforeUpdate: Statistics = null
@@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
assert(statsAfterUpdate.rowCount == Some(2))
}
- test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
- val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)
-
- assert(statsBeforeUpdate.sizeInBytes > 0)
- assert(statsBeforeUpdate.rowCount == Some(1))
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = statsBeforeUpdate.colStats("key"),
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
-
- assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
- assert(statsAfterUpdate.rowCount == Some(2))
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = statsAfterUpdate.colStats("key"),
- expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- }
-
- private lazy val (testDataFrame, expectedColStatsSeq) = {
- import testImplicits._
-
- val intSeq = Seq(1, 2)
- val stringSeq = Seq("a", "bb")
- val binarySeq = Seq("a", "bb").map(_.getBytes)
- val booleanSeq = Seq(true, false)
- val data = intSeq.indices.map { i =>
- (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
- }
- val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
- val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f =>
- val colStat = f.dataType match {
- case IntegerType =>
- ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
- case StringType =>
- ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
- stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
- case BinaryType =>
- ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
- binarySeq.map(_.length).max.toInt))
- case BooleanType =>
- ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
- booleanSeq.count(_.equals(false)).toLong))
- }
- (f, colStat)
- }
- (df, expectedColStatsSeq)
- }
-
- private def checkColStats(
- tableName: String,
- isDataSourceTable: Boolean,
- expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
- val readback = spark.table(tableName)
- val stats = readback.queryExecution.analyzed.collect {
- case rel: MetastoreRelation =>
- assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
- rel.catalogTable.stats.get
- case rel: LogicalRelation =>
- assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
- rel.catalogTable.get.stats.get
- }
- assert(stats.length == 1)
- val columnStats = stats.head.colStats
- assert(columnStats.size == expectedColStatsSeq.length)
- expectedColStatsSeq.foreach { case (field, expectedColStat) =>
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = columnStats(field.name),
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
-
- test("generate and load column-level stats for data source table") {
- val dsTable = "dsTable"
- withTable(dsTable) {
- testDataFrame.write.format("parquet").saveAsTable(dsTable)
- sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
- checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
- }
- }
-
- test("generate and load column-level stats for hive serde table") {
- val hTable = "hTable"
- val tmp = "tmp"
- withTable(hTable, tmp) {
- testDataFrame.write.format("parquet").saveAsTable(tmp)
- sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
- sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
- sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
- checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
- }
- }
-
- // When caseSensitive is on, for columns with only case difference, they are different columns
- // and we should generate column stats for all of them.
- private def checkCaseSensitiveColStats(columnName: String): Unit = {
- val tableName = "tbl"
- withTable(tableName) {
- val column1 = columnName.toLowerCase
- val column2 = columnName.toUpperCase
- withSQLConf("spark.sql.caseSensitive" -> "true") {
- sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
- sql(s"INSERT INTO $tableName SELECT 1, 3.0")
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
- val readback = spark.table(tableName)
- val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- val columnStats = rel.catalogTable.get.stats.get.colStats
- assert(columnStats.size == 2)
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = columnStats(column1),
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- StatisticsTest.checkColStat(
- dataType = DoubleType,
- colStat = columnStats(column2),
- expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- rel
- }
- assert(relations.size == 1)
- }
- }
- }
-
- test("check column statistics for case sensitive column names") {
- checkCaseSensitiveColStats(columnName = "c1")
- }
-
- test("check column statistics for case sensitive non-ascii column names") {
- // scalastyle:off
- // non ascii characters are not allowed in the source code, so we disable the scalastyle.
- checkCaseSensitiveColStats(columnName = "列c")
- // scalastyle:on
- }
-
test("estimates the size of a test MetastoreRelation") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>