aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala212
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala105
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala218
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala334
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala92
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala93
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala299
9 files changed, 591 insertions, 918 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index f3e2147b8f..79865609cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.commons.codec.binary.Base64
+import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._
+
/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
* corresponding statistic produced by the children. To override this behavior, override
@@ -58,60 +61,175 @@ case class Statistics(
}
}
+
/**
- * Statistics for a column.
+ * Statistics collected for a column.
+ *
+ * 1. Supported data types are defined in `ColumnStat.supportsType`.
+ * 2. The JVM data type stored in min/max is the external data type (used in Row) for the
+ * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
+ * TimestampType we store java.sql.Timestamp.
+ * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
+ * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * (sketches) might have been used, and the data collected can also be stale.
+ *
+ * @param distinctCount number of distinct values
+ * @param min minimum value
+ * @param max maximum value
+ * @param nullCount number of nulls
+ * @param avgLen average length of the values. For fixed-length types, this should be a constant.
+ * @param maxLen maximum length of the values. For fixed-length types, this should be a constant.
*/
-case class ColumnStat(statRow: InternalRow) {
+case class ColumnStat(
+ distinctCount: BigInt,
+ min: Option[Any],
+ max: Option[Any],
+ nullCount: BigInt,
+ avgLen: Long,
+ maxLen: Long) {
- def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = {
- NumericColumnStat(statRow, dataType)
- }
- def forString: StringColumnStat = StringColumnStat(statRow)
- def forBinary: BinaryColumnStat = BinaryColumnStat(statRow)
- def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow)
+ // We currently don't store min/max for binary/string type. This can change in the future and
+ // then we need to remove this require.
+ require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String]))
+ require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String]))
- override def toString: String = {
- // use Base64 for encoding
- Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes)
+ /**
+ * Returns a map from string to string that can be used to serialize the column stats.
+ * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
+ * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
+ *
+ * As part of the protocol, the returned map always contains a key called "version".
+ * In the case min/max values are null (None), they won't appear in the map.
+ */
+ def toMap: Map[String, String] = {
+ val map = new scala.collection.mutable.HashMap[String, String]
+ map.put(ColumnStat.KEY_VERSION, "1")
+ map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
+ map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
+ map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
+ map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
+ min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
+ max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
+ map.toMap
}
}
-object ColumnStat {
- def apply(numFields: Int, str: String): ColumnStat = {
- // use Base64 for decoding
- val bytes = Base64.decodeBase64(str)
- val unsafeRow = new UnsafeRow(numFields)
- unsafeRow.pointTo(bytes, bytes.length)
- ColumnStat(unsafeRow)
+
+object ColumnStat extends Logging {
+
+ // List of string keys used to serialize ColumnStat
+ val KEY_VERSION = "version"
+ private val KEY_DISTINCT_COUNT = "distinctCount"
+ private val KEY_MIN_VALUE = "min"
+ private val KEY_MAX_VALUE = "max"
+ private val KEY_NULL_COUNT = "nullCount"
+ private val KEY_AVG_LEN = "avgLen"
+ private val KEY_MAX_LEN = "maxLen"
+
+ /** Returns true iff the we support gathering column statistics on column of the given type. */
+ def supportsType(dataType: DataType): Boolean = dataType match {
+ case _: IntegralType => true
+ case _: DecimalType => true
+ case DoubleType | FloatType => true
+ case BooleanType => true
+ case DateType => true
+ case TimestampType => true
+ case BinaryType | StringType => true
+ case _ => false
}
-}
-case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) {
- // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType]
- val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType]
- val ndv: Long = statRow.getLong(3)
-}
+ /**
+ * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
+ * from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
+ */
+ def fromMap(table: String, field: StructField, map: Map[String, String])
+ : Option[ColumnStat] = {
+ val str2val: (String => Any) = field.dataType match {
+ case _: IntegralType => _.toLong
+ case _: DecimalType => new java.math.BigDecimal(_)
+ case DoubleType | FloatType => _.toDouble
+ case BooleanType => _.toBoolean
+ case DateType => java.sql.Date.valueOf
+ case TimestampType => java.sql.Timestamp.valueOf
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case BinaryType | StringType => _ => null
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column ${field.name} of data type: ${field.dataType}.")
+ }
-case class StringColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val avgColLen: Double = statRow.getDouble(1)
- val maxColLen: Long = statRow.getInt(2)
- val ndv: Long = statRow.getLong(3)
-}
+ try {
+ Some(ColumnStat(
+ distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
+ // Note that flatMap(Option.apply) turns Option(null) into None.
+ min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
+ max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
+ nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
+ avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
+ maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
+ ))
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e)
+ None
+ }
+ }
-case class BinaryColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val avgColLen: Double = statRow.getDouble(1)
- val maxColLen: Long = statRow.getInt(2)
-}
+ /**
+ * Constructs an expression to compute column statistics for a given column.
+ *
+ * The expression should create a single struct column with the following schema:
+ * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long
+ *
+ * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and
+ * as a result should stay in sync with it.
+ */
+ def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = {
+ def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
+ expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
+ })
+ val one = Literal(1, LongType)
+
+ // the approximate ndv (num distinct value) should never be larger than the number of rows
+ val numNonNulls = if (col.nullable) Count(col) else Count(one)
+ val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls))
+ val numNulls = Subtract(Count(one), numNonNulls)
+
+ def fixedLenTypeStruct(castType: DataType) = {
+ // For fixed width types, avg size should be the same as max size.
+ val avgSize = Literal(col.dataType.defaultSize, LongType)
+ struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize)
+ }
+
+ col.dataType match {
+ case _: IntegralType => fixedLenTypeStruct(LongType)
+ case _: DecimalType => fixedLenTypeStruct(col.dataType)
+ case DoubleType | FloatType => fixedLenTypeStruct(DoubleType)
+ case BooleanType => fixedLenTypeStruct(col.dataType)
+ case DateType => fixedLenTypeStruct(col.dataType)
+ case TimestampType => fixedLenTypeStruct(col.dataType)
+ case BinaryType | StringType =>
+ // For string and binary type, we don't store min/max.
+ val nullLit = Literal(null, col.dataType)
+ struct(
+ ndv, nullLit, nullLit, numNulls,
+ Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType))
+ case _ =>
+ throw new AnalysisException("Analyzing column statistics is not supported for column " +
+ s"${col.name} of data type: ${col.dataType}.")
+ }
+ }
+
+ /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
+ def rowToColumnStat(row: Row): ColumnStat = {
+ ColumnStat(
+ distinctCount = BigInt(row.getLong(0)),
+ min = Option(row.get(1)), // for string/binary min/max, get should return null
+ max = Option(row.get(2)),
+ nullCount = BigInt(row.getLong(3)),
+ avgLen = row.getLong(4),
+ maxLen = row.getLong(5)
+ )
+ }
-case class BooleanColumnStat(statRow: InternalRow) {
- // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`.
- val numNulls: Long = statRow.getLong(0)
- val numTrues: Long = statRow.getLong(1)
- val numFalses: Long = statRow.getLong(2)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 7fc57d09e9..9dffe3614a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types._
/**
@@ -62,7 +61,7 @@ case class AnalyzeColumnCommand(
// Compute stats for each column
val (rowCount, newColStats) =
- AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames)
+ AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
@@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging {
*
* This is visible for testing.
*/
- def computeColStats(
+ def computeColumnStats(
sparkSession: SparkSession,
+ tableName: String,
relation: LogicalPlan,
columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = {
@@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging {
val resolver = sparkSession.sessionState.conf.resolver
val attributesToAnalyze = AttributeSet(columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
- exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col."))
+ exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
}).toSeq
+ // Make sure the column types are supported for stats gathering.
+ attributesToAnalyze.foreach { attr =>
+ if (!ColumnStat.supportsType(attr.dataType)) {
+ throw new AnalysisException(
+ s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " +
+ "and Spark does not support statistics collection on this column type.")
+ }
+ }
+
// Collect statistics per column.
// The first element in the result will be the overall row count, the following elements
// will be structs containing all column stats.
// The layout of each struct follows the layout of the ColumnStats.
val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError
val expressions = Count(Literal(1)).toAggregateExpression() +:
- attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr))
+ attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr))
+
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
- val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation))
- .queryExecution.toRdd.collect().head
+ val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
- // unwrap the result
- // TODO: Get rid of numFields by using the public Dataset API.
val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
- val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType)
- (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields)))
+ (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
}.toMap
(rowCount, columnStats)
}
-
- private val zero = Literal(0, LongType)
- private val one = Literal(1, LongType)
-
- private def numNulls(e: Expression): Expression = {
- if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
- }
- private def max(e: Expression): Expression = Max(e)
- private def min(e: Expression): Expression = Min(e)
- private def ndv(e: Expression, relativeSD: Double): Expression = {
- // the approximate ndv should never be larger than the number of rows
- Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
- }
- private def avgLength(e: Expression): Expression = Average(Length(e))
- private def maxLength(e: Expression): Expression = Max(Length(e))
- private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
- private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
-
- /**
- * Creates a struct that groups the sequence of expressions together. This is used to create
- * one top level struct per column.
- */
- private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = {
- CreateStruct(exprs.map { expr: Expression =>
- expr.transformUp {
- case af: AggregateFunction => af.toAggregateExpression()
- }
- })
- }
-
- private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
- Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
- }
-
- private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
- Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
- }
-
- private def binaryColumnStat(e: Expression): Seq[Expression] = {
- Seq(numNulls(e), avgLength(e), maxLength(e))
- }
-
- private def booleanColumnStat(e: Expression): Seq[Expression] = {
- Seq(numNulls(e), numTrues(e), numFalses(e))
- }
-
- // TODO(rxin): Get rid of this function.
- def numStatFields(dataType: DataType): Int = {
- dataType match {
- case BinaryType | BooleanType => 3
- case _ => 4
- }
- }
-
- /**
- * Creates a struct expression that contains the statistics to collect for a column.
- *
- * @param attr column to collect statistics
- * @param relativeSD relative error for approximate number of distinct values.
- */
- def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = {
- attr.dataType match {
- case _: NumericType | TimestampType | DateType =>
- createStruct(numericColumnStat(attr, relativeSD))
- case StringType =>
- createStruct(stringColumnStat(attr, relativeSD))
- case BinaryType =>
- createStruct(binaryColumnStat(attr))
- case BooleanType =>
- createStruct(booleanColumnStat(attr))
- case otherType =>
- throw new AnalysisException("Analyzing columns is not supported for column " +
- s"${attr.name} of data type: ${attr.dataType}.")
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
new file mode 100644
index 0000000000..1fcccd0610
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.{lang => jl}
+import java.sql.{Date, Timestamp}
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.test.SQLTestData.ArrayData
+import org.apache.spark.sql.types._
+
+
+/**
+ * End-to-end suite testing statistics collection and use on both entire table and columns.
+ */
+class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext {
+ import testImplicits._
+
+ private def checkTableStats(tableName: String, expectedRowCount: Option[Int])
+ : Option[Statistics] = {
+ val df = spark.table(tableName)
+ val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
+ assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
+ rel.catalogTable.get.stats
+ }
+ assert(stats.size == 1)
+ stats.head
+ }
+
+ test("estimates the size of a limit 0 on outer join") {
+ withTempView("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ val df1 = spark.table("test")
+ val df2 = spark.table("test").limit(0)
+ val df = df1.join(df2, Seq("k"), "left")
+
+ val sizes = df.queryExecution.analyzed.collect { case g: Join =>
+ g.statistics.sizeInBytes
+ }
+
+ assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
+ assert(sizes.head === BigInt(96),
+ s"expected exact size 96 for table 'test', got: ${sizes.head}")
+ }
+ }
+
+ test("analyze column command - unsupported types and invalid columns") {
+ val tableName = "column_stats_test1"
+ withTable(tableName) {
+ Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
+
+ // Test unsupported data types
+ val err1 = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
+ }
+ assert(err1.message.contains("does not support statistics collection"))
+
+ // Test invalid columns
+ val err2 = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column")
+ }
+ assert(err2.message.contains("does not exist"))
+ }
+ }
+
+ test("test table-level statistics for data source table") {
+ val tableName = "tbl"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
+ Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
+
+ // noscan won't count the number of rows
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
+ checkTableStats(tableName, expectedRowCount = None)
+
+ // without noscan, we count the number of rows
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+ checkTableStats(tableName, expectedRowCount = Some(2))
+ }
+ }
+
+ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
+ val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
+ val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
+ assert(df.queryExecution.analyzed.statistics.sizeInBytes >
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
+ assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
+ }
+
+ test("estimates the size of limit") {
+ withTempView("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
+ val df = sql(s"""SELECT * FROM test limit $limit""")
+
+ val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
+ g.statistics.sizeInBytes
+ }
+ assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesGlobalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
+
+ val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
+ l.statistics.sizeInBytes
+ }
+ assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesLocalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
+ }
+ }
+ }
+
+}
+
+
+/**
+ * The base for test cases that we want to include in both the hive module (for verifying behavior
+ * when using the Hive external catalog) as well as in the sql/core module.
+ */
+abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils {
+ import testImplicits._
+
+ private val dec1 = new java.math.BigDecimal("1.000000000000000000")
+ private val dec2 = new java.math.BigDecimal("8.000000000000000000")
+ private val d1 = Date.valueOf("2016-05-08")
+ private val d2 = Date.valueOf("2016-05-09")
+ private val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
+ private val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+
+ /**
+ * Define a very simple 3 row table used for testing column serialization.
+ * Note: last column is seq[int] which doesn't support stats collection.
+ */
+ protected val data = Seq[
+ (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long,
+ jl.Double, jl.Float, java.math.BigDecimal,
+ String, Array[Byte], Date, Timestamp,
+ Seq[Int])](
+ (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null),
+ (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null),
+ (null, null, null, null, null, null, null, null, null, null, null, null, null)
+ )
+
+ /** A mapping from column to the stats collected. */
+ protected val stats = mutable.LinkedHashMap(
+ "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
+ "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
+ "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
+ "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
+ "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
+ "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
+ "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
+ "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
+ "cstring" -> ColumnStat(2, None, None, 1, 3, 3),
+ "cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
+ "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
+ "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
+ )
+
+ test("column stats round trip serialization") {
+ // Make sure we serialize and then deserialize and we will get the result data
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+ stats.zip(df.schema).foreach { case ((k, v), field) =>
+ withClue(s"column $k with type ${field.dataType}") {
+ val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
+ assert(roundtrip == Some(v))
+ }
+ }
+ }
+
+ test("analyze column command - result verification") {
+ val tableName = "column_stats_test2"
+ // (data.head.productArity - 1) because the last column does not support stats collection.
+ assert(stats.size == data.head.productArity - 1)
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+
+ withTable(tableName) {
+ df.write.saveAsTable(tableName)
+
+ // Collect statistics
+ sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
+
+ // Validate statistics
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
+ assert(table.stats.isDefined)
+ assert(table.stats.get.colStats.size == stats.size)
+
+ stats.foreach { case (k, v) =>
+ withClue(s"column $k") {
+ assert(table.stats.get.colStats(k) == v)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
deleted file mode 100644
index e866ac2cb3..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala
+++ /dev/null
@@ -1,334 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.sql.{Date, Timestamp}
-
-import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
-import org.apache.spark.sql.test.SQLTestData.ArrayData
-import org.apache.spark.sql.types._
-
-class StatisticsColumnSuite extends StatisticsTest {
- import testImplicits._
-
- test("parse analyze column commands") {
- val tableName = "tbl"
-
- // we need to specify column names
- intercept[ParseException] {
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS")
- }
-
- val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value"
- val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql)
- val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value"))
- comparePlans(parsed, expected)
- }
-
- test("analyzing columns of non-atomic types is not supported") {
- val tableName = "tbl"
- withTable(tableName) {
- Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName)
- val err = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data")
- }
- assert(err.message.contains("Analyzing columns is not supported"))
- }
- }
-
- test("check correctness of columns") {
- val table = "tbl"
- val colName1 = "abc"
- val colName2 = "x.yz"
- withTable(table) {
- sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET")
-
- val invalidColError = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key")
- }
- assert(invalidColError.message == "Invalid column name: key.")
-
- withSQLConf("spark.sql.caseSensitive" -> "true") {
- val invalidErr = intercept[AnalysisException] {
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}")
- }
- assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.")
- }
-
- withSQLConf("spark.sql.caseSensitive" -> "false") {
- val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2)
- val tableIdent = TableIdentifier(table, Some("default"))
- val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
- val (_, columnStats) =
- AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze)
- assert(columnStats.contains(colName1))
- assert(columnStats.contains(colName2))
- // check deduplication
- assert(columnStats.size == 2)
- assert(!columnStats.contains(colName2.toUpperCase))
- }
- }
- }
-
- private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = {
- values.filter(_.isDefined).map(_.get)
- }
-
- test("column-level statistics for integral type columns") {
- val values = (0 to 5).map { i =>
- if (i % 2 == 0) None else Some(i)
- }
- val data = values.map { i =>
- (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong))
- }
-
- val df = data.toDF("c1", "c2", "c3", "c4")
- val nonNullValues = getNonNullValues[Int](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.max,
- nonNullValues.min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for fractional type columns") {
- val values: Seq[Option[Decimal]] = (0 to 5).map { i =>
- if (i == 0) None else Some(Decimal(i + i * 0.01))
- }
- val data = values.map { i =>
- (i.map(_.toFloat), i.map(_.toDouble), i)
- }
-
- val df = data.toDF("c1", "c2", "c3")
- val nonNullValues = getNonNullValues[Decimal](values)
- val numNulls = values.count(_.isEmpty).toLong
- val ndv = nonNullValues.distinct.length.toLong
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = f.dataType match {
- case floatType: FloatType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat,
- ndv))
- case doubleType: DoubleType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble,
- ndv))
- case decimalType: DecimalType =>
- ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv))
- }
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for string column") {
- val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some(""))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[String](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
- nonNullValues.map(_.length).max.toInt,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for binary column") {
- val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Array[Byte]](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.map(_.length).sum / nonNullValues.length.toDouble,
- nonNullValues.map(_.length).max.toInt))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for boolean column") {
- val values = Seq(None, Some(true), Some(false), Some(true))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Boolean](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- nonNullValues.count(_.equals(true)).toLong,
- nonNullValues.count(_.equals(false)).toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for date column") {
- val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf))
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Date](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- // Internally, DateType is represented as the number of days from 1970-01-01.
- nonNullValues.map(DateTimeUtils.fromJavaDate).max,
- nonNullValues.map(DateTimeUtils.fromJavaDate).min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for timestamp column") {
- val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i =>
- i.map(Timestamp.valueOf)
- }
- val df = values.toDF("c1")
- val nonNullValues = getNonNullValues[Timestamp](values)
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = ColumnStat(InternalRow(
- values.count(_.isEmpty).toLong,
- // Internally, TimestampType is represented as the number of days from 1970-01-01
- nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max,
- nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min,
- nonNullValues.distinct.length.toLong))
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for null columns") {
- val values = Seq(None, None)
- val data = values.map { i =>
- (i.map(_.toString), i.map(_.toString.toInt))
- }
- val df = data.toDF("c1", "c2")
- val expectedColStatsSeq = df.schema.map { f =>
- (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)))
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("column-level statistics for columns with different types") {
- val intSeq = Seq(1, 2)
- val doubleSeq = Seq(1.01d, 2.02d)
- val stringSeq = Seq("a", "bb")
- val binarySeq = Seq("a", "bb").map(_.getBytes)
- val booleanSeq = Seq(true, false)
- val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf)
- val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf)
- val longSeq = Seq(5L, 4L)
-
- val data = intSeq.indices.map { i =>
- (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i),
- timestampSeq(i), longSeq(i))
- }
- val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8")
- val expectedColStatsSeq = df.schema.map { f =>
- val colStat = f.dataType match {
- case IntegerType =>
- ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
- case DoubleType =>
- ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min,
- doubleSeq.distinct.length.toLong))
- case StringType =>
- ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
- stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
- case BinaryType =>
- ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
- binarySeq.map(_.length).max.toInt))
- case BooleanType =>
- ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
- booleanSeq.count(_.equals(false)).toLong))
- case DateType =>
- ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max,
- dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong))
- case TimestampType =>
- ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max,
- timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min,
- timestampSeq.distinct.length.toLong))
- case LongType =>
- ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong))
- }
- (f, colStat)
- }
- checkColStats(df, expectedColStatsSeq)
- }
-
- test("update table-level stats while collecting column-level stats") {
- val table = "tbl"
- withTable(table) {
- sql(s"CREATE TABLE $table (c1 int) USING PARQUET")
- sql(s"INSERT INTO $table SELECT 1")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS")
- checkTableStats(tableName = table, expectedRowCount = Some(1))
-
- // update table-level stats between analyze table and analyze column commands
- sql(s"INSERT INTO $table SELECT 1")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
- val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2))
-
- val colStat = fetchedStats.get.colStats("c1")
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = colStat,
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
-
- test("analyze column stats independently") {
- val table = "tbl"
- withTable(table) {
- sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET")
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1")
- val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0))
- assert(fetchedStats1.get.colStats.size == 1)
- val expected1 = ColumnStat(InternalRow(0L, null, null, 0L))
- val rsd = spark.sessionState.conf.ndvMaxError
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = fetchedStats1.get.colStats("c1"),
- expectedColStat = expected1,
- rsd = rsd)
-
- sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2")
- val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0))
- // column c1 is kept in the stats
- assert(fetchedStats2.get.colStats.size == 2)
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = fetchedStats2.get.colStats("c1"),
- expectedColStat = expected1,
- rsd = rsd)
- val expected2 = ColumnStat(InternalRow(0L, null, null, 0L))
- StatisticsTest.checkColStat(
- dataType = LongType,
- colStat = fetchedStats2.get.colStats("c2"),
- expectedColStat = expected2,
- rsd = rsd)
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
deleted file mode 100644
index 8cf42e9248..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
-import org.apache.spark.sql.types._
-
-class StatisticsSuite extends StatisticsTest {
- import testImplicits._
-
- test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
- val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
- val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
- assert(df.queryExecution.analyzed.statistics.sizeInBytes >
- spark.sessionState.conf.autoBroadcastJoinThreshold)
- assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
- spark.sessionState.conf.autoBroadcastJoinThreshold)
- }
-
- test("estimates the size of limit") {
- withTempView("test") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
- .createOrReplaceTempView("test")
- Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
- val df = sql(s"""SELECT * FROM test limit $limit""")
-
- val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
- g.statistics.sizeInBytes
- }
- assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesGlobalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
-
- val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
- l.statistics.sizeInBytes
- }
- assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesLocalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
- }
- }
- }
-
- test("estimates the size of a limit 0 on outer join") {
- withTempView("test") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
- .createOrReplaceTempView("test")
- val df1 = spark.table("test")
- val df2 = spark.table("test").limit(0)
- val df = df1.join(df2, Seq("k"), "left")
-
- val sizes = df.queryExecution.analyzed.collect { case g: Join =>
- g.statistics.sizeInBytes
- }
-
- assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
- assert(sizes.head === BigInt(96),
- s"expected exact size 96 for table 'test', got: ${sizes.head}")
- }
- }
-
- test("test table-level statistics for data source table created in InMemoryCatalog") {
- val tableName = "tbl"
- withTable(tableName) {
- sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet")
- Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName)
-
- // noscan won't count the number of rows
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
- checkTableStats(tableName, expectedRowCount = None)
-
- // without noscan, we count the number of rows
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
- checkTableStats(tableName, expectedRowCount = Some(2))
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
deleted file mode 100644
index 915ee0d31b..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
-import org.apache.spark.sql.execution.command.AnalyzeColumnCommand
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types._
-
-
-trait StatisticsTest extends QueryTest with SharedSQLContext {
-
- def checkColStats(
- df: DataFrame,
- expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
- val table = "tbl"
- withTable(table) {
- df.write.format("json").saveAsTable(table)
- val columns = expectedColStatsSeq.map(_._1)
- val tableIdent = TableIdentifier(table, Some("default"))
- val relation = spark.sessionState.catalog.lookupRelation(tableIdent)
- val (_, columnStats) =
- AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name))
- expectedColStatsSeq.foreach { case (field, expectedColStat) =>
- assert(columnStats.contains(field.name))
- val colStat = columnStats(field.name)
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = colStat,
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
-
- // check if we get the same colStat after encoding and decoding
- val encodedCS = colStat.toString
- val numFields = AnalyzeColumnCommand.numStatFields(field.dataType)
- val decodedCS = ColumnStat(numFields, encodedCS)
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = decodedCS,
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
- }
-
- def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = {
- val df = spark.table(tableName)
- val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
- rel.catalogTable.get.stats
- }
- assert(stats.size == 1)
- stats.head
- }
-}
-
-object StatisticsTest {
- def checkColStat(
- dataType: DataType,
- colStat: ColumnStat,
- expectedColStat: ColumnStat,
- rsd: Double): Unit = {
- dataType match {
- case StringType =>
- val cs = colStat.forString
- val expectedCS = expectedColStat.forString
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.avgColLen == expectedCS.avgColLen)
- assert(cs.maxColLen == expectedCS.maxColLen)
- checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
- case BinaryType =>
- val cs = colStat.forBinary
- val expectedCS = expectedColStat.forBinary
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.avgColLen == expectedCS.avgColLen)
- assert(cs.maxColLen == expectedCS.maxColLen)
- case BooleanType =>
- val cs = colStat.forBoolean
- val expectedCS = expectedColStat.forBoolean
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.numTrues == expectedCS.numTrues)
- assert(cs.numFalses == expectedCS.numFalses)
- case atomicType: AtomicType =>
- checkNumericColStats(
- dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd)
- }
- }
-
- private def checkNumericColStats(
- dataType: AtomicType,
- colStat: ColumnStat,
- expectedColStat: ColumnStat,
- rsd: Double): Unit = {
- val cs = colStat.forNumeric(dataType)
- val expectedCS = expectedColStat.forNumeric(dataType)
- assert(cs.numNulls == expectedCS.numNulls)
- assert(cs.max == expectedCS.max)
- assert(cs.min == expectedCS.min)
- checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd)
- }
-
- private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = {
- // ndv is an approximate value, so we make sure we have the value, and it should be
- // within 3*SD's of the given rsd.
- if (expectedNdv == 0) {
- assert(ndv == 0)
- } else if (expectedNdv > 0) {
- assert(ndv > 0)
- val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
- assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.")
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 797fe9ffa8..b070138be0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand,
- DescribeTableCommand, ShowFunctionsCommand}
-import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing}
+import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest {
intercept("explain describe tables x", "Unsupported SQL statement")
}
- test("SPARK-18106 analyze table") {
+ test("analyze table statistics") {
assertEqual("analyze table t compute statistics",
AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
assertEqual("analyze table t compute statistics noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
- assertEqual("analyze table t partition (a) compute statistics noscan",
+ assertEqual("analyze table t partition (a) compute statistics nOscAn",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+
+ // Partitions specified - we currently parse them but don't do anything with it
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+ assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS",
+ AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
intercept("analyze table t compute statistics xxxx",
@@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest {
intercept("analyze table t partition (a) compute statistics xxxx",
"Expected `NOSCAN` instead of `xxxx`")
}
+
+ test("analyze table column statistics") {
+ intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "")
+
+ assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
+ AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index ff0923f048..fd9dc32063 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils}
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
@@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
}
stats.colStats.foreach { case (colName, colStat) =>
- statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString
+ colStat.toMap.foreach { case (k, v) =>
+ statsProperties += (columnStatKeyPropName(colName, k) -> v)
+ }
}
tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties)
} else {
@@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* It reads table schema, provider, partition column names and bucket specification from table
* properties, and filter out these special entries from table properties.
*/
- private def restoreTableMetadata(table: CatalogTable): CatalogTable = {
+ private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = {
if (conf.get(DEBUG_MODE)) {
- return table
+ return inputTable
}
- val tableWithSchema = if (table.tableType == VIEW) {
- table
- } else {
- getProviderFromTableProperties(table) match {
+ var table = inputTable
+
+ if (table.tableType != VIEW) {
+ table.properties.get(DATASOURCE_PROVIDER) match {
// No provider in table properties, which means this table is created by Spark prior to 2.1,
// or is created at Hive side.
case None =>
- table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
+ table = table.copy(
+ provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true)
// This is a Hive serde table created by Spark 2.1 or higher versions.
- case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table)
+ case Some(DDLUtils.HIVE_PROVIDER) =>
+ table = restoreHiveSerdeTable(table)
// This is a regular data source table.
- case Some(provider) => restoreDataSourceTable(table, provider)
+ case Some(provider) =>
+ table = restoreDataSourceTable(table, provider)
}
}
// construct Spark's statistics from information in Hive metastore
- val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
- val tableWithStats = if (statsProps.nonEmpty) {
- val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX))
- .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) }
- val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect {
- case f if colStatsProps.contains(f.name) =>
- val numFields = AnalyzeColumnCommand.numStatFields(f.dataType)
- (f.name, ColumnStat(numFields, colStatsProps(f.name)))
- }.toMap
- tableWithSchema.copy(
+ val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
+
+ if (statsProps.nonEmpty) {
+ val colStats = new scala.collection.mutable.HashMap[String, ColumnStat]
+
+ // For each column, recover its column stats. Note that this is currently a O(n^2) operation,
+ // but given the number of columns it usually not enormous, this is probably OK as a start.
+ // If we want to map this a linear operation, we'd need a stronger contract between the
+ // naming convention used for serialization.
+ table.schema.foreach { field =>
+ if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) {
+ // If "version" field is defined, then the column stat is defined.
+ val keyPrefix = columnStatKeyPropName(field.name, "")
+ val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) =>
+ (k.drop(keyPrefix.length), v)
+ }
+
+ ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach {
+ colStat => colStats += field.name -> colStat
+ }
+ }
+ }
+
+ table = table.copy(
stats = Some(Statistics(
- sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)),
- rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
- colStats = colStats)))
- } else {
- tableWithSchema
+ sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)),
+ rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
+ colStats = colStats.toMap)))
}
- tableWithStats.copy(properties = getOriginalTableProperties(table))
+ // Get the original table properties as defined by the user.
+ table.copy(
+ properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) })
}
private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = {
@@ -1020,17 +1039,17 @@ object HiveExternalCatalog {
val TABLE_PARTITION_PROVIDER_CATALOG = "catalog"
val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem"
-
- def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = {
- metadata.properties.get(DATASOURCE_PROVIDER)
- }
-
- def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = {
- metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }
+ /**
+ * Returns the fully qualified name used in table properties for a particular column stat.
+ * For example, for column "mycol", and "min" stat, this should return
+ * "spark.sql.statistics.colStats.mycol.min".
+ */
+ private def columnStatKeyPropName(columnName: String, statKey: String): String = {
+ STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey
}
// A persisted data source table always store its schema in the catalog.
- def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
+ private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
val errorMessage = "Could not read schema from the hive metastore because it is corrupted."
val props = metadata.properties
val schema = props.get(DATASOURCE_SCHEMA)
@@ -1078,11 +1097,11 @@ object HiveExternalCatalog {
)
}
- def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
+ private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
getColumnNamesByType(metadata.properties, "part", "partitioning columns")
}
- def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
+ private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets =>
BucketSpec(
numBuckets.toInt,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 4f5ebc3d83..5ae202fdc9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -22,56 +22,16 @@ import java.io.{File, PrintWriter}
import scala.reflect.ClassTag
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
-import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
-class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
-
- test("parse analyze commands") {
- def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
- val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand)
- val operators = parsed.collect {
- case a: AnalyzeTableCommand => a
- case o => o
- }
-
- assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(
- s"""$analyzeCommand expected command: $c, but got ${operators(0)}
- |parsed command:
- |$parsed
- """.stripMargin)
- }
- }
-
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS",
- classOf[AnalyzeTableCommand])
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan",
- classOf[AnalyzeTableCommand])
-
- assertAnalyzeCommand(
- "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn",
- classOf[AnalyzeTableCommand])
- }
+class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
test("MetastoreRelations fallback to HDFS for size estimation") {
val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled
@@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
+ test("verify serialized column stats after analyzing columns") {
+ import testImplicits._
+
+ val tableName = "column_stats_test2"
+ // (data.head.productArity - 1) because the last column does not support stats collection.
+ assert(stats.size == data.head.productArity - 1)
+ val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
+
+ withTable(tableName) {
+ df.write.saveAsTable(tableName)
+
+ // Collect statistics
+ sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", "))
+
+ // Validate statistics
+ val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ val table = hiveClient.getTable("default", tableName)
+
+ val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats"))
+ assert(props == Map(
+ "spark.sql.statistics.colStats.cbinary.avgLen" -> "3",
+ "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbinary.maxLen" -> "3",
+ "spark.sql.statistics.colStats.cbinary.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbinary.version" -> "1",
+ "spark.sql.statistics.colStats.cbool.avgLen" -> "1",
+ "spark.sql.statistics.colStats.cbool.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbool.max" -> "true",
+ "spark.sql.statistics.colStats.cbool.maxLen" -> "1",
+ "spark.sql.statistics.colStats.cbool.min" -> "false",
+ "spark.sql.statistics.colStats.cbool.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbool.version" -> "1",
+ "spark.sql.statistics.colStats.cbyte.avgLen" -> "1",
+ "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cbyte.max" -> "2",
+ "spark.sql.statistics.colStats.cbyte.maxLen" -> "1",
+ "spark.sql.statistics.colStats.cbyte.min" -> "1",
+ "spark.sql.statistics.colStats.cbyte.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cbyte.version" -> "1",
+ "spark.sql.statistics.colStats.cdate.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cdate.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09",
+ "spark.sql.statistics.colStats.cdate.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08",
+ "spark.sql.statistics.colStats.cdate.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdate.version" -> "1",
+ "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16",
+ "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000",
+ "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16",
+ "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000",
+ "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdecimal.version" -> "1",
+ "spark.sql.statistics.colStats.cdouble.avgLen" -> "8",
+ "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cdouble.max" -> "6.0",
+ "spark.sql.statistics.colStats.cdouble.maxLen" -> "8",
+ "spark.sql.statistics.colStats.cdouble.min" -> "1.0",
+ "spark.sql.statistics.colStats.cdouble.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cdouble.version" -> "1",
+ "spark.sql.statistics.colStats.cfloat.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cfloat.max" -> "7.0",
+ "spark.sql.statistics.colStats.cfloat.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cfloat.min" -> "1.0",
+ "spark.sql.statistics.colStats.cfloat.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cfloat.version" -> "1",
+ "spark.sql.statistics.colStats.cint.avgLen" -> "4",
+ "spark.sql.statistics.colStats.cint.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cint.max" -> "4",
+ "spark.sql.statistics.colStats.cint.maxLen" -> "4",
+ "spark.sql.statistics.colStats.cint.min" -> "1",
+ "spark.sql.statistics.colStats.cint.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cint.version" -> "1",
+ "spark.sql.statistics.colStats.clong.avgLen" -> "8",
+ "spark.sql.statistics.colStats.clong.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.clong.max" -> "5",
+ "spark.sql.statistics.colStats.clong.maxLen" -> "8",
+ "spark.sql.statistics.colStats.clong.min" -> "1",
+ "spark.sql.statistics.colStats.clong.nullCount" -> "1",
+ "spark.sql.statistics.colStats.clong.version" -> "1",
+ "spark.sql.statistics.colStats.cshort.avgLen" -> "2",
+ "spark.sql.statistics.colStats.cshort.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cshort.max" -> "3",
+ "spark.sql.statistics.colStats.cshort.maxLen" -> "2",
+ "spark.sql.statistics.colStats.cshort.min" -> "1",
+ "spark.sql.statistics.colStats.cshort.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cshort.version" -> "1",
+ "spark.sql.statistics.colStats.cstring.avgLen" -> "3",
+ "spark.sql.statistics.colStats.cstring.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.cstring.maxLen" -> "3",
+ "spark.sql.statistics.colStats.cstring.nullCount" -> "1",
+ "spark.sql.statistics.colStats.cstring.version" -> "1",
+ "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8",
+ "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2",
+ "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0",
+ "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8",
+ "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0",
+ "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1",
+ "spark.sql.statistics.colStats.ctimestamp.version" -> "1"
+ ))
+ }
+ }
+
private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = {
test("test table-level statistics for " + tableDescription) {
val parquetTable = "parquetTable"
@@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
TableIdentifier(parquetTable))
assert(DDLUtils.isDatasourceTable(catalogTable))
- sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
+ // Add a filter to avoid creating too many partitions
+ sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None)
@@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
val fetchedStats1 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
- sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src")
+ sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10")
sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan")
val fetchedStats2 = checkTableStats(
parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None)
@@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
parquetTable,
isDataSourceTable = true,
hasSizeInBytes = true,
- expectedRowCounts = Some(1000))
+ expectedRowCounts = Some(20))
assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes)
}
}
@@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}
+ /** Used to test refreshing cached metadata once table stats are updated. */
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
val tableName = "tbl"
var statsBeforeUpdate: Statistics = null
@@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
assert(statsAfterUpdate.rowCount == Some(2))
}
- test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
- val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)
-
- assert(statsBeforeUpdate.sizeInBytes > 0)
- assert(statsBeforeUpdate.rowCount == Some(1))
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = statsBeforeUpdate.colStats("key"),
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
-
- assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
- assert(statsAfterUpdate.rowCount == Some(2))
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = statsAfterUpdate.colStats("key"),
- expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- }
-
- private lazy val (testDataFrame, expectedColStatsSeq) = {
- import testImplicits._
-
- val intSeq = Seq(1, 2)
- val stringSeq = Seq("a", "bb")
- val binarySeq = Seq("a", "bb").map(_.getBytes)
- val booleanSeq = Seq(true, false)
- val data = intSeq.indices.map { i =>
- (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
- }
- val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
- val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f =>
- val colStat = f.dataType match {
- case IntegerType =>
- ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
- case StringType =>
- ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
- stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
- case BinaryType =>
- ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
- binarySeq.map(_.length).max.toInt))
- case BooleanType =>
- ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
- booleanSeq.count(_.equals(false)).toLong))
- }
- (f, colStat)
- }
- (df, expectedColStatsSeq)
- }
-
- private def checkColStats(
- tableName: String,
- isDataSourceTable: Boolean,
- expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
- val readback = spark.table(tableName)
- val stats = readback.queryExecution.analyzed.collect {
- case rel: MetastoreRelation =>
- assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
- rel.catalogTable.stats.get
- case rel: LogicalRelation =>
- assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
- rel.catalogTable.get.stats.get
- }
- assert(stats.length == 1)
- val columnStats = stats.head.colStats
- assert(columnStats.size == expectedColStatsSeq.length)
- expectedColStatsSeq.foreach { case (field, expectedColStat) =>
- StatisticsTest.checkColStat(
- dataType = field.dataType,
- colStat = columnStats(field.name),
- expectedColStat = expectedColStat,
- rsd = spark.sessionState.conf.ndvMaxError)
- }
- }
-
- test("generate and load column-level stats for data source table") {
- val dsTable = "dsTable"
- withTable(dsTable) {
- testDataFrame.write.format("parquet").saveAsTable(dsTable)
- sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
- checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
- }
- }
-
- test("generate and load column-level stats for hive serde table") {
- val hTable = "hTable"
- val tmp = "tmp"
- withTable(hTable, tmp) {
- testDataFrame.write.format("parquet").saveAsTable(tmp)
- sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
- sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
- sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
- checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
- }
- }
-
- // When caseSensitive is on, for columns with only case difference, they are different columns
- // and we should generate column stats for all of them.
- private def checkCaseSensitiveColStats(columnName: String): Unit = {
- val tableName = "tbl"
- withTable(tableName) {
- val column1 = columnName.toLowerCase
- val column2 = columnName.toUpperCase
- withSQLConf("spark.sql.caseSensitive" -> "true") {
- sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
- sql(s"INSERT INTO $tableName SELECT 1, 3.0")
- sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
- val readback = spark.table(tableName)
- val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- val columnStats = rel.catalogTable.get.stats.get.colStats
- assert(columnStats.size == 2)
- StatisticsTest.checkColStat(
- dataType = IntegerType,
- colStat = columnStats(column1),
- expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- StatisticsTest.checkColStat(
- dataType = DoubleType,
- colStat = columnStats(column2),
- expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
- rsd = spark.sessionState.conf.ndvMaxError)
- rel
- }
- assert(relations.size == 1)
- }
- }
- }
-
- test("check column statistics for case sensitive column names") {
- checkCaseSensitiveColStats(columnName = "c1")
- }
-
- test("check column statistics for case sensitive non-ascii column names") {
- // scalastyle:off
- // non ascii characters are not allowed in the source code, so we disable the scalastyle.
- checkCaseSensitiveColStats(columnName = "列c")
- // scalastyle:on
- }
-
test("estimates the size of a test MetastoreRelation") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>