aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorazagrebin <azagrebin@gmail.com>2015-03-26 00:25:04 -0700
committerReynold Xin <rxin@databricks.com>2015-03-26 00:25:04 -0700
commit5bbcd1304cfebba31ec6857a80d3825a40d02e83 (patch)
tree1c623020abda55e9429c3957671538e709c6f1d1 /sql
parentf535802977c5a3ce45894d89fdf59f8723f023c8 (diff)
downloadspark-5bbcd1304cfebba31ec6857a80d3825a40d02e83.tar.gz
spark-5bbcd1304cfebba31ec6857a80d3825a40d02e83.tar.bz2
spark-5bbcd1304cfebba31ec6857a80d3825a40d02e83.zip
[SPARK-6117] [SQL] add describe function to DataFrame for summary statis...
Please review my solution for SPARK-6117 Author: azagrebin <azagrebin@gmail.com> Closes #5073 from azagrebin/SPARK-6117 and squashes the following commits: f9056ac [azagrebin] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case ddb3950 [azagrebin] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns 9daf31e [azagrebin] [SPARK-6117] [SQL] add describe function to DataFrame for summary statistics
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala45
2 files changed, 97 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 5aece166aa..db561825e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
@@ -752,6 +752,57 @@ class DataFrame private[sql](
}
/**
+ * Compute numerical statistics for given columns of this [[DataFrame]]:
+ * count, mean (avg), stddev (standard deviation), min, max.
+ * Each row of the resulting [[DataFrame]] contains column with statistic name
+ * and columns with statistic results for each given column.
+ * If no columns are given then computes for all numerical columns.
+ *
+ * {{{
+ * df.describe("age", "height")
+ *
+ * // summary age height
+ * // count 10.0 10.0
+ * // mean 53.3 178.05
+ * // stddev 11.6 15.7
+ * // min 18.0 163.0
+ * // max 92.0 192.0
+ * }}}
+ */
+ @scala.annotation.varargs
+ def describe(cols: String*): DataFrame = {
+
+ def stddevExpr(expr: Expression) =
+ Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
+
+ val statistics = List[(String, Expression => Expression)](
+ "count" -> Count,
+ "mean" -> Average,
+ "stddev" -> stddevExpr,
+ "min" -> Min,
+ "max" -> Max)
+
+ val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
+
+ val localAgg = if (aggCols.nonEmpty) {
+ val aggExprs = statistics.flatMap { case (_, colToAgg) =>
+ aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+ }
+
+ agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+ .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
+ Row(statistic :: aggregation.toList: _*)
+ }
+ } else {
+ statistics.map { case (name, _) => Row(name) }
+ }
+
+ val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
+ val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
+ sqlContext.createDataFrame(rowRdd, schema)
+ }
+
+ /**
* Returns the first `n` rows.
* @group action
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index c30ed694a6..afbedd1e58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -443,6 +443,51 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}
+ test("describe") {
+
+ val describeTestData = Seq(
+ ("Bob", 16, 176),
+ ("Alice", 32, 164),
+ ("David", 60, 192),
+ ("Amy", 24, 180)).toDF("name", "age", "height")
+
+ val describeResult = Seq(
+ Row("count", 4, 4),
+ Row("mean", 33.0, 178.0),
+ Row("stddev", 16.583123951777, 10.0),
+ Row("min", 16, 164),
+ Row("max", 60, 192))
+
+ val emptyDescribeResult = Seq(
+ Row("count", 0, 0),
+ Row("mean", null, null),
+ Row("stddev", null, null),
+ Row("min", null, null),
+ Row("max", null, null))
+
+ def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
+
+ val describeTwoCols = describeTestData.describe("age", "height")
+ assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
+ checkAnswer(describeTwoCols, describeResult)
+
+ val describeAllCols = describeTestData.describe()
+ assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
+ checkAnswer(describeAllCols, describeResult)
+
+ val describeOneCol = describeTestData.describe("age")
+ assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
+ checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )
+
+ val describeNoCol = describeTestData.select("name").describe()
+ assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
+ checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )
+
+ val emptyDescription = describeTestData.limit(0).describe()
+ assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
+ checkAnswer(emptyDescription, emptyDescribeResult)
+ }
+
test("apply on query results (SPARK-5462)") {
val df = testData.sqlContext.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)