aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala30
4 files changed, 50 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index b4c2daa055..8681a56c82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
- coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
+ coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index edb9ed7bba..587869e57f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* The first column of each row will be the distinct values of `col1` and the column names will
* be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts
* will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
+ * Null elements will be replaced by "null", and back ticks will be dropped from elements if they
+ * exist.
+ *
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 23ddfa9839..00231d65a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
+ def cleanElement(element: Any): String = {
+ if (element == null) "null" else element.toString
+ }
// get the distinct values of column 2, so that we can make them the column names
- val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
+ val distinctCol2: Map[Any, Int] =
+ counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
@@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging {
// row.get(0) is column 1
// row.get(1) is column 2
// row.get(2) is the frequency
- countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
+ val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
+ countsRow.setLong(columnIndex + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
- countsRow.update(0, UTF8String.fromString(col1Item.toString))
+ countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
countsRow
}.toSeq
+ // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
+ // special keywords and `.`, wrap the column names in ``.
+ def cleanColumnName(name: String): String = {
+ name.replace("`", "")
+ }
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
- val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
+ val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
+ StructField(cleanColumnName(r._1.toString), LongType)
+ }
val schema = StructType(StructField(tableName, StringType) +: headerNames)
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 765094da6b..7ba4ba73e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite {
}
}
+ test("special crosstab elements (., '', null, ``)") {
+ val data = Seq(
+ ("a", Double.NaN, "ho"),
+ (null, 2.0, "ho"),
+ ("a.b", Double.NegativeInfinity, ""),
+ ("b", Double.PositiveInfinity, "`ha`"),
+ ("a", 1.0, null)
+ )
+ val df = data.toDF("1", "2", "3")
+ val ct1 = df.stat.crosstab("1", "2")
+ // column fields should be 1 + distinct elements of second column
+ assert(ct1.schema.fields.length === 6)
+ assert(ct1.collect().length === 4)
+ val ct2 = df.stat.crosstab("1", "3")
+ assert(ct2.schema.fields.length === 5)
+ assert(ct2.schema.fieldNames.contains("ha"))
+ assert(ct2.collect().length === 4)
+ val ct3 = df.stat.crosstab("3", "2")
+ assert(ct3.schema.fields.length === 6)
+ assert(ct3.schema.fieldNames.contains("NaN"))
+ assert(ct3.schema.fieldNames.contains("Infinity"))
+ assert(ct3.schema.fieldNames.contains("-Infinity"))
+ assert(ct3.collect().length === 4)
+ val ct4 = df.stat.crosstab("3", "1")
+ assert(ct4.schema.fields.length === 5)
+ assert(ct4.schema.fieldNames.contains("null"))
+ assert(ct4.schema.fieldNames.contains("a.b"))
+ assert(ct4.collect().length === 4)
+ }
+
test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)