aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-13 11:13:09 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-13 11:13:09 -0800
commit23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3 (patch)
tree637f94c371b0f360eddc48837190015089fda784
parenta24477996e936b0861819ffb420f763f80f0b1da (diff)
downloadspark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.tar.gz
spark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.tar.bz2
spark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.zip
[SPARK-11654][SQL][FOLLOW-UP] fix some mistakes and clean up
* rename `AppendColumn` to `AppendColumns` to be consistent with the physical plan name. * clean up stale comments. * always pass in resolved encoder to `TypedColumn.withInputType`(test added) * enable a mistakenly disabled java test. Author: Wenchen Fan <wenchen@databricks.com> Closes #9688 from cloud-fan/follow.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala4
7 files changed, 17 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d9f046efce..e2b97b27a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -482,13 +482,13 @@ case class MapPartitions[T, U](
}
/** Factory for constructing new `AppendColumn` nodes. */
-object AppendColumn {
+object AppendColumns {
def apply[T, U : Encoder](
func: T => U,
tEncoder: ExpressionEncoder[T],
- child: LogicalPlan): AppendColumn[T, U] = {
+ child: LogicalPlan): AppendColumns[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
- new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
+ new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}
@@ -497,7 +497,7 @@ object AppendColumn {
* resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
* decode/encode from the JVM object representation expected by `func.`
*/
-case class AppendColumn[T, U](
+case class AppendColumns[T, U](
func: T => U,
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 929224460d..82e9cd7f50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -58,10 +58,11 @@ class TypedColumn[-T, U](
private[sql] def withInputType(
inputEncoder: ExpressionEncoder[_],
schema: Seq[Attribute]): TypedColumn[T, U] = {
+ val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
new TypedColumn[T, U] (expr transform {
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(
- aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
+ aEncoder = Some(boundEncoder),
children = schema)
}, encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b930e4661c..4cc3aa2465 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -299,7 +299,7 @@ class Dataset[T] private[sql](
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
- val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan)
+ val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
@@ -364,13 +364,11 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- // We use an unbound encoder since the expression will make up its own schema.
- // TODO: This probably doesn't work if we are relying on reordering of the input class fields.
new Dataset[U1](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder.bind(queryExecution.analyzed.output),
+ resolvedTEncoder,
queryExecution.analyzed.output).named :: Nil,
logicalPlan))
}
@@ -382,10 +380,8 @@ class Dataset[T] private[sql](
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
- // We use an unbound encoder since the expression will make up its own schema.
- // TODO: This probably doesn't work if we are relying on reordering of the input class fields.
val namedColumns =
- columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named)
+ columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index ae1272ae53..9c16940707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -89,7 +89,7 @@ class GroupedDataset[K, T] private[sql](
}
/**
- * Applies the given function to each group of data. For each unique group, the function will
+ * Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
@@ -162,7 +162,7 @@ class GroupedDataset[K, T] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(
- _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named)
+ _.withInputType(resolvedTEncoder, dataAttributes).named)
val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a99ae4674b..67201a2c19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -321,7 +321,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
- case logical.AppendColumn(f, tEnc, uEnc, newCol, child) =>
+ case logical.AppendColumns(f, tEnc, uEnc, newCol, child) =>
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 46169ca07d..eb6fa1e72e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -157,6 +157,7 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(6, reduced);
}
+ @Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 20896efdfe..46f9f077fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -162,6 +162,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
1)
checkAnswer(
+ ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn),
+ (1.0, 1))
+
+ checkAnswer(
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}