aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-07-11 20:26:01 -0700
committerReynold Xin <rxin@databricks.com>2016-07-11 20:26:01 -0700
commit9cc74f95edb6e4f56151966139cd0dc24e377949 (patch)
tree998e7df35cad5f66e09f87228f6ca0d918384279
parente50efd53f073890d789a8448f850cc219cca7708 (diff)
downloadspark-9cc74f95edb6e4f56151966139cd0dc24e377949.tar.gz
spark-9cc74f95edb6e4f56151966139cd0dc24e377949.tar.bz2
spark-9cc74f95edb6e4f56151966139cd0dc24e377949.zip
[SPARK-16488] Fix codegen variable namespace collision in pmod and partitionBy
## What changes were proposed in this pull request? This patch fixes a variable namespace collision bug in pmod and partitionBy ## How was this patch tested? Regression test for one possible occurrence. A more general fix in `ExpressionEvalHelper.checkEvaluation` will be in a subsequent PR. Author: Sameer Agarwal <sameer@databricks.com> Closes #14144 from sameeragarwal/codegen-bug.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala14
2 files changed, 27 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4db1352291..91ffac0ba2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -498,34 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ val remainder = ctx.freshName("remainder")
dataType match {
case dt: DecimalType =>
val decimalAdd = "$plus"
s"""
- ${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
- if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
- ${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2);
+ ${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2);
+ if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
+ ${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2);
} else {
- ${ev.value} = r;
+ ${ev.value} = $remainder;
}
"""
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
- ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
- if (r < 0) {
- ${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
+ ${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2);
+ if ($remainder < 0) {
+ ${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2);
} else {
- ${ev.value} = r;
+ ${ev.value} = $remainder;
}
"""
case _ =>
s"""
- ${ctx.javaType(dataType)} r = $eval1 % $eval2;
- if (r < 0) {
- ${ev.value} = (r + $eval2) % $eval2;
+ ${ctx.javaType(dataType)} $remainder = $eval1 % $eval2;
+ if ($remainder < 0) {
+ ${ev.value} = ($remainder + $eval2) % $eval2;
} else {
- ${ev.value} = r;
+ ${ev.value} = $remainder;
}
"""
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 05935cec4b..f706b20364 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -449,6 +449,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
}
+ test("pmod with partitionBy") {
+ val spark = this.spark
+ import spark.implicits._
+
+ case class Test(a: Int, b: String)
+ val data = Seq((0, "a"), (1, "b"), (1, "a"))
+ spark.createDataset(data).createOrReplaceTempView("test")
+ sql("select * from test distribute by pmod(_1, 2)")
+ .write
+ .partitionBy("_2")
+ .mode("overwrite")
+ .parquet(dir)
+ }
+
private def testRead(
df: => DataFrame,
expectedResult: Seq[String],