diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index fff86686b5..5e9e6ff1a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -46,12 +47,17 @@ private[libsvm] class LibSVMOutputWriter( context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString + } + private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, fileNamePrefix + extension) + new Path(path) } }.getRecordWriter(context) } |