aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala11
3 files changed, 18 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 62e09d2e0c..4988dd66f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "LibSVM"
private def verifySchema(dataSchema: StructType): Unit = {
- if (dataSchema.size != 2 ||
- (!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
- || !dataSchema(1).dataType.sameType(new VectorUDT()))) {
+ if (
+ dataSchema.size != 2 ||
+ !dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
+ !dataSchema(1).dataType.sameType(new VectorUDT()) ||
+ !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
+ ) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
@@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- Some(
- StructType(
- StructField("label", DoubleType, nullable = false) ::
- StructField("features", new VectorUDT(), nullable = false) :: Nil))
- }
-
- override def prepareRead(
- sparkSession: SparkSession,
- options: Map[String, String],
- files: Seq[FileStatus]): Map[String, String] = {
- val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
+ val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
+ // Infers number of features if the user doesn't specify (a valid) one.
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
dataFiles.head.getPath.toUri.toString
@@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
MLUtils.computeNumFeatures(parsed)
}
- new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+ val featuresMetadata = new MetadataBuilder()
+ .putLong("numFeatures", numFeatures)
+ .build()
+
+ Some(
+ StructType(
+ StructField("label", DoubleType, nullable = false) ::
+ StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil))
}
override def prepareWrite(
@@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
- val numFeatures = options("numFeatures").toInt
+ val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
assert(numFeatures > 0)
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index d3273025b6..7f3683fc98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -386,9 +386,6 @@ case class DataSource(
"It must be specified manually")
}
- val enrichedOptions =
- format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles())
-
HadoopFsRelation(
sparkSession,
fileCatalog,
@@ -396,7 +393,7 @@ case class DataSource(
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
- enrichedOptions)
+ caseInsensitiveOptions)
case _ =>
throw new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 641c5cb02b..4ac555be7f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompres
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
@@ -187,15 +186,6 @@ trait FileFormat {
files: Seq[FileStatus]): Option[StructType]
/**
- * Prepares a read job and returns a potentially updated data source option [[Map]]. This method
- * can be useful for collecting necessary global information for scanning input data.
- */
- def prepareRead(
- sparkSession: SparkSession,
- options: Map[String, String],
- files: Seq[FileStatus]): Map[String, String] = options
-
- /**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
@@ -454,7 +444,6 @@ private[sql] object HadoopFsRelation extends Logging {
logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")
val sparkContext = sparkSession.sparkContext
- val sqlConf = sparkSession.sessionState.conf
val serializableConfiguration = new SerializableConfiguration(hadoopConf)
val serializedPaths = paths.map(_.toString)