aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala3
3 files changed, 48 insertions, 20 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index bf7c47b726..d51309f7ef 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -520,6 +520,25 @@ class DataFrame(object):
orderBy = sort
+ def describe(self, *cols):
+ """Computes statistics for numeric columns.
+
+ This include count, mean, stddev, min, and max. If no columns are
+ given, this function computes statistics for all numerical columns.
+
+ >>> df.describe().show()
+ summary age
+ count 2
+ mean 3.5
+ stddev 1.5
+ min 2
+ max 5
+ """
+ cols = ListConverter().convert(cols,
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
+ return DataFrame(jdf, self.sql_ctx)
+
def head(self, n=None):
""" Return the first `n` rows or the first row if n is None.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index db561825e6..4c80359cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
@@ -752,15 +752,17 @@ class DataFrame private[sql](
}
/**
- * Compute numerical statistics for given columns of this [[DataFrame]]:
- * count, mean (avg), stddev (standard deviation), min, max.
- * Each row of the resulting [[DataFrame]] contains column with statistic name
- * and columns with statistic results for each given column.
- * If no columns are given then computes for all numerical columns.
+ * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
+ * If no columns are given, this function computes statistics for all numerical columns.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to
+ * programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
- * df.describe("age", "height")
+ * df.describe("age", "height").show()
*
+ * // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
@@ -768,13 +770,17 @@ class DataFrame private[sql](
* // min 18.0 163.0
* // max 92.0 192.0
* }}}
+ *
+ * @group action
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {
- def stddevExpr(expr: Expression) =
+ // TODO: Add stddev as an expression, and remove it from here.
+ def stddevExpr(expr: Expression): Expression =
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
+ // The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
@@ -782,24 +788,28 @@ class DataFrame private[sql](
"min" -> Min,
"max" -> Max)
- val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
+ val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
- val localAgg = if (aggCols.nonEmpty) {
+ val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
- aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+ outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
}
- agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
- .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
- Row(statistic :: aggregation.toList: _*)
+ val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+ // Pivot the data so each summary is one row
+ row.grouped(outputCols.size).toSeq.zip(statistics).map {
+ case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*)
}
} else {
+ // If there are no output columns, just output a single column that contains the stats.
statistics.map { case (name, _) => Row(name) }
}
- val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
- val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
- sqlContext.createDataFrame(rowRdd, schema)
+ // The first column is string type, and the rest are double type.
+ val schema = StructType(
+ StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes
+ LocalRelation(schema, ret)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index afbedd1e58..fbc4065a96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -444,7 +444,6 @@ class DataFrameSuite extends QueryTest {
}
test("describe") {
-
val describeTestData = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
@@ -465,7 +464,7 @@ class DataFrameSuite extends QueryTest {
Row("min", null, null),
Row("max", null, null))
- def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
+ def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
val describeTwoCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))