aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala9
2 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index e8b0dd61f3..dc2a6f5275 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -202,7 +202,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}
- val converter = RowEncoder(requiredSchema)
+ val converter = RowEncoder(dataSchema)
val unsafeRowIterator = points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
@@ -213,9 +213,12 @@ class DefaultSource extends FileFormat with DataSourceRegister {
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
// Appends partition values
- val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+ val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
+ val requiredOutput = fullOutput.filter { a =>
+ requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name)
+ }
val joinedRow = new JoinedRow()
- val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
unsafeRowIterator.map { dataRow =>
appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 0bd14978b2..e52fbd74a7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -23,9 +23,9 @@ import java.nio.charset.StandardCharsets
import com.google.common.io.Files
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.util.Utils
@@ -104,4 +104,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
df.write.format("libsvm").save(path + "_2")
}
}
+
+ test("select features from libsvm relation") {
+ val df = sqlContext.read.format("libsvm").load(path)
+ df.select("features").rdd.map { case Row(d: Vector) => d }.first
+ }
}