aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-05-18 21:53:44 -0700
committerReynold Xin <rxin@databricks.com>2015-05-18 21:53:44 -0700
commitc9fa870a6de3f7d0903fa7a75ea5ffb6a2fcd174 (patch)
treefdb1f2695bb214e1252ec7f64f350ecf032f1e27 /sql
parentc2437de1899e09894df4ec27adfaa7fac158fd3a (diff)
downloadspark-c9fa870a6de3f7d0903fa7a75ea5ffb6a2fcd174.tar.gz
spark-c9fa870a6de3f7d0903fa7a75ea5ffb6a2fcd174.tar.bz2
spark-c9fa870a6de3f7d0903fa7a75ea5ffb6a2fcd174.zip
[SPARK-7687] [SQL] DataFrame.describe() should cast all aggregates to String
In `DataFrame.describe()`, the `count` aggregate produces an integer, the `avg` and `stdev` aggregates produce doubles, and `min` and `max` aggregates can produce varying types depending on what type of column they're applied to. As a result, we should cast all aggregate results to String so that `describe()`'s output types match its declared output schema. Author: Josh Rosen <joshrosen@databricks.com> Closes #6218 from JoshRosen/SPARK-7687 and squashes the following commits: 146b615 [Josh Rosen] Fix R test. 2974bd5 [Josh Rosen] Cast to string type instead f206580 [Josh Rosen] Cast to double to fix SPARK-7687 307ecbf [Josh Rosen] Add failing regression test for SPARK-7687
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala17
2 files changed, 14 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 27e9af49f0..adad85806d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1063,7 +1063,7 @@ class DataFrame private[sql](
val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
- outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+ outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}
val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
@@ -1077,9 +1077,9 @@ class DataFrame private[sql](
statistics.map { case (name, _) => Row(name) }
}
- // The first column is string type, and the rest are double type.
+ // All columns are string type
val schema = StructType(
- StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes
+ StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
LocalRelation(schema, ret)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f05d059d44..0dcba80ef2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -370,14 +370,14 @@ class DataFrameSuite extends QueryTest {
("Amy", 24, 180)).toDF("name", "age", "height")
val describeResult = Seq(
- Row("count", 4, 4),
- Row("mean", 33.0, 178.0),
- Row("stddev", 16.583123951777, 10.0),
- Row("min", 16, 164),
- Row("max", 60, 192))
+ Row("count", "4", "4"),
+ Row("mean", "33.0", "178.0"),
+ Row("stddev", "16.583123951777", "10.0"),
+ Row("min", "16", "164"),
+ Row("max", "60", "192"))
val emptyDescribeResult = Seq(
- Row("count", 0, 0),
+ Row("count", "0", "0"),
Row("mean", null, null),
Row("stddev", null, null),
Row("min", null, null),
@@ -388,6 +388,11 @@ class DataFrameSuite extends QueryTest {
val describeTwoCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
checkAnswer(describeTwoCols, describeResult)
+ // All aggregate value should have been cast to string
+ describeTwoCols.collect().foreach { row =>
+ assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass)
+ assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
+ }
val describeAllCols = describeTestData.describe()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))