diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2016-10-22 20:09:04 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-10-22 20:09:04 +0200 |
commit | 5fa9f8795a71e08bcbef5975ba8c072db5be8866 (patch) | |
tree | 50d1eced012bf01b7ed01cae6f8d42f286e1489e | |
parent | 3eca283aca68ac81c127d60ad5699f854d5f14b7 (diff) | |
download | spark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.tar.gz spark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.tar.bz2 spark-5fa9f8795a71e08bcbef5975ba8c072db5be8866.zip |
[SPARK-17123][SQL] Use type-widened encoder for DataFrame rather than existing encoder to allow type-widening from set operations
# What changes were proposed in this pull request?
This PR fixes set operations in `DataFrame` to be performed fine without exceptions when the types are non-scala native types. (e.g, `TimestampType`, `DateType` and `DecimalType`).
The problem is, it seems set operations such as `union`, `intersect` and `except` uses the encoder belonging to the `Dataset` in caller.
So, `Dataset` of the caller holds `ExpressionEncoder[Row]` as it is when the set operations are performed. However, the return types can be actually widen. So, we should use `ExpressionEncoder[Row]` constructed from executed plan rather than using existing one. Otherwise, this will generate some codes wrongly via `StaticInvoke`.
Running the codes below:
```scala
val dates = Seq(
(new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),
(new Date(3), BigDecimal.valueOf(4), new Timestamp(5))
).toDF("date", "timestamp", "decimal")
val widenTypedRows = Seq(
(new Timestamp(2), 10.5D, "string")
).toDF("date", "timestamp", "decimal")
val results = dates.union(widenTypedRows).collect()
results.foreach(println)
```
prints below:
**Before**
```java
23:08:54.490 ERROR org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 28, Column 107: No applicable constructor/method found for actual parameters "long"; candidates are: "public static java.sql.Date org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(int)"
/* 001 */ public java.lang.Object generate(Object[] references) {
/* 002 */ return new SpecificSafeProjection(references);
/* 003 */ }
/* 004 */
/* 005 */ class SpecificSafeProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseProjection {
/* 006 */
/* 007 */ private Object[] references;
/* 008 */ private MutableRow mutableRow;
/* 009 */ private Object[] values;
/* 010 */ private org.apache.spark.sql.types.StructType schema;
/* 011 */
/* 012 */
/* 013 */ public SpecificSafeProjection(Object[] references) {
/* 014 */ this.references = references;
/* 015 */ mutableRow = (MutableRow) references[references.length - 1];
/* 016 */
/* 017 */ this.schema = (org.apache.spark.sql.types.StructType) references[0];
/* 018 */ }
/* 019 */
/* 020 */ public java.lang.Object apply(java.lang.Object _i) {
/* 021 */ InternalRow i = (InternalRow) _i;
/* 022 */
/* 023 */ values = new Object[3];
/* 024 */
/* 025 */ boolean isNull2 = i.isNullAt(0);
/* 026 */ long value2 = isNull2 ? -1L : (i.getLong(0));
/* 027 */ boolean isNull1 = isNull2;
/* 028 */ final java.sql.Date value1 = isNull1 ? null : org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(value2);
/* 029 */ isNull1 = value1 == null;
/* 030 */ if (isNull1) {
/* 031 */ values[0] = null;
/* 032 */ } else {
/* 033 */ values[0] = value1;
/* 034 */ }
/* 035 */
/* 036 */ boolean isNull4 = i.isNullAt(1);
/* 037 */ double value4 = isNull4 ? -1.0 : (i.getDouble(1));
/* 038 */
/* 039 */ boolean isNull3 = isNull4;
/* 040 */ java.math.BigDecimal value3 = null;
/* 041 */ if (!isNull3) {
/* 042 */
/* 043 */ Object funcResult = null;
/* 044 */ funcResult = value4.toJavaBigDecimal();
/* 045 */ if (funcResult == null) {
/* 046 */ isNull3 = true;
/* 047 */ } else {
/* 048 */ value3 = (java.math.BigDecimal) funcResult;
/* 049 */ }
/* 050 */
/* 051 */ }
/* 052 */ isNull3 = value3 == null;
/* 053 */ if (isNull3) {
/* 054 */ values[1] = null;
/* 055 */ } else {
/* 056 */ values[1] = value3;
/* 057 */ }
/* 058 */
/* 059 */ boolean isNull6 = i.isNullAt(2);
/* 060 */ UTF8String value6 = isNull6 ? null : (i.getUTF8String(2));
/* 061 */ boolean isNull5 = isNull6;
/* 062 */ final java.sql.Timestamp value5 = isNull5 ? null : org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaTimestamp(value6);
/* 063 */ isNull5 = value5 == null;
/* 064 */ if (isNull5) {
/* 065 */ values[2] = null;
/* 066 */ } else {
/* 067 */ values[2] = value5;
/* 068 */ }
/* 069 */
/* 070 */ final org.apache.spark.sql.Row value = new org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema(values, schema);
/* 071 */ if (false) {
/* 072 */ mutableRow.setNullAt(0);
/* 073 */ } else {
/* 074 */
/* 075 */ mutableRow.update(0, value);
/* 076 */ }
/* 077 */
/* 078 */ return mutableRow;
/* 079 */ }
/* 080 */ }
```
**After**
```bash
[1969-12-31 00:00:00.0,1.0,1969-12-31 16:00:00.002]
[1969-12-31 00:00:00.0,4.0,1969-12-31 16:00:00.005]
[1969-12-31 16:00:00.002,10.5,string]
```
## How was this patch tested?
Unit tests in `DataFrameSuite`
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #15072 from HyukjinKwon/SPARK-17123.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 18 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 16 |
2 files changed, 30 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 073d2b1512..286d8549bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -556,7 +556,7 @@ class Dataset[T] private[sql]( * 1983 03 0.410516 0.442194 * 1984 04 0.450090 0.483521 * }}} - * + * * @param numRows Number of rows to show * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. @@ -1524,7 +1524,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. CombineUnions(Union(logicalPlan, other.logicalPlan)) @@ -1540,7 +1540,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { + def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { Intersect(logicalPlan, other.logicalPlan) } @@ -1554,7 +1554,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def except(other: Dataset[T]): Dataset[T] = withTypedPlan { + def except(other: Dataset[T]): Dataset[T] = withSetOperator { Except(logicalPlan, other.logicalPlan) } @@ -2725,4 +2725,14 @@ class Dataset[T] private[sql]( @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } + + /** A convenient function to wrap a set based logical plan and produce a Dataset. */ + @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + // Set operators widen types (change the schema), so we cannot reuse the row encoder. + Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] + } else { + Dataset(sparkSession, logicalPlan) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 16cc368208..e87baa454c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import java.util.UUID import scala.util.Random @@ -1615,4 +1616,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { qe.assertAnalyzed() } } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } } |