diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-05-09 15:05:06 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-05-09 15:05:06 +0800 |
commit | 635ef407e11dec41ae9bc428935fb8fdaa482f7e (patch) | |
tree | 436e0a6ad3ce481bec73635fb86c53242a751ea5 /mllib/src | |
parent | a59ab594cac5189ecf4158fc0ada200eaa874158 (diff) | |
download | spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.gz spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.bz2 spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.zip |
[SPARK-15211][SQL] Select features column from LibSVMRelation causes failure
## What changes were proposed in this pull request?
We need to use `requiredSchema` in `LibSVMRelation` to project the fetch required columns when loading data from this data source. Otherwise, when users try to select `features` column, it will cause failure.
## How was this patch tested?
`LibSVMRelationSuite`.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #12986 from viirya/fix-libsvmrelation.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 10 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala | 1 |
2 files changed, 10 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5f78fab4dd..68a855c99f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -203,10 +203,18 @@ class DefaultSource extends FileFormat with DataSourceRegister { } val converter = RowEncoder(dataSchema) + val fullOutput = dataSchema.map { f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + } + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) + } + + val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) points.map { pt => val features = if (sparse) pt.features.toSparse else pt.features.toDense - converter.toRow(Row(pt.label, features)) + requiredColumns(converter.toRow(Row(pt.label, features))) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e52fbd74a7..1d7144f4e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -108,5 +108,6 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("select features from libsvm relation") { val df = sqlContext.read.format("libsvm").load(path) df.select("features").rdd.map { case Row(d: Vector) => d }.first + df.select("features").collect } } |