aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala30
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala68
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala70
-rwxr-xr-xsql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala41
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala19
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala4
10 files changed, 189 insertions, 182 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index f24b240956..3d4efef953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -74,11 +75,10 @@ case class Statistics(
* Statistics collected for a column.
*
* 1. Supported data types are defined in `ColumnStat.supportsType`.
- * 2. The JVM data type stored in min/max is the external data type (used in Row) for the
- * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
- * TimestampType we store java.sql.Timestamp.
- * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
- * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * 2. The JVM data type stored in min/max is the internal data type for the corresponding
+ * Catalyst data type. For example, the internal type of DateType is Int, and that the internal
+ * type of TimestampType is Long.
+ * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms
* (sketches) might have been used, and the data collected can also be stale.
*
* @param distinctCount number of distinct values
@@ -104,22 +104,43 @@ case class ColumnStat(
/**
* Returns a map from string to string that can be used to serialize the column stats.
* The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
- * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
+ * representation for the value. min/max values are converted to the external data type. For
+ * example, for DateType we store java.sql.Date, and for TimestampType we store
+ * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]].
*
* As part of the protocol, the returned map always contains a key called "version".
* In the case min/max values are null (None), they won't appear in the map.
*/
- def toMap: Map[String, String] = {
+ def toMap(colName: String, dataType: DataType): Map[String, String] = {
val map = new scala.collection.mutable.HashMap[String, String]
map.put(ColumnStat.KEY_VERSION, "1")
map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
- min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
- max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
+ min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) }
+ max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) }
map.toMap
}
+
+ /**
+ * Converts the given value from Catalyst data type to string representation of external
+ * data type.
+ */
+ private def toExternalString(v: Any, colName: String, dataType: DataType): String = {
+ val externalValue = dataType match {
+ case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
+ case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
+ case BooleanType | _: IntegralType | FloatType | DoubleType => v
+ case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column $colName of data type: $dataType.")
+ }
+ externalValue.toString
+ }
+
}
@@ -150,28 +171,15 @@ object ColumnStat extends Logging {
* Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
* from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
*/
- def fromMap(table: String, field: StructField, map: Map[String, String])
- : Option[ColumnStat] = {
- val str2val: (String => Any) = field.dataType match {
- case _: IntegralType => _.toLong
- case _: DecimalType => new java.math.BigDecimal(_)
- case DoubleType | FloatType => _.toDouble
- case BooleanType => _.toBoolean
- case DateType => java.sql.Date.valueOf
- case TimestampType => java.sql.Timestamp.valueOf
- // This version of Spark does not use min/max for binary/string types so we ignore it.
- case BinaryType | StringType => _ => null
- case _ =>
- throw new AnalysisException("Column statistics deserialization is not supported for " +
- s"column ${field.name} of data type: ${field.dataType}.")
- }
-
+ def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = {
try {
Some(ColumnStat(
distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
// Note that flatMap(Option.apply) turns Option(null) into None.
- min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
- max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
+ min = map.get(KEY_MIN_VALUE)
+ .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
+ max = map.get(KEY_MAX_VALUE)
+ .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
@@ -184,6 +192,30 @@ object ColumnStat extends Logging {
}
/**
+ * Converts from string representation of external data type to the corresponding Catalyst data
+ * type.
+ */
+ private def fromExternalString(s: String, name: String, dataType: DataType): Any = {
+ dataType match {
+ case BooleanType => s.toBoolean
+ case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
+ case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
+ case ByteType => s.toByte
+ case ShortType => s.toShort
+ case IntegerType => s.toInt
+ case LongType => s.toLong
+ case FloatType => s.toFloat
+ case DoubleType => s.toDouble
+ case _: DecimalType => Decimal(s)
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case BinaryType | StringType => null
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column $name of data type: $dataType.")
+ }
+ }
+
+ /**
* Constructs an expression to compute column statistics for a given column.
*
* The expression should create a single struct column with the following schema:
@@ -232,11 +264,14 @@ object ColumnStat extends Logging {
}
/** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
- def rowToColumnStat(row: Row): ColumnStat = {
+ def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = {
ColumnStat(
distinctCount = BigInt(row.getLong(0)),
- min = Option(row.get(1)), // for string/binary min/max, get should return null
- max = Option(row.get(2)),
+ // for string/binary min/max, get should return null
+ min = Option(row.get(1))
+ .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
+ max = Option(row.get(2))
+ .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
nullCount = BigInt(row.getLong(3)),
avgLen = row.getLong(4),
maxLen = row.getLong(5)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 5577233ffa..f1aff62cb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.sql.types.{DecimalType, _}
object EstimationUtils {
@@ -75,4 +75,32 @@ object EstimationUtils {
// (simple computation of statistics returns product of children).
if (outputRowCount > 0) outputRowCount * sizePerRow else 1
}
+
+ /**
+ * For simplicity we use Decimal to unify operations for data types whose min/max values can be
+ * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
+ * The two methods below are the contract of conversion.
+ */
+ def toDecimal(value: Any, dataType: DataType): Decimal = {
+ dataType match {
+ case _: NumericType | DateType | TimestampType => Decimal(value.toString)
+ case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
+ }
+ }
+
+ def fromDecimal(dec: Decimal, dataType: DataType): Any = {
+ dataType match {
+ case BooleanType => dec.toLong == 1
+ case DateType => dec.toInt
+ case TimestampType => dec.toLong
+ case ByteType => dec.toByte
+ case ShortType => dec.toShort
+ case IntegerType => dec.toInt
+ case LongType => dec.toLong
+ case FloatType => dec.toFloat
+ case DoubleType => dec.toDouble
+ case _: DecimalType => dec
+ }
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 7bd8e65112..4b6b3b14d9 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -302,30 +301,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
}
/**
- * For a SQL data type, its internal data type may be different from its external type.
- * For DateType, its internal type is Int, and its external data type is Java Date type.
- * The min/max values in ColumnStat are saved in their corresponding external type.
- *
- * @param attrDataType the column data type
- * @param litValue the literal value
- * @return a BigDecimal value
- */
- def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
- attrDataType match {
- case DateType =>
- Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
- case TimestampType =>
- Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
- case _: DecimalType =>
- Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
- case StringType | BinaryType =>
- None
- case _ =>
- Some(litValue)
- }
- }
-
- /**
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
@@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
- // We update ColumnStat structure after apply this equality predicate.
- // Set distinctCount to 1. Set nullCount to 0.
- // Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attr.dataType, literal.value)
- val newStats = colStat.copy(distinctCount = 1, min = newValue,
- max = newValue, nullCount = 0)
+ // We update ColumnStat structure after apply this equality predicate:
+ // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
+ // value.
+ val newStats = attr.dataType match {
+ case StringType | BinaryType =>
+ colStat.copy(distinctCount = 1, nullCount = 0)
+ case _ =>
+ colStat.copy(distinctCount = 1, min = Some(literal.value),
+ max = Some(literal.value), nullCount = 0)
+ }
colStatsMap(attr) = newStats
}
@@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
return Some(0.0)
}
- // Need to save new min/max using the external type value of the literal
- val newMax = convertBoundValue(
- attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
- val newMin = convertBoundValue(
- attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
-
+ val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
+ val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
// newNdv should not be greater than the old ndv. For example, column has only 2 values
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
newNdv = ndv.min(BigInt(validQuerySet.size))
if (update) {
- val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
- max = newMax, nullCount = 0)
+ val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
+ max = Some(newMax), nullCount = 0)
colStatsMap(attr) = newStats
}
@@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val colStat = colStatsMap(attr)
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
- val max = BigDecimal(statsRange.max)
- val min = BigDecimal(statsRange.min)
+ val max = statsRange.max.toBigDecimal
+ val min = statsRange.min.toBigDecimal
val ndv = BigDecimal(colStat.distinctCount)
// determine the overlapping degree between predicate range and column's range
@@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
}
if (update) {
- // Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attr.dataType, literal.value)
+ val newValue = Some(literal.value)
var newMax = colStat.max
var newMin = colStat.min
var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
@@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val colStatLeft = colStatsMap(attrLeft)
val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
.asInstanceOf[NumericRange]
- val maxLeft = BigDecimal(statsRangeLeft.max)
- val minLeft = BigDecimal(statsRangeLeft.min)
+ val maxLeft = statsRangeLeft.max
+ val minLeft = statsRangeLeft.min
val colStatRight = colStatsMap(attrRight)
val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
.asInstanceOf[NumericRange]
- val maxRight = BigDecimal(statsRangeRight.max)
- val minRight = BigDecimal(statsRangeRight.min)
+ val maxRight = statsRangeRight.max
+ val minRight = statsRangeRight.min
// determine the overlapping degree between predicate range and column's range
val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 3d13967cb6..4ac5ba5689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -17,12 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
-import java.math.{BigDecimal => JDecimal}
-import java.sql.{Date, Timestamp}
-
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
+import org.apache.spark.sql.types._
/** Value range of a column. */
@@ -31,13 +27,10 @@ trait Range {
}
/** For simplicity we use decimal to unify operations of numeric ranges. */
-case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
+case class NumericRange(min: Decimal, max: Decimal) extends Range {
override def contains(l: Literal): Boolean = {
- val decimal = l.dataType match {
- case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
- case _ => new JDecimal(l.value.toString)
- }
- min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
+ val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+ min <= lit && max >= lit
}
}
@@ -58,7 +51,10 @@ object Range {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
case StringType | BinaryType => new DefaultRange()
case _ if min.isEmpty || max.isEmpty => new NullRange()
- case _ => toNumericRange(min.get, max.get, dataType)
+ case _ =>
+ NumericRange(
+ min = EstimationUtils.toDecimal(min.get, dataType),
+ max = EstimationUtils.toDecimal(max.get, dataType))
}
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
@@ -82,51 +78,11 @@ object Range {
// binary/string types don't support intersecting.
(None, None)
case (n1: NumericRange, n2: NumericRange) =>
- val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max))
- val (newMin, newMax) = fromNumericRange(newRange, dt)
- (Some(newMin), Some(newMax))
+ // Choose the maximum of two min values, and the minimum of two max values.
+ val newMin = if (n1.min <= n2.min) n2.min else n1.min
+ val newMax = if (n1.max <= n2.max) n1.max else n2.max
+ (Some(EstimationUtils.fromDecimal(newMin, dt)),
+ Some(EstimationUtils.fromDecimal(newMax, dt)))
}
}
-
- /**
- * For simplicity we use decimal to unify operations of numeric types, the two methods below
- * are the contract of conversion.
- */
- private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
- dataType match {
- case _: NumericType =>
- NumericRange(new JDecimal(min.toString), new JDecimal(max.toString))
- case BooleanType =>
- val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
- val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case DateType =>
- val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date])
- val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case TimestampType =>
- val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp])
- val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- }
- }
-
- private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
- dataType match {
- case _: IntegralType =>
- (n.min.longValue(), n.max.longValue())
- case FloatType | DoubleType =>
- (n.min.doubleValue(), n.max.doubleValue())
- case _: DecimalType =>
- (n.min, n.max)
- case BooleanType =>
- (n.min.longValue() == 1, n.max.longValue() == 1)
- case DateType =>
- (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue()))
- case TimestampType =>
- (DateTimeUtils.toJavaTimestamp(n.min.longValue()),
- DateTimeUtils.toJavaTimestamp(n.max.longValue()))
- }
- }
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index cffb0d8739..a28447840a 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
/**
@@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
nullCount = 0, avgLen = 1, maxLen = 1)
// column cdate has 10 values from 2017-01-01 through 2017-01-10.
- val dMin = Date.valueOf("2017-01-01")
- val dMax = Date.valueOf("2017-01-10")
+ val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
+ val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10"))
val attrDate = AttributeReference("cdate", DateType)()
val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
- val decMin = new java.math.BigDecimal("0.200000000000000000")
- val decMax = new java.math.BigDecimal("0.800000000000000000")
+ val decMin = Decimal("0.200000000000000000")
+ val decMax = Decimal("0.800000000000000000")
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
nullCount = 0, avgLen = 8, maxLen = 8)
@@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint < 3 OR null") {
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
- val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf)
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
Seq(attrInt -> colStatInt),
@@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 7)
}
+ test("cbool IN (true)") {
+ validateEstimatedStats(
+ Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
+ }
+
test("cbool = true") {
validateEstimatedStats(
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
@@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdate = cast('2017-01-02' AS DATE)") {
- val d20170102 = Date.valueOf("2017-01-02")
+ val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02"))
validateEstimatedStats(
- Filter(EqualTo(attrDate, Literal(d20170102)),
+ Filter(EqualTo(attrDate, Literal(d20170102, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdate < cast('2017-01-03' AS DATE)") {
- val d20170103 = Date.valueOf("2017-01-03")
+ val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
validateEstimatedStats(
- Filter(LessThan(attrDate, Literal(d20170103)),
+ Filter(LessThan(attrDate, Literal(d20170103, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("""cdate IN ( cast('2017-01-03' AS DATE),
cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
- val d20170103 = Date.valueOf("2017-01-03")
- val d20170104 = Date.valueOf("2017-01-04")
- val d20170105 = Date.valueOf("2017-01-05")
+ val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
+ val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04"))
+ val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05"))
validateEstimatedStats(
- Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
- childStatsTestPlan(Seq(attrDate), 10L)),
+ Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType),
+ Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)),
Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 3)
}
test("cdecimal = 0.400000000000000000") {
- val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
+ val dec_0_40 = Decimal("0.400000000000000000")
validateEstimatedStats(
Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
test("cdecimal < 0.60 ") {
- val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
+ val dec_0_60 = Decimal("0.600000000000000000")
validateEstimatedStats(
Filter(LessThan(attrDecimal, Literal(dec_0_60)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint = cint3") {
// no records qualify due to no overlap
- val emptyColStats = Seq[(Attribute, ColumnStat)]()
validateEstimatedStats(
Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
Nil, // set to empty
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index f62df842fa..2d6b6e8e21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap,
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{DateType, TimestampType, _}
@@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
test("test join keys of different types") {
/** Columns in a table with only one row */
def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = {
- val dec = new java.math.BigDecimal("1.000000000000000000")
- val date = Date.valueOf("2016-05-08")
- val timestamp = Timestamp.valueOf("2016-05-08 00:00:01")
+ val dec = Decimal("1.000000000000000000")
+ val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+ val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
mutable.LinkedHashMap[Attribute, ColumnStat](
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1),
+ min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2),
+ min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index f408dc4153..a5c4d22a29 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
}
test("test row size estimation") {
- val dec1 = new java.math.BigDecimal("1.000000000000000000")
- val dec2 = new java.math.BigDecimal("8.000000000000000000")
- val d1 = Date.valueOf("2016-05-08")
- val d2 = Date.valueOf("2016-05-09")
- val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
- val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+ val dec1 = Decimal("1.000000000000000000")
+ val dec2 = Decimal("8.000000000000000000")
+ val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+ val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09"))
+ val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
+ val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02"))
val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
+ min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
+ min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
+ min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4),
AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index b89014ed8e..0d8db2ff5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -73,10 +73,10 @@ case class AnalyzeColumnCommand(
val relation = sparkSession.table(tableIdent).logicalPlan
// Resolve the column names and dedup using AttributeSet
val resolver = sparkSession.sessionState.conf.resolver
- val attributesToAnalyze = AttributeSet(columnNames.map { col =>
+ val attributesToAnalyze = columnNames.map { col =>
val exprOption = relation.output.find(attr => resolver(attr.name, col))
exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
- }).toSeq
+ }
// Make sure the column types are supported for stats gathering.
attributesToAnalyze.foreach { attr =>
@@ -99,8 +99,8 @@ case class AnalyzeColumnCommand(
val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
val rowCount = statsRow.getLong(0)
- val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
- (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
+ val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
+ (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr))
}.toMap
(rowCount, columnStats)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 1f547c5a2a..ddc393c8da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -26,6 +26,7 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
@@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
stats.zip(df.schema).foreach { case ((k, v), field) =>
withClue(s"column $k with type ${field.dataType}") {
- val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
+ val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType))
assert(roundtrip == Some(v))
}
}
@@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
/** A mapping from column to the stats collected. */
protected val stats = mutable.LinkedHashMap(
"cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
- "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
- "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
- "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
+ "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1),
+ "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2),
+ "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4),
"clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
"cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
- "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
- "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
+ "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4),
+ "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16),
"cstring" -> ColumnStat(2, None, None, 1, 3, 3),
"cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
- "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
- "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
+ "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)),
+ Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4),
+ "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)),
+ Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8)
)
private val randomName = new Random(31)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 806f2be5fa..8b0fdf49ce 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
if (stats.rowCount.isDefined) {
statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
}
+ val colNameTypeMap: Map[String, DataType] =
+ tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap
stats.colStats.foreach { case (colName, colStat) =>
- colStat.toMap.foreach { case (k, v) =>
+ colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
statsProperties += (columnStatKeyPropName(colName, k) -> v)
}
}