diff options
author | Sean Zhong <seanzhong@databricks.com> | 2016-08-04 19:45:47 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-08-04 19:45:47 +0800 |
commit | 9d7a47406ed538f0005cdc7a62bc6e6f20634815 (patch) | |
tree | b015623f922bd4646c2e80c9e27fd2349d04a469 | |
parent | 43f4fd6f9bfff749af17e3c65b53a33f5ecb0922 (diff) | |
download | spark-9d7a47406ed538f0005cdc7a62bc6e6f20634815.tar.gz spark-9d7a47406ed538f0005cdc7a62bc6e6f20634815.tar.bz2 spark-9d7a47406ed538f0005cdc7a62bc6e6f20634815.zip |
[SPARK-16853][SQL] fixes encoder error in DataSet typed select
## What changes were proposed in this pull request?
For DataSet typed select:
```
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1]
```
If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException.
### Before change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A])
org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2];
..
```
### After change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show
+---+---+
| a| b|
+---+---+
| 1| 2|
+---+---+
```
## How was this patch tested?
Unit test.
Author: Sean Zhong <seanzhong@databricks.com>
Closes #14474 from clockfly/SPARK-16853.
4 files changed, 29 insertions, 10 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56061559fe..a201d7f838 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,7 +38,9 @@ object MimaExcludes { lazy val v21excludes = v20excludes ++ { Seq( // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"), + // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select") ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1fac26c438..b96b744b4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -169,6 +169,10 @@ object ExpressionEncoder { ClassTag(cls)) } + // Tuple1 + def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] = + tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]] + def tuple[T1, T2]( e1: ExpressionEncoder[T1], e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b6443c8b9..306ca773d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1061,15 +1061,17 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1]( - sparkSession, - Project( - c1.withInputType( - exprEnc.deserializer, - logicalPlan.output).named :: Nil, - logicalPlan), - implicitly[Encoder[U1]]) + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + implicit val encoder = c1.encoder + val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil, + logicalPlan) + + if (encoder.flat) { + new Dataset[U1](sparkSession, project, encoder) + } else { + // Flattens inner fields of U1 + new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7e3b7b63d8..8a756fd474 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("SPARK-16853: select, case class and tuple") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( + ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)], + (1, 1), (2, 2), (3, 3)) + + checkDataset( + ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData], + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + } + test("select 2") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( |