From ac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 12 Feb 2016 11:54:58 -0800 Subject: [SPARK-13260][SQL] count(*) does not work with CSV data source https://issues.apache.org/jira/browse/SPARK-13260 This is a quicky fix for `count(*)`. When the `requiredColumns` is empty, currently it returns `sqlContext.sparkContext.emptyRDD[Row]` which does not have the count. Just like JSON datasource, this PR lets the CSV datasource count the rows but do not parse each set of tokens. Author: hyukjinkwon Closes #11169 from HyukjinKwon/SPARK-13260. --- .../execution/datasources/csv/CSVRelation.scala | 84 +++++++++++----------- .../sql/execution/datasources/csv/CSVSuite.scala | 2 +- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index dc449fea95..f8e3a1b6d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -197,52 +197,48 @@ object CSVRelation extends Logging { } else { requiredFields } - if (requiredColumns.isEmpty) { - sqlContext.sparkContext.emptyRDD[Row] - } else { - val safeRequiredIndices = new Array[Int](safeRequiredFields.length) - schemaFields.zipWithIndex.filter { - case (field, _) => safeRequiredFields.contains(field) - }.foreach { - case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index - } - val rowArray = new Array[Any](safeRequiredIndices.length) - val requiredSize = requiredFields.length - tokenizedRDD.flatMap { tokens => - if (params.dropMalformed && schemaFields.length != tokens.size) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - None - } else if (params.failFast && schemaFields.length != tokens.size) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(params.delimiter.toString)}") + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) + schemaFields.zipWithIndex.filter { + case (field, _) => safeRequiredFields.contains(field) + }.foreach { + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + } + val rowArray = new Array[Any](safeRequiredIndices.length) + val requiredSize = requiredFields.length + tokenizedRDD.flatMap { tokens => + if (params.dropMalformed && schemaFields.length != tokens.size) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } else if (params.failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(params.delimiter.toString)}") + } else { + val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { + tokens ++ new Array[String](schemaFields.length - tokens.size) + } else if (params.permissive && schemaFields.length < tokens.size) { + tokens.take(schemaFields.length) } else { - val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { - tokens ++ new Array[String](schemaFields.length - tokens.size) - } else if (params.permissive && schemaFields.length < tokens.size) { - tokens.take(schemaFields.length) - } else { - tokens - } - try { - var index: Int = 0 - var subIndex: Int = 0 - while (subIndex < safeRequiredIndices.length) { - index = safeRequiredIndices(subIndex) - val field = schemaFields(index) - rowArray(subIndex) = CSVTypeCast.castTo( - indexSafeTokens(index), - field.dataType, - field.nullable, - params.nullValue) - subIndex = subIndex + 1 - } - Some(Row.fromSeq(rowArray.take(requiredSize))) - } catch { - case NonFatal(e) if params.dropMalformed => - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - None + tokens + } + try { + var index: Int = 0 + var subIndex: Int = 0 + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) + val field = schemaFields(index) + rowArray(subIndex) = CSVTypeCast.castTo( + indexSafeTokens(index), + field.dataType, + field.nullable, + params.nullValue) + subIndex = subIndex + 1 } + Some(Row.fromSeq(rowArray.take(requiredSize))) + } catch { + case NonFatal(e) if params.dropMalformed => + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index fa4f137b70..9d1f4569ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -56,7 +56,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val numRows = if (withHeader) numCars else numCars + 1 // schema assert(df.schema.fieldNames.length === numColumns) - assert(df.collect().length === numRows) + assert(df.count === numRows) if (checkHeader) { if (withHeader) { -- cgit v1.2.3