aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-06-16 10:24:29 -0700
committerWenchen Fan <wenchen@databricks.com>2016-06-16 10:24:29 -0700
commit9ea0d5e326e08b914aa46f1eec8795688a61bf74 (patch)
tree41b7df75ad7d557e09297cacbf4b48f000fa9706 /mllib
parent6451cf9270b55465d8ecea4c4031329a1058561a (diff)
downloadspark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.tar.gz
spark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.tar.bz2
spark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.zip
[SPARK-15983][SQL] Removes FileFormat.prepareRead
## What changes were proposed in this pull request? Interface method `FileFormat.prepareRead()` was added in #12088 to handle a special case in the LibSVM data source. However, the semantics of this interface method isn't intuitive: it returns a modified version of the data source options map. Considering that the LibSVM case can be easily handled using schema metadata inside `inferSchema`, we can remove this interface method to keep the `FileFormat` interface clean. ## How was this patch tested? Existing tests. Author: Cheng Lian <lian@databricks.com> Closes #13698 from liancheng/remove-prepare-read.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala33
1 files changed, 17 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 62e09d2e0c..4988dd66f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "LibSVM"
private def verifySchema(dataSchema: StructType): Unit = {
- if (dataSchema.size != 2 ||
- (!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
- || !dataSchema(1).dataType.sameType(new VectorUDT()))) {
+ if (
+ dataSchema.size != 2 ||
+ !dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
+ !dataSchema(1).dataType.sameType(new VectorUDT()) ||
+ !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
+ ) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
@@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- Some(
- StructType(
- StructField("label", DoubleType, nullable = false) ::
- StructField("features", new VectorUDT(), nullable = false) :: Nil))
- }
-
- override def prepareRead(
- sparkSession: SparkSession,
- options: Map[String, String],
- files: Seq[FileStatus]): Map[String, String] = {
- val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
+ val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
+ // Infers number of features if the user doesn't specify (a valid) one.
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
dataFiles.head.getPath.toUri.toString
@@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
MLUtils.computeNumFeatures(parsed)
}
- new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+ val featuresMetadata = new MetadataBuilder()
+ .putLong("numFeatures", numFeatures)
+ .build()
+
+ Some(
+ StructType(
+ StructField("label", DoubleType, nullable = false) ::
+ StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil))
}
override def prepareWrite(
@@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
- val numFeatures = options("numFeatures").toInt
+ val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
assert(numFeatures > 0)
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"