aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-07-02 22:10:24 -0700
committerReynold Xin <rxin@databricks.com>2015-07-02 22:12:47 -0700
commitff76b33b67cec0614143e60f8dc88435521cf3aa (patch)
tree8feaa64e1e5be7ea98d321147c0280b8272f5f34
parentf142867ecee59a635df91aee888351bee5f29c0e (diff)
downloadspark-ff76b33b67cec0614143e60f8dc88435521cf3aa.tar.gz
spark-ff76b33b67cec0614143e60f8dc88435521cf3aa.tar.bz2
spark-ff76b33b67cec0614143e60f8dc88435521cf3aa.zip
[SPARK-8803] handle special characters in elements in crosstab
cc rxin Having back ticks or null as elements causes problems. Since elements become column names, we have to drop them from the element as back ticks are special characters. Having null throws exceptions, we could replace them with empty strings. Handling back ticks should be improved for 1.5 Author: Burak Yavuz <brkyvz@gmail.com> Closes #7201 from brkyvz/weird-ct-elements and squashes the following commits: e06b840 [Burak Yavuz] fix scalastyle 93a0d3f [Burak Yavuz] added tests for NaN and Infinity 9dba6ce [Burak Yavuz] address cr1 db71dbd [Burak Yavuz] handle special characters in elements in crosstab (cherry picked from commit 9b23e92c727881ff9038b4fe9643c49b96914159) Signed-off-by: Reynold Xin <rxin@databricks.com> Conflicts: sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala30
4 files changed, 50 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index b4c2daa055..8681a56c82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
- coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
+ coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index edb9ed7bba..587869e57f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* The first column of each row will be the distinct values of `col1` and the column names will
* be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts
* will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
+ * Null elements will be replaced by "null", and back ticks will be dropped from elements if they
+ * exist.
+ *
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index e4a525e376..5a0c9a66b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -109,8 +109,12 @@ private[sql] object StatFunctions extends Logging {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
+ def cleanElement(element: Any): String = {
+ if (element == null) "null" else element.toString
+ }
// get the distinct values of column 2, so that we can make them the column names
- val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
+ val distinctCol2: Map[Any, Int] =
+ counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
@@ -120,15 +124,23 @@ private[sql] object StatFunctions extends Logging {
// row.get(0) is column 1
// row.get(1) is column 2
// row.get(2) is the frequency
- countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
+ val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
+ countsRow.setLong(columnIndex + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
- countsRow.setString(0, col1Item.toString)
+ countsRow.setString(0, cleanElement(col1Item.toString))
countsRow
}.toSeq
+ // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
+ // special keywords and `.`, wrap the column names in ``.
+ def cleanColumnName(name: String): String = {
+ name.replace("`", "")
+ }
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
- val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
+ val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
+ StructField(cleanColumnName(r._1.toString), LongType)
+ }
val schema = StructType(StructField(tableName, StringType) +: headerNames)
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 3e87ebad08..7e2605bc68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite {
}
}
+ test("special crosstab elements (., '', null, ``)") {
+ val data = Seq(
+ ("a", Double.NaN, "ho"),
+ (null, 2.0, "ho"),
+ ("a.b", Double.NegativeInfinity, ""),
+ ("b", Double.PositiveInfinity, "`ha`"),
+ ("a", 1.0, null)
+ )
+ val df = data.toDF("1", "2", "3")
+ val ct1 = df.stat.crosstab("1", "2")
+ // column fields should be 1 + distinct elements of second column
+ assert(ct1.schema.fields.length === 6)
+ assert(ct1.collect().length === 4)
+ val ct2 = df.stat.crosstab("1", "3")
+ assert(ct2.schema.fields.length === 5)
+ assert(ct2.schema.fieldNames.contains("ha"))
+ assert(ct2.collect().length === 4)
+ val ct3 = df.stat.crosstab("3", "2")
+ assert(ct3.schema.fields.length === 6)
+ assert(ct3.schema.fieldNames.contains("NaN"))
+ assert(ct3.schema.fieldNames.contains("Infinity"))
+ assert(ct3.schema.fieldNames.contains("-Infinity"))
+ assert(ct3.collect().length === 4)
+ val ct4 = df.stat.crosstab("3", "1")
+ assert(ct4.schema.fields.length === 5)
+ assert(ct4.schema.fieldNames.contains("null"))
+ assert(ct4.schema.fieldNames.contains("a.b"))
+ assert(ct4.collect().length === 4)
+ }
+
test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)