aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-05-09 15:05:06 +0800
committerCheng Lian <lian@databricks.com>2016-05-09 15:05:06 +0800
commit635ef407e11dec41ae9bc428935fb8fdaa482f7e (patch)
tree436e0a6ad3ce481bec73635fb86c53242a751ea5 /mllib/src
parenta59ab594cac5189ecf4158fc0ada200eaa874158 (diff)
downloadspark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.gz
spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.tar.bz2
spark-635ef407e11dec41ae9bc428935fb8fdaa482f7e.zip
[SPARK-15211][SQL] Select features column from LibSVMRelation causes failure
## What changes were proposed in this pull request? We need to use `requiredSchema` in `LibSVMRelation` to project the fetch required columns when loading data from this data source. Otherwise, when users try to select `features` column, it will cause failure. ## How was this patch tested? `LibSVMRelationSuite`. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #12986 from viirya/fix-libsvmrelation.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala1
2 files changed, 10 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 5f78fab4dd..68a855c99f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -203,10 +203,18 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
val converter = RowEncoder(dataSchema)
+ val fullOutput = dataSchema.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+ }
+ val requiredOutput = fullOutput.filter { a =>
+ requiredSchema.fieldNames.contains(a.name)
+ }
+
+ val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
- converter.toRow(Row(pt.label, features))
+ requiredColumns(converter.toRow(Row(pt.label, features)))
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index e52fbd74a7..1d7144f4e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -108,5 +108,6 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("select features from libsvm relation") {
val df = sqlContext.read.format("libsvm").load(path)
df.select("features").rdd.map { case Row(d: Vector) => d }.first
+ df.select("features").collect
}
}