From 9b23e92c727881ff9038b4fe9643c49b96914159 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 22:10:24 -0700 Subject: [SPARK-8803] handle special characters in elements in crosstab cc rxin Having back ticks or null as elements causes problems. Since elements become column names, we have to drop them from the element as back ticks are special characters. Having null throws exceptions, we could replace them with empty strings. Handling back ticks should be improved for 1.5 Author: Burak Yavuz Closes #7201 from brkyvz/weird-ct-elements and squashes the following commits: e06b840 [Burak Yavuz] fix scalastyle 93a0d3f [Burak Yavuz] added tests for NaN and Infinity 9dba6ce [Burak Yavuz] address cr1 db71dbd [Burak Yavuz] handle special characters in elements in crosstab --- .../apache/spark/sql/DataFrameNaFunctions.scala | 2 +- .../apache/spark/sql/DataFrameStatFunctions.scala | 3 +++ .../spark/sql/execution/stat/StatFunctions.scala | 20 ++++++++++++--- .../org/apache/spark/sql/DataFrameStatSuite.scala | 30 ++++++++++++++++++++++ 4 files changed, 50 insertions(+), 5 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b4c2daa055..8681a56c82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7bba..587869e57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if they + * exist. + * * * @param col1 The name of the first column. Distinct items will make the first item of * each row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 23ddfa9839..00231d65a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") } + def cleanElement(element: Any): String = { + if (element == null) "null" else element.toString + } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging { // row.get(0) is column 1 // row.get(1) is column 2 // row.get(2) is the frequency - countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) + val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get + countsRow.setLong(columnIndex + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(col1Item.toString)) + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) countsRow }.toSeq + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => + StructField(cleanColumnName(r._1.toString), LongType) + } val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 765094da6b..7ba4ba73e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite { } } + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } + test("Frequent Items") { val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) -- cgit v1.2.3