aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/DataFrame.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R4
-rw-r--r--python/pyspark/sql/dataframe.py8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala36
5 files changed, 39 insertions, 29 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a18eee3a0f..47f9203ace 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2587,8 +2587,8 @@ setMethod("saveAsTable",
#' summary
#'
-#' Computes statistics for numeric columns.
-#' If no columns are given, this function computes statistics for all numerical columns.
+#' Computes statistics for numeric and string columns.
+#' If no columns are given, this function computes statistics for all numerical or string columns.
#'
#' @param x A SparkDataFrame to be computed.
#' @param col A string of name
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index e2a1da0f1e..fdd6020db9 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1824,11 +1824,11 @@ test_that("describe() and summarize() on a DataFrame", {
expect_equal(collect(stats)[2, "age"], "24.5")
expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
stats <- describe(df)
- expect_equal(collect(stats)[4, "name"], NULL)
+ expect_equal(collect(stats)[4, "name"], "Andy")
expect_equal(collect(stats)[5, "age"], "30")
stats2 <- summary(df)
- expect_equal(collect(stats2)[4, "name"], NULL)
+ expect_equal(collect(stats2)[4, "name"], "Andy")
expect_equal(collect(stats2)[5, "age"], "30")
# SPARK-16425: SparkR summary() fails on column of type logical
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index dd670a9b3d..ab41e88620 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -751,15 +751,15 @@ class DataFrame(object):
@since("1.3.1")
def describe(self, *cols):
- """Computes statistics for numeric columns.
+ """Computes statistics for numeric and string columns.
This include count, mean, stddev, min, and max. If no columns are
- given, this function computes statistics for all numerical columns.
+ given, this function computes statistics for all numerical or string columns.
.. note:: This function is meant for exploratory data analysis, as we make no \
guarantee about the backward compatibility of the schema of the resulting DataFrame.
- >>> df.describe().show()
+ >>> df.describe(['age']).show()
+-------+------------------+
|summary| age|
+-------+------------------+
@@ -769,7 +769,7 @@ class DataFrame(object):
| min| 2|
| max| 5|
+-------+------------------+
- >>> df.describe(['age', 'name']).show()
+ >>> df.describe().show()
+-------+------------------+-----+
|summary| age| name|
+-------+------------------+-----+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ededf7f4fe..ed4ccdb4c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -228,6 +228,15 @@ class Dataset[T] private[sql](
}
}
+ private def aggregatableColumns: Seq[Expression] = {
+ schema.fields
+ .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
+ .map { n =>
+ queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver)
+ .get
+ }
+ }
+
/**
* Compose the string representing rows for output
*
@@ -1886,8 +1895,9 @@ class Dataset[T] private[sql](
}
/**
- * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
- * If no columns are given, this function computes statistics for all numerical columns.
+ * Computes statistics for numeric and string columns, including count, mean, stddev, min, and
+ * max. If no columns are given, this function computes statistics for all numerical or string
+ * columns.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
@@ -1920,7 +1930,7 @@ class Dataset[T] private[sql](
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))
val outputCols =
- (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList
+ (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList
val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9d53be8e2b..905da554f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -651,44 +651,44 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
("Amy", 24, 180)).toDF("name", "age", "height")
val describeResult = Seq(
- Row("count", "4", "4"),
- Row("mean", "33.0", "178.0"),
- Row("stddev", "19.148542155126762", "11.547005383792516"),
- Row("min", "16", "164"),
- Row("max", "60", "192"))
+ Row("count", "4", "4", "4"),
+ Row("mean", null, "33.0", "178.0"),
+ Row("stddev", null, "19.148542155126762", "11.547005383792516"),
+ Row("min", "Alice", "16", "164"),
+ Row("max", "David", "60", "192"))
val emptyDescribeResult = Seq(
- Row("count", "0", "0"),
- Row("mean", null, null),
- Row("stddev", null, null),
- Row("min", null, null),
- Row("max", null, null))
+ Row("count", "0", "0", "0"),
+ Row("mean", null, null, null),
+ Row("stddev", null, null, null),
+ Row("min", null, null, null),
+ Row("max", null, null, null))
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
- val describeTwoCols = describeTestData.describe("age", "height")
- assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
+ val describeTwoCols = describeTestData.describe("name", "age", "height")
+ assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeTwoCols, describeResult)
// All aggregate value should have been cast to string
describeTwoCols.collect().foreach { row =>
- assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass)
assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
+ assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass)
}
val describeAllCols = describeTestData.describe()
- assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
+ assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeAllCols, describeResult)
val describeOneCol = describeTestData.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
- checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )
+ checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
val describeNoCol = describeTestData.select("name").describe()
- assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
- checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )
+ assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name"))
+ checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} )
val emptyDescription = describeTestData.limit(0).describe()
- assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
+ assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}