aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala8
1 files changed, 7 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index fff86686b5..5e9e6ff1a5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -46,12 +47,17 @@ private[libsvm] class LibSVMOutputWriter(
context: TaskAttemptContext)
extends OutputWriter {
+ override val path: String = {
+ val compressionExtension = TextOutputWriter.getCompressionExtension(context)
+ new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString
+ }
+
private[this] val buffer = new Text()
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(stagingDir, fileNamePrefix + extension)
+ new Path(path)
}
}.getRecordWriter(context)
}