aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-05-04 14:16:57 +0800
committerCheng Lian <lian@databricks.com>2016-05-04 14:16:57 +0800
commitbc3760d405cc8c3ffcd957b188afa8b7e3b1f824 (patch)
treeaa6fae43f4eb0e9e88a0f2574bb2fa619954f98a /mllib/src/main
parent695f0e9195209c75bfc62fc70bfc6d7d9f1047b3 (diff)
downloadspark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.tar.gz
spark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.tar.bz2
spark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.zip
[SPARK-14237][SQL] De-duplicate partition value appending logic in various buildReader() implementations
## What changes were proposed in this pull request? Currently, various `FileFormat` data sources share approximately the same code for partition value appending. This PR tries to eliminate this duplication. A new method `buildReaderWithPartitionValues()` is added to `FileFormat` with a default implementation that appends partition values to `InternalRow`s produced by the reader function returned by `buildReader()`. Special data sources like Parquet, which implements partition value appending inside `buildReader()` because of the vectorized reader, and the Text data source, which doesn't support partitioning, override `buildReaderWithPartitionValues()` and simply delegate to `buildReader()`. This PR brings two benefits: 1. Apparently, it de-duplicates partition value appending logic 2. Now the reader function returned by `buildReader()` is only required to produce `InternalRow`s rather than `UnsafeRow`s if the data source doesn't override `buildReaderWithPartitionValues()`. Because the safe-to-unsafe conversion is also performed while appending partition values. This makes 3rd-party data sources (e.g. spark-avro) easier to implement since they no longer need to access private APIs involving `UnsafeRow`. ## How was this patch tested? Existing tests should do the work. Author: Cheng Lian <lian@databricks.com> Closes #12866 from liancheng/spark-14237-simplify-partition-values-appending.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala17
1 files changed, 1 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index ba2e1e2bc2..5f78fab4dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -204,25 +204,10 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val converter = RowEncoder(dataSchema)
- val unsafeRowIterator = points.map { pt =>
+ points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
converter.toRow(Row(pt.label, features))
}
-
- def toAttribute(f: StructField): AttributeReference =
- AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
-
- // Appends partition values
- val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
- val requiredOutput = fullOutput.filter { a =>
- requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name)
- }
- val joinedRow = new JoinedRow()
- val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
- unsafeRowIterator.map { dataRow =>
- appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
- }
}
}
}