aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-02-12 11:54:58 -0800
committerReynold Xin <rxin@databricks.com>2016-02-12 11:54:58 -0800
commitac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd (patch)
tree8b8a0d66c0c6c48b6626b7654e848f6ced61e04d
parentc4d5ad80c8091c961646a82e85ecbc335b8ffe2d (diff)
downloadspark-ac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd.tar.gz
spark-ac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd.tar.bz2
spark-ac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd.zip
[SPARK-13260][SQL] count(*) does not work with CSV data source
https://issues.apache.org/jira/browse/SPARK-13260 This is a quicky fix for `count(*)`. When the `requiredColumns` is empty, currently it returns `sqlContext.sparkContext.emptyRDD[Row]` which does not have the count. Just like JSON datasource, this PR lets the CSV datasource count the rows but do not parse each set of tokens. Author: hyukjinkwon <gurwls223@gmail.com> Closes #11169 from HyukjinKwon/SPARK-13260.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala84
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala2
2 files changed, 41 insertions, 45 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index dc449fea95..f8e3a1b6d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -197,52 +197,48 @@ object CSVRelation extends Logging {
} else {
requiredFields
}
- if (requiredColumns.isEmpty) {
- sqlContext.sparkContext.emptyRDD[Row]
- } else {
- val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
- schemaFields.zipWithIndex.filter {
- case (field, _) => safeRequiredFields.contains(field)
- }.foreach {
- case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
- }
- val rowArray = new Array[Any](safeRequiredIndices.length)
- val requiredSize = requiredFields.length
- tokenizedRDD.flatMap { tokens =>
- if (params.dropMalformed && schemaFields.length != tokens.size) {
- logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
- None
- } else if (params.failFast && schemaFields.length != tokens.size) {
- throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
- s"${tokens.mkString(params.delimiter.toString)}")
+ val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
+ schemaFields.zipWithIndex.filter {
+ case (field, _) => safeRequiredFields.contains(field)
+ }.foreach {
+ case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
+ }
+ val rowArray = new Array[Any](safeRequiredIndices.length)
+ val requiredSize = requiredFields.length
+ tokenizedRDD.flatMap { tokens =>
+ if (params.dropMalformed && schemaFields.length != tokens.size) {
+ logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
+ } else if (params.failFast && schemaFields.length != tokens.size) {
+ throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
+ s"${tokens.mkString(params.delimiter.toString)}")
+ } else {
+ val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
+ tokens ++ new Array[String](schemaFields.length - tokens.size)
+ } else if (params.permissive && schemaFields.length < tokens.size) {
+ tokens.take(schemaFields.length)
} else {
- val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
- tokens ++ new Array[String](schemaFields.length - tokens.size)
- } else if (params.permissive && schemaFields.length < tokens.size) {
- tokens.take(schemaFields.length)
- } else {
- tokens
- }
- try {
- var index: Int = 0
- var subIndex: Int = 0
- while (subIndex < safeRequiredIndices.length) {
- index = safeRequiredIndices(subIndex)
- val field = schemaFields(index)
- rowArray(subIndex) = CSVTypeCast.castTo(
- indexSafeTokens(index),
- field.dataType,
- field.nullable,
- params.nullValue)
- subIndex = subIndex + 1
- }
- Some(Row.fromSeq(rowArray.take(requiredSize)))
- } catch {
- case NonFatal(e) if params.dropMalformed =>
- logWarning("Parse exception. " +
- s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
- None
+ tokens
+ }
+ try {
+ var index: Int = 0
+ var subIndex: Int = 0
+ while (subIndex < safeRequiredIndices.length) {
+ index = safeRequiredIndices(subIndex)
+ val field = schemaFields(index)
+ rowArray(subIndex) = CSVTypeCast.castTo(
+ indexSafeTokens(index),
+ field.dataType,
+ field.nullable,
+ params.nullValue)
+ subIndex = subIndex + 1
}
+ Some(Row.fromSeq(rowArray.take(requiredSize)))
+ } catch {
+ case NonFatal(e) if params.dropMalformed =>
+ logWarning("Parse exception. " +
+ s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index fa4f137b70..9d1f4569ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -56,7 +56,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val numRows = if (withHeader) numCars else numCars + 1
// schema
assert(df.schema.fieldNames.length === numColumns)
- assert(df.collect().length === numRows)
+ assert(df.count === numRows)
if (checkHeader) {
if (withHeader) {