aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-05-09 15:05:06 +0800
committerCheng Lian <lian@databricks.com>2016-05-09 15:05:06 +0800
commit635ef407e11dec41ae9bc428935fb8fdaa482f7e (patch)
tree436e0a6ad3ce481bec73635fb86c53242a751ea5 /mllib/src/main/scala
parenta59ab594cac5189ecf4158fc0ada200eaa874158 (diff)
downloadspark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.gz
spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.bz2
spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.zip
[SPARK-15211][SQL] Select features column from LibSVMRelation causes failure
## What changes were proposed in this pull request? We need to use `requiredSchema` in `LibSVMRelation` to project the fetch required columns when loading data from this data source. Otherwise, when users try to select `features` column, it will cause failure. ## How was this patch tested? `LibSVMRelationSuite`. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #12986 from viirya/fix-libsvmrelation.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala10
1 files changed, 9 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 5f78fab4dd..68a855c99f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -203,10 +203,18 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
val converter = RowEncoder(dataSchema)
+ val fullOutput = dataSchema.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+ }
+ val requiredOutput = fullOutput.filter { a =>
+ requiredSchema.fieldNames.contains(a.name)
+ }
+
+ val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
- converter.toRow(Row(pt.label, features))
+ requiredColumns(converter.toRow(Row(pt.label, features)))
}
}
}