aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-11-02 15:08:35 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-02 15:08:35 -0800
commit9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2 (patch)
tree285005822bdff25308e471738bda5e5e157cd9b8 /sql
parentf0a4b630abf0766cc0c41e682691e0d435caca04 (diff)
downloadspark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.tar.gz
spark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.tar.bz2
spark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.zip
[SPARK-3247][SQL] An API for adding data sources to Spark SQL
This PR introduces a new set of APIs to Spark SQL to allow other developers to add support for reading data from new sources in `org.apache.spark.sql.sources`. New sources must implement the interface `BaseRelation`, which is responsible for describing the schema of the data. BaseRelations have three `Scan` subclasses, which are responsible for producing an RDD containing row objects. The [various Scan interfaces](https://github.com/marmbrus/spark/blob/foreign/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala#L50) allow for optimizations such as column pruning and filter push down, when the underlying data source can handle these operations. By implementing a class that inherits from RelationProvider these data sources can be accessed using using pure SQL. I've used the functionality to update the JSON support so it can now be used in this way as follows: ```sql CREATE TEMPORARY TABLE jsonTableSQL USING org.apache.spark.sql.json OPTIONS ( path '/home/michael/data.json' ) ``` Further example usage can be found in the test cases: https://github.com/marmbrus/spark/tree/foreign/sql/core/src/test/scala/org/apache/spark/sql/sources There is also a library that uses this new API to read avro data available here: https://github.com/marmbrus/sql-avro Author: Michael Armbrust <michael@databricks.com> Closes #2475 from marmbrus/foreign and squashes the following commits: 1ed6010 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign ab2c31f [Michael Armbrust] fix test 1d41bb5 [Michael Armbrust] unify argument names 5b47901 [Michael Armbrust] Remove sealed, more filter types fab154a [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign e3e690e [Michael Armbrust] Add hook for extraStrategies a70d602 [Michael Armbrust] Fix style, more tests, FilteredSuite => PrunedFilteredSuite 70da6d9 [Michael Armbrust] Modify API to ease binary compatibility and interop with Java 7d948ae [Michael Armbrust] Fix equality of AttributeReference. 5545491 [Michael Armbrust] Address comments 5031ac3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign 22963ef [Michael Armbrust] package objects compile wierdly... b069146 [Michael Armbrust] traits => abstract classes 34f836a [Michael Armbrust] Make @DeveloperApi 0d74bcf [Michael Armbrust] Add documention on object life cycle 3e06776 [Michael Armbrust] remove line wraps de3b68c [Michael Armbrust] Remove empty file 360cb30 [Michael Armbrust] style and java api 2957875 [Michael Armbrust] add override 0fd3a07 [Michael Armbrust] Draft of data sources API
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala112
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala108
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala86
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala176
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala137
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala125
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala2
26 files changed, 1074 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 3310566087..fc90a54a58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -134,7 +134,7 @@ case class AttributeReference(
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
override def equals(other: Any) = other match {
- case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
+ case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
index bdd07bbeb2..a38079ced3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+/**
+ * Catalyst is a library for manipulating relational query plans. All classes in catalyst are
+ * considered an internal API to Spark SQL and are subject to change between minor releases.
+ */
package object catalyst {
/**
* A JVM-global lock that should be used to prevent thread safety issues when using things in
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
index 5839c9f7c4..51b5699aff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
@@ -22,6 +22,15 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
/**
+ * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can
+ * be used for execution. If this strategy does not apply to the give logical operation then an
+ * empty list should be returned.
+ */
+abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging {
+ def apply(plan: LogicalPlan): Seq[PhysicalPlan]
+}
+
+/**
* Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans.
* Child classes are responsible for specifying a list of [[Strategy]] objects that each of which
* can return a list of possible physical plan options. If a given strategy is unable to plan all
@@ -35,16 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
*/
abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
/** A list of execution strategies that can be used by the planner */
- def strategies: Seq[Strategy]
-
- /**
- * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can
- * be used for execution. If this strategy does not apply to the give logical operation then an
- * empty list should be returned.
- */
- abstract protected class Strategy extends Logging {
- def apply(plan: LogicalPlan): Seq[PhysicalPlan]
- }
+ def strategies: Seq[GenericStrategy[PhysicalPlan]]
/**
* Returns a placeholder for a physical plan that executes `plan`. This placeholder will be
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 8dda0b1828..d25f3a619d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -455,7 +455,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
case class StructField(
name: String,
dataType: DataType,
- nullable: Boolean,
+ nullable: Boolean = true,
metadata: Metadata = Metadata.empty) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4953f8399a..4cded98c80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
+import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation}
/**
* :: AlphaComponent ::
@@ -69,12 +70,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
@transient
+ protected[sql] val ddlParser = new DDLParser
+
+ @transient
protected[sql] val sqlParser = {
val fallback = new catalyst.SqlParser
new catalyst.SparkSQLParser(fallback(_))
}
- protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = {
+ ddlParser(sql).getOrElse(sqlParser(sql))
+ }
+
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
@@ -104,6 +111,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
}
+ implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {
+ logicalPlanToSparkQuery(LogicalRelation(baseRelation))
+ }
+
/**
* :: DeveloperApi ::
* Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
@@ -283,6 +294,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): SchemaRDD =
new SchemaRDD(this, catalog.lookupRelation(None, tableName))
+ /**
+ * :: DeveloperApi ::
+ * Allows extra strategies to be injected into the query planner at runtime. Note this API
+ * should be consider experimental and is not intended to be stable across releases.
+ */
+ @DeveloperApi
+ var extraStrategies: Seq[Strategy] = Nil
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
@@ -293,7 +312,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
def numPartitions = self.numShufflePartitions
val strategies: Seq[Strategy] =
+ extraStrategies ++ (
CommandStrategy(self) ::
+ DataSourceStrategy ::
TakeOrdered ::
HashAggregation ::
LeftSemiJoin ::
@@ -302,7 +323,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
- BroadcastNestedLoopJoin :: Nil
+ BroadcastNestedLoopJoin :: Nil)
/**
* Used to build table scan operators where complex projection and filtering are done using
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 876b1c6ede..60065509bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation}
import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
@@ -39,6 +40,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc))
+ def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = {
+ new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation))
+ }
+
/**
* Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 04c51a1ee4..d64c5af89e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -50,12 +50,6 @@ object RDDConversions {
}
}
}
-
- /*
- def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = {
- LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
- }
- */
}
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 79e4ddb8c4..2cd3063bc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{SQLContext, execution}
+import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
@@ -304,6 +304,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case class CommandStrategy(context: SQLContext) extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case r: RunnableCommand => ExecutedCommand(r) :: Nil
case logical.SetCommand(kv) =>
Seq(execution.SetCommand(kv, plan.output)(context))
case logical.ExplainCommand(logicalPlan, extended) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 5859eba408..e658e6fc4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -21,10 +21,12 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
+import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, SQLConf, SQLContext}
+import org.apache.spark.sql.{SQLConf, SQLContext}
+// TODO: DELETE ME...
trait Command {
this: SparkPlan =>
@@ -44,6 +46,35 @@ trait Command {
override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
+// TODO: Replace command with runnable command.
+trait RunnableCommand extends logical.Command {
+ self: Product =>
+
+ def output: Seq[Attribute]
+ def run(sqlContext: SQLContext): Seq[Row]
+}
+
+case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
+ /**
+ * A concrete command should override this lazy field to wrap up any side effects caused by the
+ * command or any other computation that should be evaluated exactly once. The value of this field
+ * can be used as the contents of the corresponding RDD generated from the physical plan of this
+ * command.
+ *
+ * The `execute()` method of all the physical command classes should reference `sideEffectResult`
+ * so that the command can be executed eagerly right after the command query is created.
+ */
+ protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext)
+
+ override def output = cmd.output
+
+ override def children = Nil
+
+ override def executeCollect(): Array[Row] = sideEffectResult.toArray
+
+ override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+}
+
/**
* :: DeveloperApi ::
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
new file mode 100644
index 0000000000..fc70c18343
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.json
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources._
+
+private[sql] class DefaultSource extends RelationProvider {
+ /** Returns a new base relation with the given parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
+ JSONRelation(fileName, samplingRatio)(sqlContext)
+ }
+}
+
+private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
+ @transient val sqlContext: SQLContext)
+ extends TableScan {
+
+ private def baseRDD = sqlContext.sparkContext.textFile(fileName)
+
+ override val schema =
+ JsonRDD.inferSchema(
+ baseRDD,
+ samplingRatio,
+ sqlContext.columnNameOfCorruptRecord)
+
+ override def buildScan() =
+ JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 05926a24c5..51dad54f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.execution.SparkPlan
/**
* Allows the execution of relational queries, including those expressed in SQL using Spark.
@@ -433,6 +434,12 @@ package object sql {
val StructField = catalyst.types.StructField
/**
+ * Converts a logical plan into zero or more SparkPlans.
+ */
+ @DeveloperApi
+ type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+
+ /**
* :: DeveloperApi ::
*
* Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
@@ -448,7 +455,9 @@ package object sql {
type Metadata = catalyst.util.Metadata
/**
+ * :: DeveloperApi ::
* Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former.
*/
+ @DeveloperApi
type MetadataBuilder = catalyst.util.MetadataBuilder
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
new file mode 100644
index 0000000000..9b8c6a56b9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * A Strategy for planning scans over data sources defined using the sources API.
+ */
+private[sql] object DataSourceStrategy extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) =>
+ pruneFilterProject(
+ l,
+ projectList,
+ filters,
+ (a, f) => t.buildScan(a, f)) :: Nil
+
+ case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) =>
+ pruneFilterProject(
+ l,
+ projectList,
+ filters,
+ (a, _) => t.buildScan(a)) :: Nil
+
+ case l @ LogicalRelation(t: TableScan) =>
+ execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
+
+ case _ => Nil
+ }
+
+ protected def pruneFilterProject(
+ relation: LogicalRelation,
+ projectList: Seq[NamedExpression],
+ filterPredicates: Seq[Expression],
+ scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = {
+
+ val projectSet = AttributeSet(projectList.flatMap(_.references))
+ val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
+ val filterCondition = filterPredicates.reduceLeftOption(And)
+
+ val pushedFilters = selectFilters(filterPredicates.map { _ transform {
+ case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
+ }}).toArray
+
+ if (projectList.map(_.toAttribute) == projectList &&
+ projectSet.size == projectList.size &&
+ filterSet.subsetOf(projectSet)) {
+ // When it is possible to just use column pruning to get the right projection and
+ // when the columns of this projection are enough to evaluate all filter conditions,
+ // just do a scan followed by a filter, with no extra project.
+ val requestedColumns =
+ projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above.
+ .map(relation.attributeMap) // Match original case of attributes.
+ .map(_.name)
+ .toArray
+
+ val scan =
+ execution.PhysicalRDD(
+ projectList.map(_.toAttribute),
+ scanBuilder(requestedColumns, pushedFilters))
+ filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
+ } else {
+ val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
+ val columnNames = requestedColumns.map(_.name).toArray
+
+ val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters))
+ execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
+ }
+ }
+
+ protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
+ case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v)
+ case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v)
+
+ case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v)
+ case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v)
+
+ case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v)
+ case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
+
+ case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ GreaterThanOrEqual(a.name, v)
+ case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ LessThanOrEqual(a.name, v)
+
+ case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala
new file mode 100644
index 0000000000..82a2cf8402
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.expressions.AttributeMap
+import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan}
+
+/**
+ * Used to link a [[BaseRelation]] in to a logical query plan.
+ */
+private[sql] case class LogicalRelation(relation: BaseRelation)
+ extends LeafNode
+ with MultiInstanceRelation {
+
+ override val output = relation.schema.toAttributes
+
+ // Logical Relations are distinct if they have different output for the sake of transformations.
+ override def equals(other: Any) = other match {
+ case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output
+ case _ => false
+ }
+
+ override def sameResult(otherPlan: LogicalPlan) = otherPlan match {
+ case LogicalRelation(otherRelation) => relation == otherRelation
+ case _ => false
+ }
+
+ @transient override lazy val statistics = Statistics(
+ // TODO: Allow datasources to provide statistics as well.
+ sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes)
+ )
+
+ /** Used to lookup original attribute capitalization */
+ val attributeMap = AttributeMap(output.map(o => (o, o)))
+
+ def newInstance() = LogicalRelation(relation).asInstanceOf[this.type]
+
+ override def simpleString = s"Relation[${output.mkString(",")}] $relation"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
new file mode 100644
index 0000000000..9168ca2fc6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.util.Utils
+
+import scala.language.implicitConversions
+import scala.util.parsing.combinator.lexical.StdLexical
+import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.combinator.PackratParsers
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.SqlLexical
+
+/**
+ * A parser for foreign DDL commands.
+ */
+private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging {
+
+ def apply(input: String): Option[LogicalPlan] = {
+ phrase(ddl)(new lexical.Scanner(input)) match {
+ case Success(r, x) => Some(r)
+ case x =>
+ logDebug(s"Not recognized as DDL: $x")
+ None
+ }
+ }
+
+ protected case class Keyword(str: String)
+
+ protected implicit def asParser(k: Keyword): Parser[String] =
+ lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+
+ protected val CREATE = Keyword("CREATE")
+ protected val TEMPORARY = Keyword("TEMPORARY")
+ protected val TABLE = Keyword("TABLE")
+ protected val USING = Keyword("USING")
+ protected val OPTIONS = Keyword("OPTIONS")
+
+ // Use reflection to find the reserved words defined in this class.
+ protected val reservedWords =
+ this.getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].str)
+
+ override val lexical = new SqlLexical(reservedWords)
+
+ protected lazy val ddl: Parser[LogicalPlan] = createTable
+
+ /**
+ * CREATE FOREIGN TEMPORARY TABLE avroTable
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
+ */
+ protected lazy val createTable: Parser[LogicalPlan] =
+ CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
+ case tableName ~ provider ~ opts =>
+ CreateTableUsing(tableName, provider, opts)
+ }
+
+ protected lazy val options: Parser[Map[String, String]] =
+ "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
+
+ protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
+
+ protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }
+}
+
+private[sql] case class CreateTableUsing(
+ tableName: String,
+ provider: String,
+ options: Map[String, String]) extends RunnableCommand {
+
+ def run(sqlContext: SQLContext) = {
+ val loader = Utils.getContextOrSparkClassLoader
+ val clazz: Class[_] = try loader.loadClass(provider) catch {
+ case cnf: java.lang.ClassNotFoundException =>
+ try loader.loadClass(provider + ".DefaultSource") catch {
+ case cnf: java.lang.ClassNotFoundException =>
+ sys.error(s"Failed to load class for data source: $provider")
+ }
+ }
+ val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
+ val relation = dataSource.createRelation(sqlContext, options)
+
+ sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
+ Seq.empty
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
new file mode 100644
index 0000000000..e72a2aeb8f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+abstract class Filter
+
+case class EqualTo(attribute: String, value: Any) extends Filter
+case class GreaterThan(attribute: String, value: Any) extends Filter
+case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
+case class LessThan(attribute: String, value: Any) extends Filter
+case class LessThanOrEqual(attribute: String, value: Any) extends Filter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
new file mode 100644
index 0000000000..ac3bf9d8e1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -0,0 +1,86 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.sql.sources
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, StructType}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
+
+/**
+ * Implemented by objects that produce relations for a specific kind of data source. When
+ * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to
+ * pass in the parameters specified by a user.
+ *
+ * Users may specify the fully qualified class name of a given data source. When that class is
+ * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
+ * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the
+ * data source 'org.apache.spark.sql.json.DefaultSource'
+ *
+ * A new instance of this class with be instantiated each time a DDL call is made.
+ */
+@DeveloperApi
+trait RelationProvider {
+ /** Returns a new base relation with the given parameters. */
+ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
+}
+
+/**
+ * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must
+ * be able to produce the schema of their data in the form of a [[StructType]] Concrete
+ * implementation should inherit from one of the descendant `Scan` classes, which define various
+ * abstract methods for execution.
+ *
+ * BaseRelations must also define a equality function that only returns true when the two
+ * instances will return the same data. This equality function is used when determining when
+ * it is safe to substitute cached results for a given relation.
+ */
+@DeveloperApi
+abstract class BaseRelation {
+ def sqlContext: SQLContext
+ def schema: StructType
+}
+
+/**
+ * A BaseRelation that can produce all of its tuples as an RDD of Row objects.
+ */
+@DeveloperApi
+abstract class TableScan extends BaseRelation {
+ def buildScan(): RDD[Row]
+}
+
+/**
+ * A BaseRelation that can eliminate unneeded columns before producing an RDD
+ * containing all of its tuples as Row objects.
+ */
+@DeveloperApi
+abstract class PrunedScan extends BaseRelation {
+ def buildScan(requiredColumns: Array[String]): RDD[Row]
+}
+
+/**
+ * A BaseRelation that can eliminate unneeded columns and filter using selected
+ * predicates before producing an RDD containing all matching tuples as Row objects.
+ *
+ * The pushed down filters are currently purely an optimization as they will all be evaluated
+ * again. This means it is safe to use them with methods that produce false positives such
+ * as filtering partitions based on a bloom filter.
+ */
+@DeveloperApi
+abstract class PrunedFilteredScan extends BaseRelation {
+ def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala
new file mode 100644
index 0000000000..8393c510f4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+/**
+ * A set of APIs for adding data sources to Spark SQL.
+ */
+package object sources
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1a5d87d524..44a2961b27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -27,18 +27,6 @@ case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
- val planWithCaching = query.queryExecution.withCachedData
- val cachedData = planWithCaching collect {
- case cached: InMemoryRelation => cached
- }
-
- assert(
- cachedData.size == numCachedTables,
- s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
- planWithCaching)
- }
-
def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
executedPlan.collect {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 042f61f5a4..3d9f0cbf80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -19,8 +19,10 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.columnar.InMemoryRelation
class QueryTest extends PlanTest {
+
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
@@ -78,11 +80,31 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ s"== Correct Answer - ${convertedAnswer.size} ==" +:
+ prepareAnswer(convertedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
+
+ def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
+ /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */
+ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 1cb6c23c58..362c7e1a52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -549,6 +549,32 @@ class JsonSuite extends QueryTest {
)
}
+ test("Loading a JSON dataset from a text file with SQL") {
+ val file = getTempFilePath("json")
+ val path = file.toString
+ primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
+
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTableSQL
+ |USING org.apache.spark.sql.json
+ |OPTIONS (
+ | path '$path'
+ |)
+ """.stripMargin)
+
+ checkAnswer(
+ sql("select * from jsonTableSQL"),
+ (BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.") :: Nil
+ )
+ }
+
test("Applying schemas") {
val file = getTempFilePath("json")
val path = file.toString
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
new file mode 100644
index 0000000000..9626252e74
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -0,0 +1,34 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Analyzer
+import org.apache.spark.sql.test.TestSQLContext
+import org.scalatest.BeforeAndAfter
+
+abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
+ // Case sensitivity is not configurable yet, but we want to test some edge cases.
+ // TODO: Remove when it is configurable
+ implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) {
+ @transient
+ override protected[sql] lazy val analyzer: Analyzer =
+ new Analyzer(catalog, functionRegistry, caseSensitive = false)
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
new file mode 100644
index 0000000000..8b2f1591d5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -0,0 +1,176 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import scala.language.existentials
+
+import org.apache.spark.sql._
+
+class FilteredScanSource extends RelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ }
+}
+
+case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+ extends PrunedFilteredScan {
+
+ override def schema =
+ StructType(
+ StructField("a", IntegerType, nullable = false) ::
+ StructField("b", IntegerType, nullable = false) :: Nil)
+
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
+ val rowBuilders = requiredColumns.map {
+ case "a" => (i: Int) => Seq(i)
+ case "b" => (i: Int) => Seq(i * 2)
+ }
+
+ FiltersPushed.list = filters
+
+ val filterFunctions = filters.collect {
+ case EqualTo("a", v) => (a: Int) => a == v
+ case LessThan("a", v: Int) => (a: Int) => a < v
+ case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
+ case GreaterThan("a", v: Int) => (a: Int) => a > v
+ case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
+ }
+
+ def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
+
+ sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
+ Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
+ }
+}
+
+// A hack for better error messages when filter pushdown fails.
+object FiltersPushed {
+ var list: Seq[Filter] = Nil
+}
+
+class FilteredScanSuite extends DataSourceTest {
+
+ import caseInsensisitiveContext._
+
+ before {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTenFiltered
+ |USING org.apache.spark.sql.sources.FilteredScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+ }
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT a, b FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT b, a FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i * 2, i)).toSeq)
+
+ sqlTest(
+ "SELECT a FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ "SELECT b FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT a * 2 FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT A AS b FROM oneToTenFiltered",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b",
+ (1 to 5).map(i => Row(i * 4, i)).toSeq)
+
+ sqlTest(
+ "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b",
+ (2 to 10 by 2).map(i => Row(i, i)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a = 1",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE A = 1",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE b = 2",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
+ testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
+ testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
+ testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+
+ def testPushDown(sqlString: String, expectedCount: Int): Unit = {
+ test(s"PushDown Returns $expectedCount: $sqlString") {
+ val queryExecution = sql(sqlString).queryExecution
+ val rawPlan = queryExecution.executedPlan.collect {
+ case p: execution.PhysicalRDD => p
+ } match {
+ case Seq(p) => p
+ case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
+ }
+ val rawCount = rawPlan.execute().count()
+
+ if (rawCount != expectedCount) {
+ fail(
+ s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" +
+ s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" +
+ queryExecution)
+ }
+ }
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
new file mode 100644
index 0000000000..fee2e22611
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -0,0 +1,137 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql._
+
+class PrunedScanSource extends RelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ }
+}
+
+case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+ extends PrunedScan {
+
+ override def schema =
+ StructType(
+ StructField("a", IntegerType, nullable = false) ::
+ StructField("b", IntegerType, nullable = false) :: Nil)
+
+ override def buildScan(requiredColumns: Array[String]) = {
+ val rowBuilders = requiredColumns.map {
+ case "a" => (i: Int) => Seq(i)
+ case "b" => (i: Int) => Seq(i * 2)
+ }
+
+ sqlContext.sparkContext.parallelize(from to to).map(i =>
+ Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
+ }
+}
+
+class PrunedScanSuite extends DataSourceTest {
+ import caseInsensisitiveContext._
+
+ before {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTenPruned
+ |USING org.apache.spark.sql.sources.PrunedScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+ }
+
+ sqlTest(
+ "SELECT * FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT a, b FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT b, a FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i * 2, i)).toSeq)
+
+ sqlTest(
+ "SELECT a FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ "SELECT a, a FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i, i)).toSeq)
+
+ sqlTest(
+ "SELECT b FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT a * 2 FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT A AS b FROM oneToTenPruned",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b",
+ (1 to 5).map(i => Row(i * 4, i)).toSeq)
+
+ sqlTest(
+ "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b",
+ (2 to 10 by 2).map(i => Row(i, i)).toSeq)
+
+ testPruning("SELECT * FROM oneToTenPruned", "a", "b")
+ testPruning("SELECT a, b FROM oneToTenPruned", "a", "b")
+ testPruning("SELECT b, a FROM oneToTenPruned", "b", "a")
+ testPruning("SELECT b, b FROM oneToTenPruned", "b")
+ testPruning("SELECT a FROM oneToTenPruned", "a")
+ testPruning("SELECT b FROM oneToTenPruned", "b")
+
+ def testPruning(sqlString: String, expectedColumns: String*): Unit = {
+ test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
+ val queryExecution = sql(sqlString).queryExecution
+ val rawPlan = queryExecution.executedPlan.collect {
+ case p: execution.PhysicalRDD => p
+ } match {
+ case Seq(p) => p
+ case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
+ }
+ val rawColumns = rawPlan.output.map(_.name)
+ val rawOutput = rawPlan.execute().first()
+
+ if (rawColumns != expectedColumns) {
+ fail(
+ s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" +
+ s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" +
+ queryExecution)
+ }
+
+ if (rawOutput.size != expectedColumns.size) {
+ fail(s"Wrong output row. Got $rawOutput\n$queryExecution")
+ }
+ }
+ }
+
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
new file mode 100644
index 0000000000..b254b0620c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -0,0 +1,125 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql._
+
+class DefaultSource extends SimpleScanSource
+
+class SimpleScanSource extends RelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ }
+}
+
+case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+ extends TableScan {
+
+ override def schema =
+ StructType(StructField("i", IntegerType, nullable = false) :: Nil)
+
+ override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
+}
+
+class TableScanSuite extends DataSourceTest {
+ import caseInsensisitiveContext._
+
+ before {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTen
+ |USING org.apache.spark.sql.sources.SimpleScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+ }
+
+ sqlTest(
+ "SELECT * FROM oneToTen",
+ (1 to 10).map(Row(_)).toSeq)
+
+ sqlTest(
+ "SELECT i FROM oneToTen",
+ (1 to 10).map(Row(_)).toSeq)
+
+ sqlTest(
+ "SELECT i FROM oneToTen WHERE i < 5",
+ (1 to 4).map(Row(_)).toSeq)
+
+ sqlTest(
+ "SELECT i * 2 FROM oneToTen",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1",
+ (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+
+ test("Caching") {
+ // Cached Query Execution
+ cacheTable("oneToTen")
+ assertCached(sql("SELECT * FROM oneToTen"))
+ checkAnswer(
+ sql("SELECT * FROM oneToTen"),
+ (1 to 10).map(Row(_)).toSeq)
+
+ assertCached(sql("SELECT i FROM oneToTen"))
+ checkAnswer(
+ sql("SELECT i FROM oneToTen"),
+ (1 to 10).map(Row(_)).toSeq)
+
+ assertCached(sql("SELECT i FROM oneToTen WHERE i < 5"))
+ checkAnswer(
+ sql("SELECT i FROM oneToTen WHERE i < 5"),
+ (1 to 4).map(Row(_)).toSeq)
+
+ assertCached(sql("SELECT i * 2 FROM oneToTen"))
+ checkAnswer(
+ sql("SELECT i * 2 FROM oneToTen"),
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
+ checkAnswer(
+ sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
+ (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+ // Verify uncaching
+ uncacheTable("oneToTen")
+ assertCached(sql("SELECT * FROM oneToTen"), 0)
+ }
+
+ test("defaultSource") {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTenDef
+ |USING org.apache.spark.sql.sources
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT * FROM oneToTenDef"),
+ (1 to 10).map(Row(_)).toSeq)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2e27817d60..dca5367f24 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.ExtractPythonUdfs
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.{Command => PhysicalCommand}
import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand
+import org.apache.spark.sql.sources.DataSourceStrategy
/**
* DEPRECATED: Use HiveContext instead.
@@ -99,7 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
if (dialect == "sql") {
super.sql(sqlText)
} else if (dialect == "hiveql") {
- new SchemaRDD(this, HiveQl.parseSql(sqlText))
+ new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText)))
} else {
sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'")
}
@@ -345,7 +346,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self
- override val strategies: Seq[Strategy] = Seq(
+ override val strategies: Seq[Strategy] = extraStrategies ++ Seq(
+ DataSourceStrategy,
CommandStrategy(self),
HiveCommandStrategy(self),
TakeOrdered,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 3207ad81d9..989740c8d4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
import org.apache.spark.sql.hive
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy}
import scala.collection.JavaConversions._