aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala70
1 files changed, 13 insertions, 57 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 3d13967cb6..4ac5ba5689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -17,12 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
-import java.math.{BigDecimal => JDecimal}
-import java.sql.{Date, Timestamp}
-
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
+import org.apache.spark.sql.types._
/** Value range of a column. */
@@ -31,13 +27,10 @@ trait Range {
}
/** For simplicity we use decimal to unify operations of numeric ranges. */
-case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
+case class NumericRange(min: Decimal, max: Decimal) extends Range {
override def contains(l: Literal): Boolean = {
- val decimal = l.dataType match {
- case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
- case _ => new JDecimal(l.value.toString)
- }
- min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
+ val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+ min <= lit && max >= lit
}
}
@@ -58,7 +51,10 @@ object Range {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
case StringType | BinaryType => new DefaultRange()
case _ if min.isEmpty || max.isEmpty => new NullRange()
- case _ => toNumericRange(min.get, max.get, dataType)
+ case _ =>
+ NumericRange(
+ min = EstimationUtils.toDecimal(min.get, dataType),
+ max = EstimationUtils.toDecimal(max.get, dataType))
}
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
@@ -82,51 +78,11 @@ object Range {
// binary/string types don't support intersecting.
(None, None)
case (n1: NumericRange, n2: NumericRange) =>
- val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max))
- val (newMin, newMax) = fromNumericRange(newRange, dt)
- (Some(newMin), Some(newMax))
+ // Choose the maximum of two min values, and the minimum of two max values.
+ val newMin = if (n1.min <= n2.min) n2.min else n1.min
+ val newMax = if (n1.max <= n2.max) n1.max else n2.max
+ (Some(EstimationUtils.fromDecimal(newMin, dt)),
+ Some(EstimationUtils.fromDecimal(newMax, dt)))
}
}
-
- /**
- * For simplicity we use decimal to unify operations of numeric types, the two methods below
- * are the contract of conversion.
- */
- private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
- dataType match {
- case _: NumericType =>
- NumericRange(new JDecimal(min.toString), new JDecimal(max.toString))
- case BooleanType =>
- val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
- val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case DateType =>
- val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date])
- val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- case TimestampType =>
- val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp])
- val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp])
- NumericRange(new JDecimal(min1), new JDecimal(max1))
- }
- }
-
- private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
- dataType match {
- case _: IntegralType =>
- (n.min.longValue(), n.max.longValue())
- case FloatType | DoubleType =>
- (n.min.doubleValue(), n.max.doubleValue())
- case _: DecimalType =>
- (n.min, n.max)
- case BooleanType =>
- (n.min.longValue() == 1, n.max.longValue() == 1)
- case DateType =>
- (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue()))
- case TimestampType =>
- (DateTimeUtils.toJavaTimestamp(n.min.longValue()),
- DateTimeUtils.toJavaTimestamp(n.max.longValue()))
- }
- }
-
}