aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2017-04-14 19:16:47 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-14 19:16:47 +0800
commitfb036c4413c2cd4d90880d080f418ec468d6c0fc (patch)
tree8d155d76971538e5ffd6c1f5262653ae813646ce /sql
parent7536e2849df6d63587fbf16b4ecb5db06fed7125 (diff)
downloadspark-fb036c4413c2cd4d90880d080f418ec468d6c0fc.tar.gz
spark-fb036c4413c2cd4d90880d080f418ec468d6c0fc.tar.bz2
spark-fb036c4413c2cd4d90880d080f418ec468d6c0fc.zip
[SPARK-20318][SQL] Use Catalyst type for min/max in ColumnStat for ease of estimation
## What changes were proposed in this pull request? Currently when estimating predicates like col > literal or col = literal, we will update min or max in column stats based on literal value. However, literal value is of Catalyst type (internal type), while min/max is of external type. Then for the next predicate, we again need to do type conversion to compare and update column stats. This is awkward and causes many unnecessary conversions in estimation. To solve this, we use Catalyst type for min/max in `ColumnStat`. Note that the persistent format in metastore is still of external type, so there's no inconsistency for statistics in metastore. This pr also fixes a bug for boolean type in `IN` condition. ## How was this patch tested? The changes for ColumnStat are covered by existing tests. For bug fix, a new test for boolean type in IN condition is added Author: wangzhenhua <wangzhenhua@huawei.com> Closes #17630 from wzhfy/refactorColumnStat.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala30
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala68
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala70
-rwxr-xr-xsql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala41
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala19
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala4
10 files changed, 189 insertions, 182 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index f24b240956..3d4efef953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -74,11 +75,10 @@ case class Statistics(
* Statistics collected for a column.
*
* 1. Supported data types are defined in `ColumnStat.supportsType`.
- * 2. The JVM data type stored in min/max is the external data type (used in Row) for the
- * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
- * TimestampType we store java.sql.Timestamp.
- * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
- * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * 2. The JVM data type stored in min/max is the internal data type for the corresponding
+ * Catalyst data type. For example, the internal type of DateType is Int, and that the internal
+ * type of TimestampType is Long.
+ * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms
* (sketches) might have been used, and the data collected can also be stale.
*
* @param distinctCount number of distinct values
@@ -104,22 +104,43 @@ case class ColumnStat(
/**
* Returns a map from string to string that can be used to serialize the column stats.
* The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
- * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
+ * representation for the value. min/max values are converted to the external data type. For
+ * example, for DateType we store java.sql.Date, and for TimestampType we store
+ * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]].
*
* As part of the protocol, the returned map always contains a key called "version".
* In the case min/max values are null (None), they won't appear in the map.
*/
- def toMap: Map[String, String] = {
+ def toMap(colName: String, dataType: DataType): Map[String, String] = {
val map = new scala.collection.mutable.HashMap[String, String]
map.put(ColumnStat.KEY_VERSION, "1")
map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
- min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
- max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
+ min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) }
+ max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) }
map.toMap
}
+
+ /**
+ * Converts the given value from Catalyst data type to string representation of external
+ * data type.
+ */
+ private def toExternalString(v: Any, colName: String, dataType: DataType): String = {
+ val externalValue = dataType match {
+ case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
+ case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
+ case BooleanType | _: IntegralType | FloatType | DoubleType => v
+ case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column $colName of data type: $dataType.")
+ }
+ externalValue.toString
+ }
+
}
@@ -150,28 +171,15 @@ object ColumnStat extends Logging {
* Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
* from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
*/
- def fromMap(table: String, field: StructField, map: Map[String, String])
- : Option[ColumnStat] = {
- val str2val: (String => Any) = field.dataType match {
- case _: IntegralType => _.toLong
- case _: DecimalType => new java.math.BigDecimal(_)
- case DoubleType | FloatType => _.toDouble
- case BooleanType => _.toBoolean
- case DateType => java.sql.Date.valueOf
- case TimestampType => java.sql.Timestamp.valueOf
- // This version of Spark does not use min/max for binary/string types so we ignore it.
- case BinaryType | StringType => _ => null
- case _ =>
- throw new AnalysisException("Column statistics deserialization is not supported for " +
- s"column ${field.name} of data type: ${field.dataType}.")
- }
-
+ def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = {
try {
Some(ColumnStat(
distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
// Note that flatMap(Option.apply) turns Option(null) into None.
- min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
- max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
+ min = map.get(KEY_MIN_VALUE)
+ .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
+ max = map.get(KEY_MAX_VALUE)
+ .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
@@ -184,6 +192,30 @@ object ColumnStat extends Logging {
}
/**
+ * Converts from string representation of external data type to the corresponding Catalyst data
+ * type.
+ */
+ private def fromExternalString(s: String, name: String, dataType: DataType): Any = {
+ dataType match {
+ case BooleanType => s.toBoolean
+ case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
+ case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
+ case ByteType => s.toByte
+ case ShortType => s.toShort
+ case IntegerType => s.toInt
+ case LongType => s.toLong
+ case FloatType => s.toFloat
+ case DoubleType => s.toDouble
+ case _: DecimalType => Decimal(s)
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case BinaryType | StringType => null
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column $name of data type: $dataType.")
+ }
+ }
+
+ /**
* Constructs an expression to compute column statistics for a given column.
*
* The expression should create a single struct column with the following schema:
@@ -232,11 +264,14 @@ object ColumnStat extends Logging {
}
/** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
- def rowToColumnStat(row: Row): ColumnStat = {
+ def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = {
ColumnStat(
distinctCount = BigInt(row.getLong(0)),
- min = Option(row.get(1)), // for string/binary min/max, get should return null
- max = Option(row.get(2)),
+ // for string/binary min/max, get should return null
+ min = Option(row.get(1))
+ .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
+ max = Option(row.get(2))
+ .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
nullCount = BigInt(row.getLong(3)),
avgLen = row.getLong(4),
maxLen = row.getLong(5)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 5577233ffa..f1aff62cb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.sql.types.{DecimalType, _}
object EstimationUtils {
@@ -75,4 +75,32 @@ object EstimationUtils {
// (simple computation of statistics returns product of children).
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
}
+
+ /**
+ * For simplicity we use Decimal to unify operations for data types whose min/max values can be
+ * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
+ * The two methods below are the contract of conversion.
+ */
+ def toDecimal(value: Any, dataType: DataType): Decimal = {
+ dataType match {
+ case _: NumericType | DateType | TimestampType => Decimal(value.toString)
+ case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
+ }
+ }
+
+ def fromDecimal(dec: Decimal, dataType: DataType): Any = {
+ dataType match {
+ case BooleanType => dec.toLong == 1
+ case DateType => dec.toInt
+ case TimestampType => dec.toLong
+ case ByteType => dec.toByte
+ case ShortType => dec.toShort
+ case IntegerType => dec.toInt
+ case LongType => dec.toLong
+ case FloatType => dec.toFloat
+ case DoubleType => dec.toDouble
+ case _: DecimalType => dec
+ }
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 7bd8e65112..4b6b3b14d9 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -302,30 +301,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
}
/**
- * For a SQL data type, its internal data type may be different from its external type.
- * For DateType, its internal type is Int, and its external data type is Java Date type.
- * The min/max values in ColumnStat are saved in their corresponding external type.
- *
- * @param attrDataType the column data type
- * @param litValue the literal value
- * @return a BigDecimal value
- */
- def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
- attrDataType match {
- case DateType =>
- Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
- case TimestampType =>
- Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
- case _: DecimalType =>
- Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
- case StringType | BinaryType =>
- None
- case _ =>
- Some(litValue)
- }
- }
-
- /**
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
@@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
- // We update ColumnStat structure after apply this equality predicate.
- // Set distinctCount to 1. Set nullCount to 0.
- // Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attr.dataType, literal.value)
- val newStats = colStat.copy(distinctCount = 1, min = newValue,
- max = newValue, nullCount = 0)
+ // We update ColumnStat structure after apply this equality predicate:
+ // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
+ // value.
+ val newStats = attr.dataType match {
+ case StringType | BinaryType =>
+ colStat.copy(distinctCount = 1, nullCount = 0)
+ case _ =>
+ colStat.copy(distinctCount = 1, min = Some(literal.value),
+ max = Some(literal.value), nullCount = 0)
+ }
colStatsMap(attr) = newStats
}
@@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
return Some(0.0)
}
- // Need to save new min/max using the external type value of the literal
- val newMax = convertBoundValue(
- attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
- val newMin = convertBoundValue(
- attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
-
+ val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
+ val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
// newNdv should not be greater than the old ndv. For example, column has only 2 values
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
newNdv = ndv.min(BigInt(validQuerySet.size))
if (update) {
- val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
- max = newMax, nullCount = 0)
+ val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
+ max = Some(newMax), nullCount = 0)
colStatsMap(attr) = newStats
}
@@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val colStat = colStatsMap(attr)
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
- val max = BigDecimal(statsRange.max)
- val min = BigDecimal(statsRange.min)
+ val max = statsRange.max.toBigDecimal
+ val min = statsRange.min.toBigDecimal
val ndv = BigDecimal(colStat.distinctCount)
// determine the overlapping degree between predicate range and column's range
@@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
}
if (update) {
- // Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attr.dataType, literal.value)
+ val newValue = Some(literal.value)
var newMax = colStat.max
var newMin = colStat.min
var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
@@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val colStatLeft = colStatsMap(attrLeft)
val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
.asInstanceOf[NumericRange]
- val maxLeft = BigDecimal(statsRangeLeft.max)
- val minLeft = BigDecimal(statsRangeLeft.min)
+ val maxLeft = statsRangeLeft.max
+ val minLeft = statsRangeLeft.min
val colStatRight = colStatsMap(attrRight)
val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
.asInstanceOf[NumericRange]
- val maxRight = BigDecimal(statsRangeRight.max)
- val minRight = BigDecimal(statsRangeRight.min)
+ val maxRight = statsRangeRight.max
+ val minRight = statsRangeRight.min
// determine the overlapping degree between predicate range and column's range
val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 3d13967cb6..4ac5ba5689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -17,12 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
-import java.math.{BigDecimal => JDecimal}
-import java.sql.{Date, Timestamp}
-
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
+import org.apache.spark.sql.types._
/** Value range of a column. */
@@ -31,13 +27,10 @@ trait Range {
}
/** For simplicity we use decimal to unify operations of numeric ranges. */
-case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
+case class NumericRange(min: Decimal, max: Decimal) extends Range {
override def contains(l: Literal): Boolean = {
- val decimal = l.dataType match {
- case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
- case _ => new JDecimal(l.value.toString)
- }
- min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
+ val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+ min <= lit && max >= lit
}
}
@@ -58,7 +51,10 @@ object Range {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
case StringType | BinaryType => new DefaultRange()
case _ if min.isEmpty || max.isEmpty => new NullRange()
- case _ => toNumericRange(min.get, max.get, dataType)
+ case _ =>
+ NumericRange(
+ min = EstimationUtils.toDecimal(min.get, dataType),
+ max = EstimationUtils.toDecimal(max.get, dataType))
}
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
@@ -82,51 +78,11 @@ object Range {
// binary/string types don't support intersecting.
(None, None)
case (n1: NumericRange, n2: NumericRange) =>
- val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max))
- val (newMin, newMax) = fromNumericRange(newRange, dt)
- (Some(newMin), Some(newMax))
+ // Choose the maximum of two min values, and the minimum of two max values.
+ val newMin = if (n1.min <= n2.min) n2.min else n1.min
+ val newMax = if (n1.max <= n2.max) n1.max else n2.max
+ (Some(EstimationUtils.fromDecimal(newMin, dt)),
+ Some(EstimationUtils.fromDecimal(newMax, dt)))
}
}
-
- /**
- * For simplicity we use decimal to unify operations of numeric types, the two methods below
- * are the contract of conversion.
- */
- private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
- dataType match {
- case _: NumericType =>
- NumericRange(new JDecimal(min.toString), new JDecimal(max.toString))
- case BooleanType =>
- val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
- val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case DateType =>
- val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date])
- val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case TimestampType =>
- val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp])
- val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- }
- }
-
- private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
- dataType match {
- case _: IntegralType =>
- (n.min.longValue(), n.max.longValue())
- case FloatType | DoubleType =>
- (n.min.doubleValue(), n.max.doubleValue())
- case _: DecimalType =>
- (n.min, n.max)
- case BooleanType =>
- (n.min.longValue() == 1, n.max.longValue() == 1)
- case DateType =>
- (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue()))
- case TimestampType =>
- (DateTimeUtils.toJavaTimestamp(n.min.longValue()),
- DateTimeUtils.toJavaTimestamp(n.max.longValue()))
- }
- }
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index cffb0d8739..a28447840a 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
/**
@@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
nullCount = 0, avgLen = 1, maxLen = 1)
// column cdate has 10 values from 2017-01-01 through 2017-01-10.
- val dMin = Date.valueOf("2017-01-01")
- val dMax = Date.valueOf("2017-01-10")
+ val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
+ val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10"))
val attrDate = AttributeReference("cdate", DateType)()
val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
- val decMin = new java.math.BigDecimal("0.200000000000000000")
- val decMax = new java.math.BigDecimal("0.800000000000000000")
+ val decMin = Decimal("0.200000000000000000")
+ val decMax = Decimal("0.800000000000000000")
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
nullCount = 0, avgLen = 8, maxLen = 8)
@@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint < 3 OR null") {
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
- val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf)
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
Seq(attrInt -> colStatInt),
@@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 7)
}
+ test("cbool IN (true)") {
+ validateEstimatedStats(
+ Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
+ }
+
test("cbool = true") {
validateEstimatedStats(
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
@@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdate = cast('2017-01-02' AS DATE)") {
- val d20170102 = Date.valueOf("2017-01-02")
+ val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02"))
validateEstimatedStats(
- Filter(EqualTo(attrDate, Literal(d20170102)),
+ Filter(EqualTo(attrDate, Literal(d20170102, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdate < cast('2017-01-03' AS DATE)") {
- val d20170103 = Date.valueOf("2017-01-03")
+ val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
validateEstimatedStats(
- Filter(LessThan(attrDate, Literal(d20170103)),
+ Filter(LessThan(attrDate, Literal(d20170103, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("""cdate IN ( cast('2017-01-03' AS DATE),
cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
- val d20170103 = Date.valueOf("2017-01-03")
- val d20170104 = Date.valueOf("2017-01-04")
- val d20170105 = Date.valueOf("2017-01-05")
+ val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
+ val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04"))
+ val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05"))
validateEstimatedStats(
- Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
- childStatsTestPlan(Seq(attrDate), 10L)),
+ Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType),
+ Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 3)
}
test("cdecimal = 0.400000000000000000") {
- val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
+ val dec_0_40 = Decimal("0.400000000000000000")
validateEstimatedStats(
Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdecimal < 0.60 ") {
- val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
+ val dec_0_60 = Decimal("0.600000000000000000")
validateEstimatedStats(
Filter(LessThan(attrDecimal, Literal(dec_0_60)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint = cint3") {
// no records qualify due to no overlap
- val emptyColStats = Seq[(Attribute, ColumnStat)]()
validateEstimatedStats(
Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
Nil, // set to empty
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index f62df842fa..2d6b6e8e21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap,
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{DateType, TimestampType, _}
@@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
test("test join keys of different types") {
/** Columns in a table with only one row */
def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = {
- val dec = new java.math.BigDecimal("1.000000000000000000")
- val date = Date.valueOf("2016-05-08")
- val timestamp = Timestamp.valueOf("2016-05-08 00:00:01")
+ val dec = Decimal("1.000000000000000000")
+ val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+ val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
mutable.LinkedHashMap[Attribute, ColumnStat](
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1),
+ min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2),
+ min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index f408dc4153..a5c4d22a29 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
}
test("test row size estimation") {
- val dec1 = new java.math.BigDecimal("1.000000000000000000")
- val dec2 = new java.math.BigDecimal("8.000000000000000000")
- val d1 = Date.valueOf("2016-05-08")
- val d2 = Date.valueOf("2016-05-09")
- val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
- val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+ val dec1 = Decimal("1.000000000000000000")
+ val dec2 = Decimal("8.000000000000000000")
+ val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+ val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09"))
+ val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
+ val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02"))
val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
+ min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
+ min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index b89014ed8e..0d8db2ff5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -73,10 +73,10 @@ case class AnalyzeColumnCommand(
val relation = sparkSession.table(tableIdent).logicalPlan
// Resolve the column names and dedup using AttributeSet
val resolver = sparkSession.sessionState.conf.resolver
- val attributesToAnalyze = AttributeSet(columnNames.map { col =>
+ val attributesToAnalyze = columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
- }).toSeq
+ }
// Make sure the column types are supported for stats gathering.
attributesToAnalyze.foreach { attr =>
@@ -99,8 +99,8 @@ case class AnalyzeColumnCommand(
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
val rowCount = statsRow.getLong(0)
- val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
- (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
+ val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
+ (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr))
}.toMap
(rowCount, columnStats)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 1f547c5a2a..ddc393c8da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -26,6 +26,7 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
@@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
stats.zip(df.schema).foreach { case ((k, v), field) =>
withClue(s"column $k with type ${field.dataType}") {
- val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
+ val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType))
assert(roundtrip == Some(v))
}
}
@@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
/** A mapping from column to the stats collected. */
protected val stats = mutable.LinkedHashMap(
"cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
- "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
- "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
- "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
+ "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1),
+ "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2),
+ "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4),
"clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
"cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
- "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
- "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
+ "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4),
+ "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16),
"cstring" -> ColumnStat(2, None, None, 1, 3, 3),
"cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
- "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
- "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
+ "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)),
+ Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4),
+ "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)),
+ Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8)
)
private val randomName = new Random(31)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 806f2be5fa..8b0fdf49ce 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
if (stats.rowCount.isDefined) {
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
}
+ val colNameTypeMap: Map[String, DataType] =
+ tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap
stats.colStats.foreach { case (colName, colStat) =>
- colStat.toMap.foreach { case (k, v) =>
+ colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
statsProperties += (columnStatKeyPropName(colName, k) -> v)
}
}