aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-10-22 20:09:04 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-10-22 20:09:04 +0200
commit5fa9f8795a71e08bcbef5975ba8c072db5be8866 (patch)
tree50d1eced012bf01b7ed01cae6f8d42f286e1489e
parent3eca283aca68ac81c127d60ad5699f854d5f14b7 (diff)
downloadspark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.tar.gz
spark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.tar.bz2
spark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.zip
[SPARK-17123][SQL] Use type-widened encoder for DataFrame rather than existing encoder to allow type-widening from set operations
# What changes were proposed in this pull request? This PR fixes set operations in `DataFrame` to be performed fine without exceptions when the types are non-scala native types. (e.g, `TimestampType`, `DateType` and `DecimalType`). The problem is, it seems set operations such as `union`, `intersect` and `except` uses the encoder belonging to the `Dataset` in caller. So, `Dataset` of the caller holds `ExpressionEncoder[Row]` as it is when the set operations are performed. However, the return types can be actually widen. So, we should use `ExpressionEncoder[Row]` constructed from executed plan rather than using existing one. Otherwise, this will generate some codes wrongly via `StaticInvoke`. Running the codes below: ```scala val dates = Seq( (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) ).toDF("date", "timestamp", "decimal") val widenTypedRows = Seq( (new Timestamp(2), 10.5D, "string") ).toDF("date", "timestamp", "decimal") val results = dates.union(widenTypedRows).collect() results.foreach(println) ``` prints below: **Before** ```java 23:08:54.490 ERROR org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 28, Column 107: No applicable constructor/method found for actual parameters "long"; candidates are: "public static java.sql.Date org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(int)" /* 001 */ public java.lang.Object generate(Object[] references) { /* 002 */ return new SpecificSafeProjection(references); /* 003 */ } /* 004 */ /* 005 */ class SpecificSafeProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private MutableRow mutableRow; /* 009 */ private Object[] values; /* 010 */ private org.apache.spark.sql.types.StructType schema; /* 011 */ /* 012 */ /* 013 */ public SpecificSafeProjection(Object[] references) { /* 014 */ this.references = references; /* 015 */ mutableRow = (MutableRow) references[references.length - 1]; /* 016 */ /* 017 */ this.schema = (org.apache.spark.sql.types.StructType) references[0]; /* 018 */ } /* 019 */ /* 020 */ public java.lang.Object apply(java.lang.Object _i) { /* 021 */ InternalRow i = (InternalRow) _i; /* 022 */ /* 023 */ values = new Object[3]; /* 024 */ /* 025 */ boolean isNull2 = i.isNullAt(0); /* 026 */ long value2 = isNull2 ? -1L : (i.getLong(0)); /* 027 */ boolean isNull1 = isNull2; /* 028 */ final java.sql.Date value1 = isNull1 ? null : org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(value2); /* 029 */ isNull1 = value1 == null; /* 030 */ if (isNull1) { /* 031 */ values[0] = null; /* 032 */ } else { /* 033 */ values[0] = value1; /* 034 */ } /* 035 */ /* 036 */ boolean isNull4 = i.isNullAt(1); /* 037 */ double value4 = isNull4 ? -1.0 : (i.getDouble(1)); /* 038 */ /* 039 */ boolean isNull3 = isNull4; /* 040 */ java.math.BigDecimal value3 = null; /* 041 */ if (!isNull3) { /* 042 */ /* 043 */ Object funcResult = null; /* 044 */ funcResult = value4.toJavaBigDecimal(); /* 045 */ if (funcResult == null) { /* 046 */ isNull3 = true; /* 047 */ } else { /* 048 */ value3 = (java.math.BigDecimal) funcResult; /* 049 */ } /* 050 */ /* 051 */ } /* 052 */ isNull3 = value3 == null; /* 053 */ if (isNull3) { /* 054 */ values[1] = null; /* 055 */ } else { /* 056 */ values[1] = value3; /* 057 */ } /* 058 */ /* 059 */ boolean isNull6 = i.isNullAt(2); /* 060 */ UTF8String value6 = isNull6 ? null : (i.getUTF8String(2)); /* 061 */ boolean isNull5 = isNull6; /* 062 */ final java.sql.Timestamp value5 = isNull5 ? null : org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaTimestamp(value6); /* 063 */ isNull5 = value5 == null; /* 064 */ if (isNull5) { /* 065 */ values[2] = null; /* 066 */ } else { /* 067 */ values[2] = value5; /* 068 */ } /* 069 */ /* 070 */ final org.apache.spark.sql.Row value = new org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema(values, schema); /* 071 */ if (false) { /* 072 */ mutableRow.setNullAt(0); /* 073 */ } else { /* 074 */ /* 075 */ mutableRow.update(0, value); /* 076 */ } /* 077 */ /* 078 */ return mutableRow; /* 079 */ } /* 080 */ } ``` **After** ```bash [1969-12-31 00:00:00.0,1.0,1969-12-31 16:00:00.002] [1969-12-31 00:00:00.0,4.0,1969-12-31 16:00:00.005] [1969-12-31 16:00:00.002,10.5,string] ``` ## How was this patch tested? Unit tests in `DataFrameSuite` Author: hyukjinkwon <gurwls223@gmail.com> Closes #15072 from HyukjinKwon/SPARK-17123.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala16
2 files changed, 30 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 073d2b1512..286d8549bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -556,7 +556,7 @@ class Dataset[T] private[sql](
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
- *
+ *
* @param numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
@@ -1524,7 +1524,7 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 2.0.0
*/
- def union(other: Dataset[T]): Dataset[T] = withTypedPlan {
+ def union(other: Dataset[T]): Dataset[T] = withSetOperator {
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
CombineUnions(Union(logicalPlan, other.logicalPlan))
@@ -1540,7 +1540,7 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 1.6.0
*/
- def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan {
+ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator {
Intersect(logicalPlan, other.logicalPlan)
}
@@ -1554,7 +1554,7 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 2.0.0
*/
- def except(other: Dataset[T]): Dataset[T] = withTypedPlan {
+ def except(other: Dataset[T]): Dataset[T] = withSetOperator {
Except(logicalPlan, other.logicalPlan)
}
@@ -2725,4 +2725,14 @@ class Dataset[T] private[sql](
@inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
Dataset(sparkSession, logicalPlan)
}
+
+ /** A convenient function to wrap a set based logical plan and produce a Dataset. */
+ @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+ if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) {
+ // Set operators widen types (change the schema), so we cannot reuse the row encoder.
+ Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]]
+ } else {
+ Dataset(sparkSession, logicalPlan)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 16cc368208..e87baa454c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.io.File
import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
import java.util.UUID
import scala.util.Random
@@ -1615,4 +1616,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
qe.assertAnalyzed()
}
}
+
+ test("SPARK-17123: Performing set operations that combine non-scala native types") {
+ val dates = Seq(
+ (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),
+ (new Date(3), BigDecimal.valueOf(4), new Timestamp(5))
+ ).toDF("date", "timestamp", "decimal")
+
+ val widenTypedRows = Seq(
+ (new Timestamp(2), 10.5D, "string")
+ ).toDF("date", "timestamp", "decimal")
+
+ dates.union(widenTypedRows).collect()
+ dates.except(widenTypedRows).collect()
+ dates.intersect(widenTypedRows).collect()
+ }
}