diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-04-23 01:11:36 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-04-23 01:11:36 +0800 |
commit | 8098f158576b07343f74e2061d217b106c71b62d (patch) | |
tree | 82622c423578c8b535cd486d4a83558f7e29f573 | |
parent | c089c6f4e83d85e622b8d13f466a656c2852702b (diff) | |
download | spark-8098f158576b07343f74e2061d217b106c71b62d.tar.gz spark-8098f158576b07343f74e2061d217b106c71b62d.tar.bz2 spark-8098f158576b07343f74e2061d217b106c71b62d.zip |
[SPARK-14843][ML] Fix encoding error in LibSVMRelation
## What changes were proposed in this pull request?
We use `RowEncoder` in libsvm data source to serialize the label and features read from libsvm files. However, the schema passed in this encoder is not correct. As the result, we can't correctly select `features` column from the DataFrame. We should use full data schema instead of `requiredSchema` to serialize the data read in. Then do projection to select required columns later.
## How was this patch tested?
`LibSVMRelationSuite`.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #12611 from viirya/fix-libsvm.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 9 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala | 9 |
2 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index e8b0dd61f3..dc2a6f5275 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -202,7 +202,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) } - val converter = RowEncoder(requiredSchema) + val converter = RowEncoder(dataSchema) val unsafeRowIterator = points.map { pt => val features = if (sparse) pt.features.toSparse else pt.features.toDense @@ -213,9 +213,12 @@ class DefaultSource extends FileFormat with DataSourceRegister { AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() // Appends partition values - val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute) + val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute) + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name) + } val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) unsafeRowIterator.map { dataRow => appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 0bd14978b2..e52fbd74a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -23,9 +23,9 @@ import java.nio.charset.StandardCharsets import com.google.common.io.Files import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.util.Utils @@ -104,4 +104,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { df.write.format("libsvm").save(path + "_2") } } + + test("select features from libsvm relation") { + val df = sqlContext.read.format("libsvm").load(path) + df.select("features").rdd.map { case Row(d: Vector) => d }.first + } } |