aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-03-26 12:26:13 -0700
committerReynold Xin <rxin@databricks.com>2015-03-26 12:26:13 -0700
commit784fcd532784fcfd9bf0a1db71c9f71c469ee716 (patch)
treed38d70a2d3c2b8aa14187e7b6dec0f7f8783374a /sql
parentc3a52a08248db08eade29b265f02483144a282d6 (diff)
downloadspark-784fcd532784fcfd9bf0a1db71c9f71c469ee716.tar.gz
spark-784fcd532784fcfd9bf0a1db71c9f71c469ee716.tar.bz2
spark-784fcd532784fcfd9bf0a1db71c9f71c469ee716.zip
[SPARK-6117] [SQL] Improvements to DataFrame.describe()
1. Slightly modifications to the code to make it more readable. 2. Added Python implementation. 3. Updated the documentation to state that we don't guarantee the output schema for this function and it should only be used for exploratory data analysis. Author: Reynold Xin <rxin@databricks.com> Closes #5201 from rxin/df-describe and squashes the following commits: 25a7834 [Reynold Xin] Reset run-tests. 6abdfee [Reynold Xin] [SPARK-6117] [SQL] Improvements to DataFrame.describe()
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala3
2 files changed, 29 insertions, 20 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index db561825e6..4c80359cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
@@ -752,15 +752,17 @@ class DataFrame private[sql](
}
/**
- * Compute numerical statistics for given columns of this [[DataFrame]]:
- * count, mean (avg), stddev (standard deviation), min, max.
- * Each row of the resulting [[DataFrame]] contains column with statistic name
- * and columns with statistic results for each given column.
- * If no columns are given then computes for all numerical columns.
+ * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
+ * If no columns are given, this function computes statistics for all numerical columns.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to
+ * programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
- * df.describe("age", "height")
+ * df.describe("age", "height").show()
*
+ * // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
@@ -768,13 +770,17 @@ class DataFrame private[sql](
* // min 18.0 163.0
* // max 92.0 192.0
* }}}
+ *
+ * @group action
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {
- def stddevExpr(expr: Expression) =
+ // TODO: Add stddev as an expression, and remove it from here.
+ def stddevExpr(expr: Expression): Expression =
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
+ // The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
@@ -782,24 +788,28 @@ class DataFrame private[sql](
"min" -> Min,
"max" -> Max)
- val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
+ val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
- val localAgg = if (aggCols.nonEmpty) {
+ val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
- aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+ outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
}
- agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
- .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
- Row(statistic :: aggregation.toList: _*)
+ val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+ // Pivot the data so each summary is one row
+ row.grouped(outputCols.size).toSeq.zip(statistics).map {
+ case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*)
}
} else {
+ // If there are no output columns, just output a single column that contains the stats.
statistics.map { case (name, _) => Row(name) }
}
- val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
- val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
- sqlContext.createDataFrame(rowRdd, schema)
+ // The first column is string type, and the rest are double type.
+ val schema = StructType(
+ StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes
+ LocalRelation(schema, ret)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index afbedd1e58..fbc4065a96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -444,7 +444,6 @@ class DataFrameSuite extends QueryTest {
}
test("describe") {
-
val describeTestData = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
@@ -465,7 +464,7 @@ class DataFrameSuite extends QueryTest {
Row("min", null, null),
Row("max", null, null))
- def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
+ def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
val describeTwoCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))