diff options
Diffstat (limited to 'sql/hive/src/test')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala | 7 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala | 8 |
2 files changed, 8 insertions, 7 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index abc7c8cc4d..7501334f94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(path, dataSchema, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5fdf615259..1607c97cd6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(path, dataSchema, context) } override def getFileExtension(context: TaskAttemptContext): String = "" @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val serialized = row.toSeq.map { v => + override def write(row: InternalRow): Unit = { + val serialized = row.toSeq(dataSchema).map { v => if (v == null) "" else v.toString }.mkString(",") |