aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala84
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala2
2 files changed, 41 insertions, 45 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index dc449fea95..f8e3a1b6d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -197,52 +197,48 @@ object CSVRelation extends Logging {
} else {
requiredFields
}
- if (requiredColumns.isEmpty) {
- sqlContext.sparkContext.emptyRDD[Row]
- } else {
- val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
- schemaFields.zipWithIndex.filter {
- case (field, _) => safeRequiredFields.contains(field)
- }.foreach {
- case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
- }
- val rowArray = new Array[Any](safeRequiredIndices.length)
- val requiredSize = requiredFields.length
- tokenizedRDD.flatMap { tokens =>
- if (params.dropMalformed && schemaFields.length != tokens.size) {
- logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
- None
- } else if (params.failFast && schemaFields.length != tokens.size) {
- throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
- s"${tokens.mkString(params.delimiter.toString)}")
+ val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
+ schemaFields.zipWithIndex.filter {
+ case (field, _) => safeRequiredFields.contains(field)
+ }.foreach {
+ case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
+ }
+ val rowArray = new Array[Any](safeRequiredIndices.length)
+ val requiredSize = requiredFields.length
+ tokenizedRDD.flatMap { tokens =>
+ if (params.dropMalformed && schemaFields.length != tokens.size) {
+ logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
+ } else if (params.failFast && schemaFields.length != tokens.size) {
+ throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
+ s"${tokens.mkString(params.delimiter.toString)}")
+ } else {
+ val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
+ tokens ++ new Array[String](schemaFields.length - tokens.size)
+ } else if (params.permissive && schemaFields.length < tokens.size) {
+ tokens.take(schemaFields.length)
} else {
- val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
- tokens ++ new Array[String](schemaFields.length - tokens.size)
- } else if (params.permissive && schemaFields.length < tokens.size) {
- tokens.take(schemaFields.length)
- } else {
- tokens
- }
- try {
- var index: Int = 0
- var subIndex: Int = 0
- while (subIndex < safeRequiredIndices.length) {
- index = safeRequiredIndices(subIndex)
- val field = schemaFields(index)
- rowArray(subIndex) = CSVTypeCast.castTo(
- indexSafeTokens(index),
- field.dataType,
- field.nullable,
- params.nullValue)
- subIndex = subIndex + 1
- }
- Some(Row.fromSeq(rowArray.take(requiredSize)))
- } catch {
- case NonFatal(e) if params.dropMalformed =>
- logWarning("Parse exception. " +
- s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
- None
+ tokens
+ }
+ try {
+ var index: Int = 0
+ var subIndex: Int = 0
+ while (subIndex < safeRequiredIndices.length) {
+ index = safeRequiredIndices(subIndex)
+ val field = schemaFields(index)
+ rowArray(subIndex) = CSVTypeCast.castTo(
+ indexSafeTokens(index),
+ field.dataType,
+ field.nullable,
+ params.nullValue)
+ subIndex = subIndex + 1
}
+ Some(Row.fromSeq(rowArray.take(requiredSize)))
+ } catch {
+ case NonFatal(e) if params.dropMalformed =>
+ logWarning("Parse exception. " +
+ s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index fa4f137b70..9d1f4569ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -56,7 +56,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val numRows = if (withHeader) numCars else numCars + 1
// schema
assert(df.schema.fieldNames.length === numColumns)
- assert(df.collect().length === numRows)
+ assert(df.count === numRows)
if (checkHeader) {
if (withHeader) {