aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--project/MimaExcludes.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala11
4 files changed, 29 insertions, 10 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 56061559fe..a201d7f838 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v21excludes = v20excludes ++ {
Seq(
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
+ // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 1fac26c438..b96b744b4f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -169,6 +169,10 @@ object ExpressionEncoder {
ClassTag(cls))
}
+ // Tuple1
+ def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
+ tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
+
def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 8b6443c8b9..306ca773d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](
- sparkSession,
- Project(
- c1.withInputType(
- exprEnc.deserializer,
- logicalPlan.output).named :: Nil,
- logicalPlan),
- implicitly[Encoder[U1]])
+ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
+ implicit val encoder = c1.encoder
+ val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+ logicalPlan)
+
+ if (encoder.flat) {
+ new Dataset[U1](sparkSession, project, encoder)
+ } else {
+ // Flattens inner fields of U1
+ new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7e3b7b63d8..8a756fd474 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
+ test("SPARK-16853: select, case class and tuple") {
+ val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+ checkDataset(
+ ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
+ (1, 1), (2, 2), (3, 3))
+
+ checkDataset(
+ ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
+ ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
+ }
+
test("select 2") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(