diff options
author | Cheng Lian <lian@databricks.com> | 2016-06-16 10:24:29 -0700 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-06-16 10:24:29 -0700 |
commit | 9ea0d5e326e08b914aa46f1eec8795688a61bf74 (patch) | |
tree | 41b7df75ad7d557e09297cacbf4b48f000fa9706 /mllib | |
parent | 6451cf9270b55465d8ecea4c4031329a1058561a (diff) | |
download | spark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.tar.gz spark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.tar.bz2 spark-9ea0d5e326e08b914aa46f1eec8795688a61bf74.zip |
[SPARK-15983][SQL] Removes FileFormat.prepareRead
## What changes were proposed in this pull request?
Interface method `FileFormat.prepareRead()` was added in #12088 to handle a special case in the LibSVM data source.
However, the semantics of this interface method isn't intuitive: it returns a modified version of the data source options map. Considering that the LibSVM case can be easily handled using schema metadata inside `inferSchema`, we can remove this interface method to keep the `FileFormat` interface clean.
## How was this patch tested?
Existing tests.
Author: Cheng Lian <lian@databricks.com>
Closes #13698 from liancheng/remove-prepare-read.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 62e09d2e0c..4988dd66f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { override def toString: String = "LibSVM" private def verifySchema(dataSchema: StructType): Unit = { - if (dataSchema.size != 2 || - (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) - || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + if ( + dataSchema.size != 2 || + !dataSchema(0).dataType.sameType(DataTypes.DoubleType) || + !dataSchema(1).dataType.sameType(new VectorUDT()) || + !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0) + ) { throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } } @@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - Some( - StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil)) - } - - override def prepareRead( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = { - val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { + val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse { + // Infers number of features if the user doesn't specify (a valid) one. val dataFiles = files.filterNot(_.getPath.getName startsWith "_") val path = if (dataFiles.length == 1) { dataFiles.head.getPath.toUri.toString @@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { MLUtils.computeNumFeatures(parsed) } - new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + val featuresMetadata = new MetadataBuilder() + .putLong("numFeatures", numFeatures) + .build() + + Some( + StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil)) } override def prepareWrite( @@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { verifySchema(dataSchema) - val numFeatures = options("numFeatures").toInt + val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt assert(numFeatures > 0) val sparse = options.getOrElse("vectorType", "sparse") == "sparse" |