diff options
3 files changed, 18 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 62e09d2e0c..4988dd66f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { override def toString: String = "LibSVM" private def verifySchema(dataSchema: StructType): Unit = { - if (dataSchema.size != 2 || - (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) - || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + if ( + dataSchema.size != 2 || + !dataSchema(0).dataType.sameType(DataTypes.DoubleType) || + !dataSchema(1).dataType.sameType(new VectorUDT()) || + !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0) + ) { throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } } @@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - Some( - StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil)) - } - - override def prepareRead( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = { - val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { + val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse { + // Infers number of features if the user doesn't specify (a valid) one. val dataFiles = files.filterNot(_.getPath.getName startsWith "_") val path = if (dataFiles.length == 1) { dataFiles.head.getPath.toUri.toString @@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { MLUtils.computeNumFeatures(parsed) } - new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + val featuresMetadata = new MetadataBuilder() + .putLong("numFeatures", numFeatures) + .build() + + Some( + StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil)) } override def prepareWrite( @@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { verifySchema(dataSchema) - val numFeatures = options("numFeatures").toInt + val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt assert(numFeatures > 0) val sparse = options.getOrElse("vectorType", "sparse") == "sparse" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d3273025b6..7f3683fc98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -386,9 +386,6 @@ case class DataSource( "It must be specified manually") } - val enrichedOptions = - format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) - HadoopFsRelation( sparkSession, fileCatalog, @@ -396,7 +393,7 @@ case class DataSource( dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, - enrichedOptions) + caseInsensitiveOptions) case _ => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 641c5cb02b..4ac555be7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompres import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql._ @@ -187,15 +186,6 @@ trait FileFormat { files: Seq[FileStatus]): Option[StructType] /** - * Prepares a read job and returns a potentially updated data source option [[Map]]. This method - * can be useful for collecting necessary global information for scanning input data. - */ - def prepareRead( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = options - - /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. @@ -454,7 +444,6 @@ private[sql] object HadoopFsRelation extends Logging { logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") val sparkContext = sparkSession.sparkContext - val sqlConf = sparkSession.sessionState.conf val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = paths.map(_.toString) |