aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-07-25 19:17:49 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-25 19:17:49 -0700
commit8904791230a0fae336db93e5a80f65c4d9d584dc (patch)
tree9d236ecea8ee3a1acf5454df532f43c93adac903 /sql
parent9d8666cac84fc4fc867f6a5e80097dbe5cb65301 (diff)
downloadspark-8904791230a0fae336db93e5a80f65c4d9d584dc.tar.gz
spark-8904791230a0fae336db93e5a80f65c4d9d584dc.tar.bz2
spark-8904791230a0fae336db93e5a80f65c4d9d584dc.zip
[SPARK-2659][SQL] Fix division semantics for hive
Author: Michael Armbrust <michael@databricks.com> Closes #1557 from marmbrus/fixDivision and squashes the following commits: b85077f [Michael Armbrust] Fix unit tests. af98f29 [Michael Armbrust] Change DIV to long type 0c29ae8 [Michael Armbrust] Fix division semantics for hive
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala3
-rw-r--r--sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d45851
-rw-r--r--sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f71
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala5
7 files changed, 27 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 67a8ce9b88..47c7ad076a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -50,6 +50,7 @@ trait HiveTypeCoercion {
StringToIntegralCasts ::
FunctionArgumentConversion ::
CastNulls ::
+ Division ::
Nil
/**
@@ -318,6 +319,23 @@ trait HiveTypeCoercion {
}
/**
+ * Hive only performs integral division with the DIV operator. The arguments to / are always
+ * converted to fractional types.
+ */
+ object Division extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ // Decimal and Double remain the same
+ case d: Divide if d.dataType == DoubleType => d
+ case d: Divide if d.dataType == DecimalType => d
+
+ case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+ }
+ }
+
+ /**
* Ensures that NullType gets casted to some other types under certain circumstances.
*/
object CastNulls extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index d607eed1be..0a27cce337 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal(10) as Symbol("2*3+4"),
Literal(14) as Symbol("2*(3+4)"))
.where(Literal(true))
- .groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
+ .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4395874526..e6ab68b563 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -925,7 +925,8 @@ private[hive] object HiveQl {
case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token(DIV(), left :: right:: Nil) =>
+ Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
/* Comparisons */
diff --git a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
new file mode 100644
index 0000000000..17ba0bea72
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
@@ -0,0 +1 @@
+0 0 0 1 2
diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
new file mode 100644
index 0000000000..7b7a917511
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
@@ -0,0 +1 @@
+2.0 0.5 0.3333333333333333 0.002
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 08ef4d9b6b..b4dbf2b115 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -350,12 +350,6 @@ abstract class HiveComparisonTest
val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n")
- println("hive output")
- hive.foreach(println)
-
- println("catalyst printout")
- catalyst.foreach(println)
-
if (recomputeCache) {
logger.warn(s"Clearing cache files for failed test $testCaseName")
hiveCacheFiles.foreach(_.delete())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6f36a4f8cb..a8623b64c6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src WHERE key Between 1 and 2")
createQueryTest("div",
- "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1")
+ "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1")
+
+ createQueryTest("division",
+ "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")
test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))