aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala107
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala182
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala191
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala207
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala406
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala90
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala283
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala4
13 files changed, 1313 insertions, 220 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 265a61592b..f3107f7b51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -27,23 +27,23 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonFactory
-
import org.apache.commons.lang3.StringUtils
+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
+import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JacksonGenerator
+import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -400,7 +400,9 @@ class DataFrame private[sql](
joined.left,
joined.right,
joinType = Inner,
- Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn))))
+ Some(expressions.EqualTo(
+ joined.left.resolve(usingColumn),
+ joined.right.resolve(usingColumn))))
)
}
@@ -465,8 +467,8 @@ class DataFrame private[sql](
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val cond = plan.condition.map { _.transform {
- case EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) =>
- EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name))
+ case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) =>
+ expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name))
}}
plan.copy(condition = cond)
}
@@ -1326,6 +1328,28 @@ class DataFrame private[sql](
/**
* :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of
+ * partition columns.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ * @group output
+ */
+ @Experimental
+ def saveAsTable(
+ tableName: String,
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String],
+ partitionColumns: java.util.List[String]): Unit = {
+ saveAsTable(tableName, source, mode, options.toMap, partitionColumns)
+ }
+
+ /**
+ * :: Experimental ::
* (Scala-specific)
* Creates a table from the the contents of this DataFrame based on a given data source,
* [[SaveMode]] specified by mode, and a set of options.
@@ -1350,6 +1374,7 @@ class DataFrame private[sql](
tableName,
source,
temporary = false,
+ Array.empty[String],
mode,
options,
logicalPlan)
@@ -1359,6 +1384,36 @@ class DataFrame private[sql](
/**
* :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of
+ * partition columns.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ * @group output
+ */
+ @Experimental
+ def saveAsTable(
+ tableName: String,
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String],
+ partitionColumns: Seq[String]): Unit = {
+ sqlContext.executePlan(
+ CreateTableUsingAsSelect(
+ tableName,
+ source,
+ temporary = false,
+ partitionColumns.toArray,
+ mode,
+ options,
+ logicalPlan)).toRdd
+ }
+
+ /**
+ * :: Experimental ::
* Saves the contents of this DataFrame to the given path,
* using the default data source configured by spark.sql.sources.default and
* [[SaveMode.ErrorIfExists]] as the save mode.
@@ -1419,6 +1474,21 @@ class DataFrame private[sql](
/**
* :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source,
+ * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`.
+ * @group output
+ */
+ @Experimental
+ def save(
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String],
+ partitionColumns: java.util.List[String]): Unit = {
+ save(source, mode, options.toMap, partitionColumns)
+ }
+
+ /**
+ * :: Experimental ::
* (Scala-specific)
* Saves the contents of this DataFrame based on the given data source,
* [[SaveMode]] specified by mode, and a set of options
@@ -1429,7 +1499,22 @@ class DataFrame private[sql](
source: String,
mode: SaveMode,
options: Map[String, String]): Unit = {
- ResolvedDataSource(sqlContext, source, mode, options, this)
+ ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this)
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source,
+ * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`.
+ * @group output
+ */
+ @Experimental
+ def save(
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String],
+ partitionColumns: Seq[String]): Unit = {
+ ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index dcac97beaf..f07bb196c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -66,6 +66,9 @@ private[spark] object SQLConf {
// to its length exceeds the threshold.
val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold"
+ // Whether to perform partition discovery when loading external data sources. Default to true.
+ val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled"
+
// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
@@ -235,6 +238,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def defaultDataSourceName: String =
getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet")
+ private[spark] def partitionDiscoveryEnabled() =
+ getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean
+
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
private[spark] def schemaStringLengthThreshold: Int =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 648021806f..afee09adaa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -762,7 +762,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def load(source: String, options: Map[String, String]): DataFrame = {
- val resolved = ResolvedDataSource(this, None, source, options)
+ val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}
@@ -783,6 +783,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
+ * (Java-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
+ *
+ * @group genericdata
+ */
+ @Experimental
+ def load(
+ source: String,
+ schema: StructType,
+ partitionColumns: Array[String],
+ options: java.util.Map[String, String]): DataFrame = {
+ load(source, schema, partitionColumns, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
+ * @group genericdata
+ */
+ @Experimental
+ def load(
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options)
+ DataFrame(this, LogicalRelation(resolved.relation))
+ }
+
+ /**
+ * :: Experimental ::
* (Scala-specific) Returns the dataset specified by the given data source and
* a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
* @group genericdata
@@ -791,8 +822,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
def load(
source: String,
schema: StructType,
+ partitionColumns: Array[String],
options: Map[String, String]): DataFrame = {
- val resolved = ResolvedDataSource(this, Some(schema), source, options)
+ val resolved = ResolvedDataSource(this, Some(schema), partitionColumns, source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 56a4689eb5..af0029cb84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -343,9 +343,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) =>
- val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
+ case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query)
+ if partitionsCols.nonEmpty =>
+ sys.error("Cannot create temporary partitioned table.")
+
+ case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) =>
+ val cmd = CreateTempTableUsingAsSelect(
+ tableName, provider, Array.empty[String], mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 85e60733bc..ee4b1c72a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -136,10 +136,6 @@ private[sql] class DefaultSource
}
}
-private[sql] case class Partition(values: Row, path: String)
-
-private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
-
/**
* An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is
* intended as a full replacement of the Parquet support in Spark SQL. The old implementation will
@@ -307,7 +303,7 @@ private[sql] case class ParquetRelation2(
if (partitionDirs.nonEmpty) {
// Parses names and values of partition columns, and infer their data types.
- ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName)
+ PartitioningUtils.parsePartitions(partitionDirs, defaultPartitionName)
} else {
// No partition directories found, makes an empty specification
PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition])
@@ -805,7 +801,7 @@ private[sql] object ParquetRelation2 extends Logging {
val ordinalMap = metastoreSchema.zipWithIndex.map {
case (field, index) => field.name.toLowerCase -> index
}.toMap
- val reorderedParquetSchema = mergedParquetSchema.sortBy(f =>
+ val reorderedParquetSchema = mergedParquetSchema.sortBy(f =>
ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1))
StructType(metastoreSchema.zip(reorderedParquetSchema).map {
@@ -841,178 +837,4 @@ private[sql] object ParquetRelation2 extends Logging {
.filter(_.nullable)
StructType(parquetSchema ++ missingFields)
}
-
-
- // TODO Data source implementations shouldn't touch Catalyst types (`Literal`).
- // However, we are already using Catalyst expressions for partition pruning and predicate
- // push-down here...
- private[parquet] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) {
- require(columnNames.size == literals.size)
- }
-
- /**
- * Given a group of qualified paths, tries to parse them and returns a partition specification.
- * For example, given:
- * {{{
- * hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14
- * hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28
- * }}}
- * it returns:
- * {{{
- * PartitionSpec(
- * partitionColumns = StructType(
- * StructField(name = "a", dataType = IntegerType, nullable = true),
- * StructField(name = "b", dataType = StringType, nullable = true),
- * StructField(name = "c", dataType = DoubleType, nullable = true)),
- * partitions = Seq(
- * Partition(
- * values = Row(1, "hello", 3.14),
- * path = "hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14"),
- * Partition(
- * values = Row(2, "world", 6.28),
- * path = "hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28")))
- * }}}
- */
- private[parquet] def parsePartitions(
- paths: Seq[Path],
- defaultPartitionName: String): PartitionSpec = {
- val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName)))
- val fields = {
- val (PartitionValues(columnNames, literals)) = partitionValues.head
- columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
- StructField(name, dataType, nullable = true)
- }
- }
-
- val partitions = partitionValues.zip(paths).map {
- case (PartitionValues(_, literals), path) =>
- Partition(Row(literals.map(_.value): _*), path.toString)
- }
-
- PartitionSpec(StructType(fields), partitions)
- }
-
- /**
- * Parses a single partition, returns column names and values of each partition column. For
- * example, given:
- * {{{
- * path = hdfs://<host>:<port>/path/to/partition/a=42/b=hello/c=3.14
- * }}}
- * it returns:
- * {{{
- * PartitionValues(
- * Seq("a", "b", "c"),
- * Seq(
- * Literal.create(42, IntegerType),
- * Literal.create("hello", StringType),
- * Literal.create(3.14, FloatType)))
- * }}}
- */
- private[parquet] def parsePartition(
- path: Path,
- defaultPartitionName: String): PartitionValues = {
- val columns = ArrayBuffer.empty[(String, Literal)]
- // Old Hadoop versions don't have `Path.isRoot`
- var finished = path.getParent == null
- var chopped = path
-
- while (!finished) {
- val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName)
- maybeColumn.foreach(columns += _)
- chopped = chopped.getParent
- finished = maybeColumn.isEmpty || chopped.getParent == null
- }
-
- val (columnNames, values) = columns.reverse.unzip
- PartitionValues(columnNames, values)
- }
-
- private def parsePartitionColumn(
- columnSpec: String,
- defaultPartitionName: String): Option[(String, Literal)] = {
- val equalSignIndex = columnSpec.indexOf('=')
- if (equalSignIndex == -1) {
- None
- } else {
- val columnName = columnSpec.take(equalSignIndex)
- assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'")
-
- val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
- assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
-
- val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName)
- Some(columnName -> literal)
- }
- }
-
- /**
- * Resolves possible type conflicts between partitions by up-casting "lower" types. The up-
- * casting order is:
- * {{{
- * NullType ->
- * IntegerType -> LongType ->
- * FloatType -> DoubleType -> DecimalType.Unlimited ->
- * StringType
- * }}}
- */
- private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = {
- // Column names of all partitions must match
- val distinctPartitionsColNames = values.map(_.columnNames).distinct
- assert(distinctPartitionsColNames.size == 1, {
- val list = distinctPartitionsColNames.mkString("\t", "\n", "")
- s"Conflicting partition column names detected:\n$list"
- })
-
- // Resolves possible type conflicts for each column
- val columnCount = values.head.columnNames.size
- val resolvedValues = (0 until columnCount).map { i =>
- resolveTypeConflicts(values.map(_.literals(i)))
- }
-
- // Fills resolved literals back to each partition
- values.zipWithIndex.map { case (d, index) =>
- d.copy(literals = resolvedValues.map(_(index)))
- }
- }
-
- /**
- * Converts a string to a `Literal` with automatic type inference. Currently only supports
- * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and
- * [[StringType]].
- */
- private[parquet] def inferPartitionColumnValue(
- raw: String,
- defaultPartitionName: String): Literal = {
- // First tries integral types
- Try(Literal.create(Integer.parseInt(raw), IntegerType))
- .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
- // Then falls back to fractional types
- .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType)))
- .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
- .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
- // Then falls back to string
- .getOrElse {
- if (raw == defaultPartitionName) Literal.create(null, NullType)
- else Literal.create(raw, StringType)
- }
- }
-
- private val upCastingOrder: Seq[DataType] =
- Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType)
-
- /**
- * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
- * types.
- */
- private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = {
- val desiredType = {
- val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_))
- // Falls back to string if all values of this column are null or empty string
- if (topType == NullType) StringType else topType
- }
-
- literals.map { case l @ Literal(_, dataType) =>
- Literal.create(Cast(l, desiredType).eval(), desiredType)
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index b3d71f687a..a5410cda0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -17,20 +17,25 @@
package org.apache.spark.sql.sources
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{UTF8String, StringType}
-import org.apache.spark.sql.{Row, Strategy, execution, sources}
+import org.apache.spark.sql.types.{StructType, UTF8String, StringType}
+import org.apache.spark.sql._
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-private[sql] object DataSourceStrategy extends Strategy {
+private[sql] object DataSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
@@ -53,6 +58,51 @@ private[sql] object DataSourceStrategy extends Strategy {
filters,
(a, _) => t.buildScan(a)) :: Nil
+ // Scanning partitioned FSBasedRelation
+ case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation))
+ if t.partitionSpec.partitionColumns.nonEmpty =>
+ val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray
+
+ logInfo {
+ val total = t.partitionSpec.partitions.length
+ val selected = selectedPartitions.length
+ val percentPruned = (1 - total.toDouble / selected.toDouble) * 100
+ s"Selected $selected partitions out of $total, pruned $percentPruned% partitions."
+ }
+
+ // Only pushes down predicates that do not reference partition columns.
+ val pushedFilters = {
+ val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet
+ filters.filter { f =>
+ val referencedColumnNames = f.references.map(_.name).toSet
+ referencedColumnNames.intersect(partitionColumnNames).isEmpty
+ }
+ }
+
+ buildPartitionedTableScan(
+ l,
+ projectList,
+ pushedFilters,
+ t.partitionSpec.partitionColumns,
+ selectedPartitions) :: Nil
+
+ // Scanning non-partitioned FSBasedRelation
+ case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) =>
+ val inputPaths = t.paths.map(new Path(_)).flatMap { path =>
+ val fs = path.getFileSystem(t.sqlContext.sparkContext.hadoopConfiguration)
+ val qualifiedPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ SparkHadoopUtil.get.listLeafStatuses(fs, qualifiedPath).map(_.getPath).filterNot { path =>
+ val name = path.getName
+ name.startsWith("_") || name.startsWith(".")
+ }.map(fs.makeQualified(_).toString)
+ }
+
+ pruneFilterProject(
+ l,
+ projectList,
+ filters,
+ (a, f) => t.buildScan(a, f, inputPaths)) :: Nil
+
case l @ LogicalRelation(t: TableScan) =>
createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil
@@ -60,9 +110,144 @@ private[sql] object DataSourceStrategy extends Strategy {
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil
+ case i @ logical.InsertIntoTable(
+ l @ LogicalRelation(t: FSBasedRelation), part, query, overwrite, false) if part.isEmpty =>
+ val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
+ execution.ExecutedCommand(
+ InsertIntoFSBasedRelation(t, query, Array.empty[String], mode)) :: Nil
+
case _ => Nil
}
+ private def buildPartitionedTableScan(
+ logicalRelation: LogicalRelation,
+ projections: Seq[NamedExpression],
+ filters: Seq[Expression],
+ partitionColumns: StructType,
+ partitions: Array[Partition]) = {
+ val output = projections.map(_.toAttribute)
+ val relation = logicalRelation.relation.asInstanceOf[FSBasedRelation]
+
+ // Builds RDD[Row]s for each selected partition.
+ val perPartitionRows = partitions.map { case Partition(partitionValues, dir) =>
+ // Paths to all data files within this partition
+ val dataFilePaths = {
+ val dirPath = new Path(dir)
+ val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf)
+ fs.listStatus(dirPath).map(_.getPath).filterNot { path =>
+ val name = path.getName
+ name.startsWith("_") || name.startsWith(".")
+ }.map(fs.makeQualified(_).toString)
+ }
+
+ // The table scan operator (PhysicalRDD) which retrieves required columns from data files.
+ // Notice that the schema of data files, represented by `relation.dataSchema`, may contain
+ // some partition column(s).
+ val scan =
+ pruneFilterProject(
+ logicalRelation,
+ projections,
+ filters,
+ (requiredColumns, filters) => {
+ val partitionColNames = partitionColumns.fieldNames
+
+ // Don't scan any partition columns to save I/O. Here we are being optimistic and
+ // assuming partition columns data stored in data files are always consistent with those
+ // partition values encoded in partition directory paths.
+ val nonPartitionColumns = requiredColumns.filterNot(partitionColNames.contains)
+ val dataRows = relation.buildScan(nonPartitionColumns, filters, dataFilePaths)
+
+ // Merges data values with partition values.
+ mergeWithPartitionValues(
+ relation.schema,
+ requiredColumns,
+ partitionColNames,
+ partitionValues,
+ dataRows)
+ })
+
+ scan.execute()
+ }
+
+ val unionedRows = perPartitionRows.reduceOption(_ ++ _).getOrElse {
+ relation.sqlContext.emptyResult
+ }
+
+ createPhysicalRDD(logicalRelation.relation, output, unionedRows)
+ }
+
+ private def mergeWithPartitionValues(
+ schema: StructType,
+ requiredColumns: Array[String],
+ partitionColumns: Array[String],
+ partitionValues: Row,
+ dataRows: RDD[Row]): RDD[Row] = {
+ val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains)
+
+ // If output columns contain any partition column(s), we need to merge scanned data
+ // columns and requested partition columns to form the final result.
+ if (!requiredColumns.sameElements(nonPartitionColumns)) {
+ val mergers = requiredColumns.zipWithIndex.map { case (name, index) =>
+ // To see whether the `index`-th column is a partition column...
+ val i = partitionColumns.indexOf(name)
+ if (i != -1) {
+ // If yes, gets column value from partition values.
+ (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => {
+ mutableRow(ordinal) = partitionValues(i)
+ }
+ } else {
+ // Otherwise, inherits the value from scanned data.
+ val i = nonPartitionColumns.indexOf(name)
+ (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => {
+ mutableRow(ordinal) = dataRow(i)
+ }
+ }
+ }
+
+ dataRows.mapPartitions { iterator =>
+ val dataTypes = requiredColumns.map(schema(_).dataType)
+ val mutableRow = new SpecificMutableRow(dataTypes)
+ iterator.map { dataRow =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mergers(i)(mutableRow, dataRow, i)
+ i += 1
+ }
+ mutableRow.asInstanceOf[expressions.Row]
+ }
+ }
+ } else {
+ dataRows
+ }
+ }
+
+ protected def prunePartitions(
+ predicates: Seq[Expression],
+ partitionSpec: PartitionSpec): Seq[Partition] = {
+ val PartitionSpec(partitionColumns, partitions) = partitionSpec
+ val partitionColumnNames = partitionColumns.map(_.name).toSet
+ val partitionPruningPredicates = predicates.filter {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+
+ if (partitionPruningPredicates.nonEmpty) {
+ val predicate =
+ partitionPruningPredicates
+ .reduceOption(expressions.And)
+ .getOrElse(Literal(true))
+
+ val boundPredicate = InterpretedPredicate.create(predicate.transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ })
+
+ partitions.filter { case Partition(values, _) => boundPredicate(values) }
+ } else {
+ partitions
+ }
+ }
+
// Based on Public API.
protected def pruneFilterProject(
relation: LogicalRelation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
new file mode 100644
index 0000000000..d30f7f65e2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.lang.{Double => JDouble, Float => JFloat, Long => JLong}
+import java.math.{BigDecimal => JBigDecimal}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
+
+import com.google.common.cache.{CacheBuilder, Cache}
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
+import org.apache.spark.sql.types._
+
+private[sql] case class Partition(values: Row, path: String)
+
+private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
+
+private[sql] object PartitioningUtils {
+ private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) {
+ require(columnNames.size == literals.size)
+ }
+
+ /**
+ * Given a group of qualified paths, tries to parse them and returns a partition specification.
+ * For example, given:
+ * {{{
+ * hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14
+ * hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28
+ * }}}
+ * it returns:
+ * {{{
+ * PartitionSpec(
+ * partitionColumns = StructType(
+ * StructField(name = "a", dataType = IntegerType, nullable = true),
+ * StructField(name = "b", dataType = StringType, nullable = true),
+ * StructField(name = "c", dataType = DoubleType, nullable = true)),
+ * partitions = Seq(
+ * Partition(
+ * values = Row(1, "hello", 3.14),
+ * path = "hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14"),
+ * Partition(
+ * values = Row(2, "world", 6.28),
+ * path = "hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28")))
+ * }}}
+ */
+ private[sql] def parsePartitions(
+ paths: Seq[Path],
+ defaultPartitionName: String): PartitionSpec = {
+ val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName)))
+ val fields = {
+ val (PartitionValues(columnNames, literals)) = partitionValues.head
+ columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
+ StructField(name, dataType, nullable = true)
+ }
+ }
+
+ val partitions = partitionValues.zip(paths).map {
+ case (PartitionValues(_, literals), path) =>
+ Partition(Row(literals.map(_.value): _*), path.toString)
+ }
+
+ PartitionSpec(StructType(fields), partitions)
+ }
+
+ /**
+ * Parses a single partition, returns column names and values of each partition column. For
+ * example, given:
+ * {{{
+ * path = hdfs://<host>:<port>/path/to/partition/a=42/b=hello/c=3.14
+ * }}}
+ * it returns:
+ * {{{
+ * PartitionValues(
+ * Seq("a", "b", "c"),
+ * Seq(
+ * Literal.create(42, IntegerType),
+ * Literal.create("hello", StringType),
+ * Literal.create(3.14, FloatType)))
+ * }}}
+ */
+ private[sql] def parsePartition(
+ path: Path,
+ defaultPartitionName: String): PartitionValues = {
+ val columns = ArrayBuffer.empty[(String, Literal)]
+ // Old Hadoop versions don't have `Path.isRoot`
+ var finished = path.getParent == null
+ var chopped = path
+
+ while (!finished) {
+ val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName)
+ maybeColumn.foreach(columns += _)
+ chopped = chopped.getParent
+ finished = maybeColumn.isEmpty || chopped.getParent == null
+ }
+
+ val (columnNames, values) = columns.reverse.unzip
+ PartitionValues(columnNames, values)
+ }
+
+ private def parsePartitionColumn(
+ columnSpec: String,
+ defaultPartitionName: String): Option[(String, Literal)] = {
+ val equalSignIndex = columnSpec.indexOf('=')
+ if (equalSignIndex == -1) {
+ None
+ } else {
+ val columnName = columnSpec.take(equalSignIndex)
+ assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'")
+
+ val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
+ assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
+
+ val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName)
+ Some(columnName -> literal)
+ }
+ }
+
+ /**
+ * Resolves possible type conflicts between partitions by up-casting "lower" types. The up-
+ * casting order is:
+ * {{{
+ * NullType ->
+ * IntegerType -> LongType ->
+ * FloatType -> DoubleType -> DecimalType.Unlimited ->
+ * StringType
+ * }}}
+ */
+ private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = {
+ // Column names of all partitions must match
+ val distinctPartitionsColNames = values.map(_.columnNames).distinct
+ assert(distinctPartitionsColNames.size == 1, {
+ val list = distinctPartitionsColNames.mkString("\t", "\n", "")
+ s"Conflicting partition column names detected:\n$list"
+ })
+
+ // Resolves possible type conflicts for each column
+ val columnCount = values.head.columnNames.size
+ val resolvedValues = (0 until columnCount).map { i =>
+ resolveTypeConflicts(values.map(_.literals(i)))
+ }
+
+ // Fills resolved literals back to each partition
+ values.zipWithIndex.map { case (d, index) =>
+ d.copy(literals = resolvedValues.map(_(index)))
+ }
+ }
+
+ /**
+ * Converts a string to a `Literal` with automatic type inference. Currently only supports
+ * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and
+ * [[StringType]].
+ */
+ private[sql] def inferPartitionColumnValue(
+ raw: String,
+ defaultPartitionName: String): Literal = {
+ // First tries integral types
+ Try(Literal.create(Integer.parseInt(raw), IntegerType))
+ .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
+ // Then falls back to fractional types
+ .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType)))
+ .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
+ .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
+ // Then falls back to string
+ .getOrElse {
+ if (raw == defaultPartitionName) Literal.create(null, NullType)
+ else Literal.create(raw, StringType)
+ }
+ }
+
+ private val upCastingOrder: Seq[DataType] =
+ Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType)
+
+ /**
+ * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
+ * types.
+ */
+ private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = {
+ val desiredType = {
+ val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_))
+ // Falls back to string if all values of this column are null or empty string
+ if (topType == NullType) StringType else topType
+ }
+
+ literals.map { case l @ Literal(_, dataType) =>
+ Literal.create(Cast(l, desiredType).eval(), desiredType)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index dbdb0d39c2..127133bfaf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -14,12 +14,28 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql.sources
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import java.util.Date
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat}
+import org.apache.hadoop.util.Shell
+import parquet.hadoop.util.ContextUtil
+
+import org.apache.spark._
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
@@ -41,3 +57,391 @@ private[sql] case class InsertIntoDataSource(
Seq.empty[Row]
}
}
+
+private[sql] case class InsertIntoFSBasedRelation(
+ @transient relation: FSBasedRelation,
+ @transient query: LogicalPlan,
+ partitionColumns: Array[String],
+ mode: SaveMode)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ require(
+ relation.paths.length == 1,
+ s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
+
+ val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
+ val outputPath = new Path(relation.paths.head)
+ val fs = outputPath.getFileSystem(hadoopConf)
+ val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+
+ val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match {
+ case (SaveMode.ErrorIfExists, true) =>
+ sys.error(s"path $qualifiedOutputPath already exists.")
+ case (SaveMode.Overwrite, true) =>
+ fs.delete(qualifiedOutputPath, true)
+ true
+ case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
+ true
+ case (SaveMode.Ignore, exists) =>
+ !exists
+ }
+
+ if (doInsertion) {
+ val job = Job.getInstance(hadoopConf)
+ job.setOutputKeyClass(classOf[Void])
+ job.setOutputValueClass(classOf[Row])
+ FileOutputFormat.setOutputPath(job, qualifiedOutputPath)
+
+ val df = sqlContext.createDataFrame(
+ DataFrame(sqlContext, query).queryExecution.toRdd,
+ relation.schema,
+ needsConversion = false)
+
+ if (partitionColumns.isEmpty) {
+ insert(new DefaultWriterContainer(relation, job), df)
+ } else {
+ val writerContainer = new DynamicPartitionWriterContainer(
+ relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
+ insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
+ }
+ }
+
+ Seq.empty[Row]
+ }
+
+ private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
+ // Uses local vals for serialization
+ val needsConversion = relation.needConversion
+ val dataSchema = relation.dataSchema
+
+ try {
+ writerContainer.driverSideSetup()
+ df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ writerContainer.commitJob()
+ relation.refresh()
+ } catch { case cause: Throwable =>
+ writerContainer.abortJob()
+ throw new SparkException("Job aborted.", cause)
+ }
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
+ writerContainer.executorSideSetup(taskContext)
+
+ try {
+ if (needsConversion) {
+ val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
+ while (iterator.hasNext) {
+ val row = converter(iterator.next()).asInstanceOf[Row]
+ writerContainer.outputWriterForRow(row).write(row)
+ }
+ } else {
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ writerContainer.outputWriterForRow(row).write(row)
+ }
+ }
+ writerContainer.commitTask()
+ } catch { case cause: Throwable =>
+ writerContainer.abortTask()
+ throw new SparkException("Task failed while writing rows.", cause)
+ }
+ }
+ }
+
+ private def insertWithDynamicPartitions(
+ sqlContext: SQLContext,
+ writerContainer: BaseWriterContainer,
+ df: DataFrame,
+ partitionColumns: Array[String]): Unit = {
+ // Uses a local val for serialization
+ val needsConversion = relation.needConversion
+ val dataSchema = relation.dataSchema
+
+ require(
+ df.schema == relation.schema,
+ s"""DataFrame must have the same schema as the relation to which is inserted.
+ |DataFrame schema: ${df.schema}
+ |Relation schema: ${relation.schema}
+ """.stripMargin)
+
+ val partitionColumnsInSpec = relation.partitionColumns.fieldNames
+ require(
+ partitionColumnsInSpec.sameElements(partitionColumns),
+ s"""Partition columns mismatch.
+ |Expected: ${partitionColumnsInSpec.mkString(", ")}
+ |Actual: ${partitionColumns.mkString(", ")}
+ """.stripMargin)
+
+ val output = df.queryExecution.executedPlan.output
+ val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
+ val codegenEnabled = df.sqlContext.conf.codegenEnabled
+
+ try {
+ writerContainer.driverSideSetup()
+ df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ writerContainer.commitJob()
+ relation.refresh()
+ } catch { case cause: Throwable =>
+ logError("Aborting job.", cause)
+ writerContainer.abortJob()
+ throw new SparkException("Job aborted.", cause)
+ }
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
+ writerContainer.executorSideSetup(taskContext)
+
+ val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
+ val dataProj = newProjection(codegenEnabled, dataOutput, output)
+
+ if (needsConversion) {
+ val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ val partitionPart = partitionProj(row)
+ val dataPart = dataProj(row)
+ val convertedDataPart = converter(dataPart).asInstanceOf[Row]
+ writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart)
+ }
+ } else {
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ val partitionPart = partitionProj(row)
+ val dataPart = dataProj(row)
+ writerContainer.outputWriterForRow(partitionPart).write(dataPart)
+ }
+ }
+
+ writerContainer.commitTask()
+ }
+ }
+
+ // This is copied from SparkPlan, probably should move this to a more general place.
+ private def newProjection(
+ codegenEnabled: Boolean,
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): Projection = {
+ log.debug(
+ s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if (codegenEnabled) {
+ GenerateProjection.generate(expressions, inputSchema)
+ } else {
+ new InterpretedProjection(expressions, inputSchema)
+ }
+ }
+}
+
+private[sql] abstract class BaseWriterContainer(
+ @transient val relation: FSBasedRelation,
+ @transient job: Job)
+ extends SparkHadoopMapReduceUtil
+ with Logging
+ with Serializable {
+
+ protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job))
+
+ // This is only used on driver side.
+ @transient private val jobContext: JobContext = job
+
+ // The following fields are initialized and used on both driver and executor side.
+ @transient protected var outputCommitter: FileOutputCommitter = _
+ @transient private var jobId: JobID = _
+ @transient private var taskId: TaskID = _
+ @transient private var taskAttemptId: TaskAttemptID = _
+ @transient protected var taskAttemptContext: TaskAttemptContext = _
+
+ protected val outputPath: String = {
+ assert(
+ relation.paths.length == 1,
+ s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
+ relation.paths.head
+ }
+
+ protected val dataSchema = relation.dataSchema
+
+ protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass
+
+ private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _
+
+ def driverSideSetup(): Unit = {
+ setupIDs(0, 0, 0)
+ setupConf()
+ taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+ relation.prepareForWrite(job)
+ outputFormatClass = job.getOutputFormatClass
+ outputCommitter = newOutputCommitter(taskAttemptContext)
+ outputCommitter.setupJob(jobContext)
+ }
+
+ def executorSideSetup(taskContext: TaskContext): Unit = {
+ setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
+ setupConf()
+ taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+ outputCommitter = newOutputCommitter(taskAttemptContext)
+ outputCommitter.setupTask(taskAttemptContext)
+ initWriters()
+ }
+
+ private def newOutputCommitter(context: TaskAttemptContext): FileOutputCommitter = {
+ outputFormatClass.newInstance().getOutputCommitter(context) match {
+ case f: FileOutputCommitter => f
+ case f => sys.error(
+ s"FileOutputCommitter or its subclass is expected, but got a ${f.getClass.getName}.")
+ }
+ }
+
+ private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
+ this.jobId = SparkHadoopWriter.createJobID(new Date, jobId)
+ this.taskId = new TaskID(this.jobId, true, splitId)
+ this.taskAttemptId = new TaskAttemptID(taskId, attemptId)
+ }
+
+ private def setupConf(): Unit = {
+ serializableConf.value.set("mapred.job.id", jobId.toString)
+ serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
+ serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
+ serializableConf.value.setBoolean("mapred.task.is.map", true)
+ serializableConf.value.setInt("mapred.task.partition", 0)
+ }
+
+ // Called on executor side when writing rows
+ def outputWriterForRow(row: Row): OutputWriter
+
+ protected def initWriters(): Unit
+
+ def commitTask(): Unit = {
+ SparkHadoopMapRedUtil.commitTask(
+ outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId)
+ }
+
+ def abortTask(): Unit = {
+ outputCommitter.abortTask(taskAttemptContext)
+ logError(s"Task attempt $taskAttemptId aborted.")
+ }
+
+ def commitJob(): Unit = {
+ outputCommitter.commitJob(jobContext)
+ logInfo(s"Job $jobId committed.")
+ }
+
+ def abortJob(): Unit = {
+ outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+ logError(s"Job $jobId aborted.")
+ }
+}
+
+private[sql] class DefaultWriterContainer(
+ @transient relation: FSBasedRelation,
+ @transient job: Job)
+ extends BaseWriterContainer(relation, job) {
+
+ @transient private var writer: OutputWriter = _
+
+ override protected def initWriters(): Unit = {
+ writer = outputWriterClass.newInstance()
+ writer.init(outputCommitter.getWorkPath.toString, dataSchema, taskAttemptContext)
+ }
+
+ override def outputWriterForRow(row: Row): OutputWriter = writer
+
+ override def commitTask(): Unit = {
+ writer.close()
+ super.commitTask()
+ }
+
+ override def abortTask(): Unit = {
+ writer.close()
+ super.abortTask()
+ }
+}
+
+private[sql] class DynamicPartitionWriterContainer(
+ @transient relation: FSBasedRelation,
+ @transient job: Job,
+ partitionColumns: Array[String],
+ defaultPartitionName: String)
+ extends BaseWriterContainer(relation, job) {
+
+ // All output writers are created on executor side.
+ @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _
+
+ override protected def initWriters(): Unit = {
+ outputWriters = mutable.Map.empty[String, OutputWriter]
+ }
+
+ override def outputWriterForRow(row: Row): OutputWriter = {
+ val partitionPath = partitionColumns.zip(row.toSeq).map { case (col, rawValue) =>
+ val string = if (rawValue == null) null else String.valueOf(rawValue)
+ val valueString = if (string == null || string.isEmpty) {
+ defaultPartitionName
+ } else {
+ DynamicPartitionWriterContainer.escapePathName(string)
+ }
+ s"/$col=$valueString"
+ }.mkString
+
+ outputWriters.getOrElseUpdate(partitionPath, {
+ val path = new Path(outputCommitter.getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR))
+ val writer = outputWriterClass.newInstance()
+ writer.init(path.toString, dataSchema, taskAttemptContext)
+ writer
+ })
+ }
+
+ override def commitTask(): Unit = {
+ outputWriters.values.foreach(_.close())
+ super.commitTask()
+ }
+
+ override def abortTask(): Unit = {
+ outputWriters.values.foreach(_.close())
+ super.abortTask()
+ }
+}
+
+private[sql] object DynamicPartitionWriterContainer {
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+ // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils).
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ val charToEscape = {
+ val bitSet = new java.util.BitSet(128)
+
+ /**
+ * ASCII 01-1F are HTTP control characters that need to be escaped.
+ * \u000A and \u000D are \n and \r, respectively.
+ */
+ val clist = Array(
+ '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009',
+ '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013',
+ '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C',
+ '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F',
+ '{', '[', ']', '^')
+
+ clist.foreach(bitSet.set(_))
+
+ if (Shell.WINDOWS) {
+ Array(' ', '<', '>', '|').foreach(bitSet.set(_))
+ }
+
+ bitSet
+ }
+
+ def needsEscaping(c: Char): Boolean = {
+ c >= 0 && c < charToEscape.size() && charToEscape.get(c)
+ }
+
+ def escapePathName(path: String): String = {
+ val builder = new StringBuilder()
+ path.foreach { c =>
+ if (DynamicPartitionWriterContainer.needsEscaping(c)) {
+ builder.append('%')
+ builder.append(f"${c.asInstanceOf[Int]}%02x")
+ } else {
+ builder.append(c)
+ }
+ }
+
+ builder.toString()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 06c64f2bdd..595c5eb40e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -17,18 +17,20 @@
package org.apache.spark.sql.sources
-import scala.language.existentials
+import scala.language.{existentials, implicitConversions}
import scala.util.matching.Regex
-import scala.language.implicitConversions
+
+import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
-import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode}
import org.apache.spark.util.Utils
/**
@@ -111,6 +113,7 @@ private[sql] class DDLParser(
CreateTableUsingAsSelect(tableName,
provider,
temp.isDefined,
+ Array.empty[String],
mode,
options,
queryPlan)
@@ -157,7 +160,7 @@ private[sql] class DDLParser(
protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
- s"identifier matching regex ${regex}", {
+ s"identifier matching regex $regex", {
case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str
case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str
}
@@ -214,6 +217,7 @@ private[sql] object ResolvedDataSource {
def apply(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
+ partitionColumns: Array[String],
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
@@ -222,6 +226,27 @@ private[sql] object ResolvedDataSource {
case Some(schema: StructType) => clazz.newInstance() match {
case dataSource: SchemaRelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case dataSource: FSBasedRelationProvider =>
+ val maybePartitionsSchema = if (partitionColumns.isEmpty) {
+ None
+ } else {
+ Some(partitionColumnsSchema(schema, partitionColumns))
+ }
+
+ val caseInsensitiveOptions= new CaseInsensitiveMap(options)
+ val paths = {
+ val patternPath = new Path(caseInsensitiveOptions("path"))
+ SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray
+ }
+
+ val dataSchema = StructType(schema.filterNot(f => partitionColumns.contains(f.name)))
+
+ dataSource.createRelation(
+ sqlContext,
+ paths,
+ Some(schema),
+ maybePartitionsSchema,
+ caseInsensitiveOptions)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
throw new AnalysisException(s"$className does not allow user-specified schemas.")
case _ =>
@@ -231,20 +256,39 @@ private[sql] object ResolvedDataSource {
case None => clazz.newInstance() match {
case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case dataSource: FSBasedRelationProvider =>
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val paths = {
+ val patternPath = new Path(caseInsensitiveOptions("path"))
+ SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray
+ }
+ dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions)
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
throw new AnalysisException(
s"A schema needs to be specified when using $className.")
case _ =>
- throw new AnalysisException(s"$className is not a RelationProvider.")
+ throw new AnalysisException(
+ s"$className is neither a RelationProvider nor a FSBasedRelationProvider.")
}
}
new ResolvedDataSource(clazz, relation)
}
+ private def partitionColumnsSchema(
+ schema: StructType,
+ partitionColumns: Array[String]): StructType = {
+ StructType(partitionColumns.map { col =>
+ schema.find(_.name == col).getOrElse {
+ throw new RuntimeException(s"Partition column $col not found in schema $schema")
+ }
+ }).asNullable
+ }
+
/** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
def apply(
sqlContext: SQLContext,
provider: String,
+ partitionColumns: Array[String],
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
@@ -252,6 +296,31 @@ private[sql] object ResolvedDataSource {
val relation = clazz.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data)
+ case dataSource: FSBasedRelationProvider =>
+ // Don't glob path for the write path. The contracts here are:
+ // 1. Only one output path can be specified on the write path;
+ // 2. Output path must be a legal HDFS style file system path;
+ // 3. It's OK that the output path doesn't exist yet;
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val outputPath = {
+ val path = new Path(caseInsensitiveOptions("path"))
+ val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ path.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ }
+ val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
+ val r = dataSource.createRelation(
+ sqlContext,
+ Array(outputPath.toString),
+ Some(dataSchema.asNullable),
+ Some(partitionColumnsSchema(data.schema, partitionColumns)),
+ caseInsensitiveOptions)
+ sqlContext.executePlan(
+ InsertIntoFSBasedRelation(
+ r,
+ data.logicalPlan,
+ partitionColumns.toArray,
+ mode)).toRdd
+ r
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}
@@ -310,6 +379,7 @@ private[sql] case class CreateTableUsingAsSelect(
tableName: String,
provider: String,
temporary: Boolean,
+ partitionColumns: Array[String],
mode: SaveMode,
options: Map[String, String],
child: LogicalPlan) extends UnaryNode {
@@ -324,8 +394,9 @@ private[sql] case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
+ def run(sqlContext: SQLContext): Seq[Row] = {
+ val resolved = ResolvedDataSource(
+ sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
@@ -335,13 +406,14 @@ private[sql] case class CreateTempTableUsing(
private[sql] case class CreateTempTableUsingAsSelect(
tableName: String,
provider: String,
+ partitionColumns: Array[String],
mode: SaveMode,
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = DataFrame(sqlContext, query)
- val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df)
+ val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index ca53dcdb92..5e010d2112 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -17,11 +17,19 @@
package org.apache.spark.sql.sources
-import org.apache.spark.annotation.{Experimental, DeveloperApi}
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SaveMode, DataFrame, Row, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.types.{StructField, StructType}
/**
* ::DeveloperApi::
@@ -78,6 +86,41 @@ trait SchemaRelationProvider {
schema: StructType): BaseRelation
}
+/**
+ * ::DeveloperApi::
+ * Implemented by objects that produce relations for a specific kind of data source
+ * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a
+ * USING clause specified (to specify the implemented [[FSBasedRelationProvider]]), a user defined
+ * schema, and an optional list of partition columns, this interface is used to pass in the
+ * parameters specified by a user.
+ *
+ * Users may specify the fully qualified class name of a given data source. When that class is
+ * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
+ * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the
+ * data source 'org.apache.spark.sql.json.DefaultSource'
+ *
+ * A new instance of this class with be instantiated each time a DDL call is made.
+ *
+ * The difference between a [[RelationProvider]] and a [[FSBasedRelationProvider]] is
+ * that users need to provide a schema and a (possibly empty) list of partition columns when
+ * using a SchemaRelationProvider. A relation provider can inherits both [[RelationProvider]],
+ * and [[FSBasedRelationProvider]] if it can support schema inference, user-specified
+ * schemas, and accessing partitioned relations.
+ */
+trait FSBasedRelationProvider {
+ /**
+ * Returns a new base relation with the given parameters, a user defined schema, and a list of
+ * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity
+ * is enforced by the Map that is passed to the function.
+ */
+ def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ schema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): FSBasedRelation
+}
+
@DeveloperApi
trait CreatableRelationProvider {
/**
@@ -207,3 +250,235 @@ trait InsertableRelation {
trait CatalystScan {
def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row]
}
+
+/**
+ * ::Experimental::
+ * [[OutputWriter]] is used together with [[FSBasedRelation]] for persisting rows to the
+ * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor.
+ * An [[OutputWriter]] instance is created and initialized when a new output file is opened on
+ * executor side. This instance is used to persist rows to this single output file.
+ */
+@Experimental
+abstract class OutputWriter {
+ /**
+ * Initializes this [[OutputWriter]] before any rows are persisted.
+ *
+ * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that
+ * this may not point to the final output file. For example, `FileOutputFormat` writes to
+ * temporary directories and then merge written files back to the final destination. In
+ * this case, `path` points to a temporary output file under the temporary directory.
+ * @param dataSchema Schema of the rows to be written. Partition columns are not included in the
+ * schema if the corresponding relation is partitioned.
+ * @param context The Hadoop MapReduce task context.
+ */
+ def init(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): Unit = ()
+
+ /**
+ * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
+ * tables, dynamic partition columns are not included in rows to be written.
+ */
+ def write(row: Row): Unit
+
+ /**
+ * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before
+ * the task output is committed.
+ */
+ def close(): Unit
+}
+
+/**
+ * ::Experimental::
+ * A [[BaseRelation]] that provides much of the common code required for formats that store their
+ * data to an HDFS compatible filesystem.
+ *
+ * For the read path, similar to [[PrunedFilteredScan]], it can eliminate unneeded columns and
+ * filter using selected predicates before producing an RDD containing all matching tuples as
+ * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file
+ * systems, it's able to discover partitioning information from the paths of input directories, and
+ * perform partition pruning before start reading the data. Subclasses of [[FSBasedRelation()]] must
+ * override one of the three `buildScan` methods to implement the read path.
+ *
+ * For the write path, it provides the ability to write to both non-partitioned and partitioned
+ * tables. Directory layout of the partitioned tables is compatible with Hive.
+ *
+ * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for
+ * implementing metastore table conversion.
+ * @param paths Base paths of this relation. For partitioned relations, it should be the root
+ * directories of all partition directories.
+ * @param maybePartitionSpec An [[FSBasedRelation]] can be created with an optional
+ * [[PartitionSpec]], so that partition discovery can be skipped.
+ */
+@Experimental
+abstract class FSBasedRelation private[sql](
+ val paths: Array[String],
+ maybePartitionSpec: Option[PartitionSpec])
+ extends BaseRelation {
+
+ /**
+ * Constructs an [[FSBasedRelation]].
+ *
+ * @param paths Base paths of this relation. For partitioned relations, it should be either root
+ * directories of all partition directories.
+ * @param partitionColumns Partition columns of this relation.
+ */
+ def this(paths: Array[String], partitionColumns: StructType) =
+ this(paths, {
+ if (partitionColumns.isEmpty) None
+ else Some(PartitionSpec(partitionColumns, Array.empty[Partition]))
+ })
+
+ /**
+ * Constructs an [[FSBasedRelation]].
+ *
+ * @param paths Base paths of this relation. For partitioned relations, it should be root
+ * directories of all partition directories.
+ */
+ def this(paths: Array[String]) = this(paths, None)
+
+ private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+
+ private val codegenEnabled = sqlContext.conf.codegenEnabled
+
+ private var _partitionSpec: PartitionSpec = maybePartitionSpec.map { spec =>
+ spec.copy(partitionColumns = spec.partitionColumns.asNullable)
+ }.getOrElse {
+ if (sqlContext.conf.partitionDiscoveryEnabled()) {
+ discoverPartitions()
+ } else {
+ PartitionSpec(StructType(Nil), Array.empty[Partition])
+ }
+ }
+
+ private[sql] def partitionSpec: PartitionSpec = _partitionSpec
+
+ /**
+ * Partition columns. Note that they are always nullable.
+ */
+ def partitionColumns: StructType = partitionSpec.partitionColumns
+
+ private[sql] def refresh(): Unit = {
+ if (sqlContext.conf.partitionDiscoveryEnabled()) {
+ _partitionSpec = discoverPartitions()
+ }
+ }
+
+ private def discoverPartitions(): PartitionSpec = {
+ val basePaths = paths.map(new Path(_))
+ val leafDirs = basePaths.flatMap { path =>
+ val fs = path.getFileSystem(hadoopConf)
+ Try(fs.getFileStatus(path.makeQualified(fs.getUri, fs.getWorkingDirectory)))
+ .filter(_.isDir)
+ .map(SparkHadoopUtil.get.listLeafDirStatuses(fs, _))
+ .getOrElse(Seq.empty[FileStatus])
+ }.map(_.getPath)
+
+ if (leafDirs.nonEmpty) {
+ PartitioningUtils.parsePartitions(leafDirs, "__HIVE_DEFAULT_PARTITION__")
+ } else {
+ PartitionSpec(StructType(Array.empty[StructField]), Array.empty[Partition])
+ }
+ }
+
+ /**
+ * Schema of this relation. It consists of columns appearing in [[dataSchema]] and all partition
+ * columns not appearing in [[dataSchema]].
+ */
+ override lazy val schema: StructType = {
+ val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
+ StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column =>
+ dataSchemaColumnNames.contains(column.name.toLowerCase)
+ })
+ }
+
+ /**
+ * Specifies schema of actual data files. For partitioned relations, if one or more partitioned
+ * columns are contained in the data files, they should also appear in `dataSchema`.
+ */
+ def dataSchema: StructType
+
+ /**
+ * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
+ * this relation. For partitioned relations, this method is called for each selected partition,
+ * and builds an `RDD[Row]` containing all rows within that single partition.
+ *
+ * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the
+ * relation. For a partitioned relation, it contains paths of all data files in a single
+ * selected partition.
+ */
+ def buildScan(inputPaths: Array[String]): RDD[Row] = {
+ throw new RuntimeException(
+ "At least one buildScan() method should be overridden to read the relation.")
+ }
+
+ /**
+ * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
+ * this relation. For partitioned relations, this method is called for each selected partition,
+ * and builds an `RDD[Row]` containing all rows within that single partition.
+ *
+ * @param requiredColumns Required columns.
+ * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the
+ * relation. For a partitioned relation, it contains paths of all data files in a single
+ * selected partition.
+ */
+ def buildScan(requiredColumns: Array[String], inputPaths: Array[String]): RDD[Row] = {
+ // Yeah, to workaround serialization...
+ val dataSchema = this.dataSchema
+ val codegenEnabled = this.codegenEnabled
+
+ val requiredOutput = requiredColumns.map { col =>
+ val field = dataSchema(col)
+ BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable)
+ }.toSeq
+
+ buildScan(inputPaths).mapPartitions { rows =>
+ val buildProjection = if (codegenEnabled) {
+ GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes)
+ } else {
+ () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes)
+ }
+
+ val mutableProjection = buildProjection()
+ rows.map(mutableProjection)
+ }
+ }
+
+ /**
+ * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
+ * this relation. For partitioned relations, this method is called for each selected partition,
+ * and builds an `RDD[Row]` containing all rows within that single partition.
+ *
+ * @param requiredColumns Required columns.
+ * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction
+ * of all `filters`. The pushed down filters are currently purely an optimization as they
+ * will all be evaluated again. This means it is safe to use them with methods that produce
+ * false positives such as filtering partitions based on a bloom filter.
+ * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the
+ * relation. For a partitioned relation, it contains paths of all data files in a single
+ * selected partition.
+ */
+ def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ inputPaths: Array[String]): RDD[Row] = {
+ buildScan(requiredColumns, inputPaths)
+ }
+
+ /**
+ * Client side preparation for data writing can be put here. For example, user defined output
+ * committer can be configured here.
+ *
+ * Note that the only side effect expected here is mutating `job` via its setters. Especially,
+ * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states
+ * may cause unexpected behaviors.
+ */
+ def prepareForWrite(job: Job): Unit = ()
+
+ /**
+ * This method is responsible for producing a new [[OutputWriter]] for each newly opened output
+ * file on the executor side.
+ */
+ def outputWriterClass: Class[_ <: OutputWriter]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
index 6ed68d179e..aad1d248d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -101,13 +101,13 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
}
- case i @ logical.InsertIntoTable(
- l: LogicalRelation, partition, query, overwrite, ifNotExists)
- if !l.isInstanceOf[InsertableRelation] =>
+ case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) => // OK
+ case logical.InsertIntoTable(LogicalRelation(_: FSBasedRelation), _, _, _, _) => // OK
+ case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) =>
// The relation in l is not an InsertableRelation.
failAnalysis(s"$l does not allow insertion.")
- case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, query) =>
+ case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) =>
// When the SaveMode is Overwrite, we need to check if the table is an input table of
// the query. If so, we will throw an AnalysisException to let users know it is not allowed.
if (catalog.tableExists(Seq(tableName))) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index b7561ce729..bea568ed40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -21,7 +21,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.parquet.ParquetRelation2._
+import org.apache.spark.sql.sources.PartitioningUtils._
+import org.apache.spark.sql.sources.{Partition, PartitionSpec}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLContext}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 54f2f3cdec..4e54b2eb8d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.sources
-import java.io.{IOException, File}
+import java.io.{File, IOException}
-import org.apache.spark.sql.AnalysisException
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.util.Utils
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {