diff options
author | Michael Armbrust <michael@databricks.com> | 2014-11-02 15:08:35 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-02 15:08:35 -0800 |
commit | 9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2 (patch) | |
tree | 285005822bdff25308e471738bda5e5e157cd9b8 /sql/core | |
parent | f0a4b630abf0766cc0c41e682691e0d435caca04 (diff) | |
download | spark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.tar.gz spark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.tar.bz2 spark-9c0eb57c737dd7d97d2cbd4516ddd2cf5d06e4b2.zip |
[SPARK-3247][SQL] An API for adding data sources to Spark SQL
This PR introduces a new set of APIs to Spark SQL to allow other developers to add support for reading data from new sources in `org.apache.spark.sql.sources`.
New sources must implement the interface `BaseRelation`, which is responsible for describing the schema of the data. BaseRelations have three `Scan` subclasses, which are responsible for producing an RDD containing row objects. The [various Scan interfaces](https://github.com/marmbrus/spark/blob/foreign/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala#L50) allow for optimizations such as column pruning and filter push down, when the underlying data source can handle these operations.
By implementing a class that inherits from RelationProvider these data sources can be accessed using using pure SQL. I've used the functionality to update the JSON support so it can now be used in this way as follows:
```sql
CREATE TEMPORARY TABLE jsonTableSQL
USING org.apache.spark.sql.json
OPTIONS (
path '/home/michael/data.json'
)
```
Further example usage can be found in the test cases: https://github.com/marmbrus/spark/tree/foreign/sql/core/src/test/scala/org/apache/spark/sql/sources
There is also a library that uses this new API to read avro data available here:
https://github.com/marmbrus/sql-avro
Author: Michael Armbrust <michael@databricks.com>
Closes #2475 from marmbrus/foreign and squashes the following commits:
1ed6010 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign
ab2c31f [Michael Armbrust] fix test
1d41bb5 [Michael Armbrust] unify argument names
5b47901 [Michael Armbrust] Remove sealed, more filter types
fab154a [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign
e3e690e [Michael Armbrust] Add hook for extraStrategies
a70d602 [Michael Armbrust] Fix style, more tests, FilteredSuite => PrunedFilteredSuite
70da6d9 [Michael Armbrust] Modify API to ease binary compatibility and interop with Java
7d948ae [Michael Armbrust] Fix equality of AttributeReference.
5545491 [Michael Armbrust] Address comments
5031ac3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into foreign
22963ef [Michael Armbrust] package objects compile wierdly...
b069146 [Michael Armbrust] traits => abstract classes
34f836a [Michael Armbrust] Make @DeveloperApi
0d74bcf [Michael Armbrust] Add documention on object life cycle
3e06776 [Michael Armbrust] remove line wraps
de3b68c [Michael Armbrust] Remove empty file
360cb30 [Michael Armbrust] style and java api
2957875 [Michael Armbrust] add override
0fd3a07 [Michael Armbrust] Draft of data sources API
Diffstat (limited to 'sql/core')
20 files changed, 1053 insertions, 27 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4953f8399a..4cded98c80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -69,12 +70,18 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient + protected[sql] val ddlParser = new DDLParser + + @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser new catalyst.SparkSQLParser(fallback(_)) } - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = { + ddlParser(sql).getOrElse(sqlParser(sql)) + } + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -104,6 +111,10 @@ class SQLContext(@transient val sparkContext: SparkContext) LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } + implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + logicalPlanToSparkQuery(LogicalRelation(baseRelation)) + } + /** * :: DeveloperApi :: * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. @@ -283,6 +294,14 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * :: DeveloperApi :: + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @DeveloperApi + var extraStrategies: Seq[Strategy] = Nil + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -293,7 +312,9 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = + extraStrategies ++ ( CommandStrategy(self) :: + DataSourceStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: @@ -302,7 +323,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ParquetOperations :: BasicOperators :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil + BroadcastNestedLoopJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 876b1c6ede..60065509bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} @@ -39,6 +40,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation)) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4..d64c5af89e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -50,12 +50,6 @@ object RDDConversions { } } } - - /* - def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { - LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } - */ } case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4..2cd3063bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -304,6 +304,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.SetCommand(kv) => Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5859eba408..e658e6fc4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,10 +21,12 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{SQLConf, SQLContext} +// TODO: DELETE ME... trait Command { this: SparkPlan => @@ -44,6 +46,35 @@ trait Command { override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } +// TODO: Replace command with runnable command. +trait RunnableCommand extends logical.Command { + self: Product => + + def output: Seq[Attribute] + def run(sqlContext: SQLContext): Seq[Row] +} + +case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + + override def output = cmd.output + + override def children = Nil + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) +} + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala new file mode 100644 index 0000000000..fc70c18343 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio)(sqlContext) + } +} + +private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( + @transient val sqlContext: SQLContext) + extends TableScan { + + private def baseRDD = sqlContext.sparkContext.textFile(fileName) + + override val schema = + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord) + + override def buildScan() = + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 05926a24c5..51dad54f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.SparkPlan /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -433,6 +434,12 @@ package object sql { val StructField = catalyst.types.StructField /** + * Converts a logical plan into zero or more SparkPlans. + */ + @DeveloperApi + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + + /** * :: DeveloperApi :: * * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, @@ -448,7 +455,9 @@ package object sql { type Metadata = catalyst.util.Metadata /** + * :: DeveloperApi :: * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ + @DeveloperApi type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala new file mode 100644 index 0000000000..9b8c6a56b9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan + +/** + * A Strategy for planning scans over data sources defined using the sources API. + */ +private[sql] object DataSourceStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, _) => t.buildScan(a)) :: Nil + + case l @ LogicalRelation(t: TableScan) => + execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + + protected def pruneFilterProject( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = filterPredicates.reduceLeftOption(And) + + val pushedFilters = selectFilters(filterPredicates.map { _ transform { + case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. + }}).toArray + + if (projectList.map(_.toAttribute) == projectList && + projectSet.size == projectList.size && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val requestedColumns = + projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + .map(relation.attributeMap) // Match original case of attributes. + .map(_.name) + .toArray + + val scan = + execution.PhysicalRDD( + projectList.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) + filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + } else { + val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + val columnNames = requestedColumns.map(_.name).toArray + + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + } + } + + protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + GreaterThanOrEqual(a.name, v) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + LessThanOrEqual(a.name, v) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala new file mode 100644 index 0000000000..82a2cf8402 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} + +/** + * Used to link a [[BaseRelation]] in to a logical query plan. + */ +private[sql] case class LogicalRelation(relation: BaseRelation) + extends LeafNode + with MultiInstanceRelation { + + override val output = relation.schema.toAttributes + + // Logical Relations are distinct if they have different output for the sake of transformations. + override def equals(other: Any) = other match { + case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output + case _ => false + } + + override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + case LogicalRelation(otherRelation) => relation == otherRelation + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Allow datasources to provide statistics as well. + sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + ) + + /** Used to lookup original attribute capitalization */ + val attributeMap = AttributeMap(output.map(o => (o, o))) + + def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + + override def simpleString = s"Relation[${output.mkString(",")}] $relation" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala new file mode 100644 index 0000000000..9168ca2fc6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser for foreign DDL commands. + */ +private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { + + def apply(input: String): Option[LogicalPlan] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + logDebug(s"Not recognized as DDL: $x") + None + } + } + + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val ddl: Parser[LogicalPlan] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[LogicalPlan] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ provider ~ opts => + CreateTableUsing(tableName, provider, opts) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } +} + +private[sql] case class CreateTableUsing( + tableName: String, + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try loader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try loader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + val relation = dataSource.createRelation(sqlContext, options) + + sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala new file mode 100644 index 0000000000..e72a2aeb8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +abstract class Filter + +case class EqualTo(attribute: String, value: Any) extends Filter +case class GreaterThan(attribute: String, value: Any) extends Filter +case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter +case class LessThan(attribute: String, value: Any) extends Filter +case class LessThanOrEqual(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala new file mode 100644 index 0000000000..ac3bf9d8e1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.sources + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation +} + +/** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * implementation should inherit from one of the descendant `Scan` classes, which define various + * abstract methods for execution. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ +@DeveloperApi +abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType +} + +/** + * A BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ +@DeveloperApi +abstract class TableScan extends BaseRelation { + def buildScan(): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ +@DeveloperApi +abstract class PrunedScan extends BaseRelation { + def buildScan(requiredColumns: Array[String]): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ +@DeveloperApi +abstract class PrunedFilteredScan extends BaseRelation { + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala new file mode 100644 index 0000000000..8393c510f4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +/** + * A set of APIs for adding data sources to Spark SQL. + */ +package object sources diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1a5d87d524..44a2961b27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,18 +27,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 042f61f5a4..3d9f0cbf80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -78,11 +80,31 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${convertedAnswer.size} ==" +: + prepareAnswer(convertedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1cb6c23c58..362c7e1a52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -549,6 +549,32 @@ class JsonSuite extends QueryTest { ) } + test("Loading a JSON dataset from a text file with SQL") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + test("Applying schemas") { val file = getTempFilePath("json") val path = file.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala new file mode 100644 index 0000000000..9626252e74 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfter + +abstract class DataSourceTest extends QueryTest with BeforeAndAfter { + // Case sensitivity is not configurable yet, but we want to test some edge cases. + // TODO: Remove when it is configurable + implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { + @transient + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala new file mode 100644 index 0000000000..8b2f1591d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -0,0 +1,176 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import scala.language.existentials + +import org.apache.spark.sql._ + +class FilteredScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedFilteredScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + FiltersPushed.list = filters + + val filterFunctions = filters.collect { + case EqualTo("a", v) => (a: Int) => a == v + case LessThan("a", v: Int) => (a: Int) => a < v + case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v + case GreaterThan("a", v: Int) => (a: Int) => a > v + case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + } + + def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + + sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +// A hack for better error messages when filter pushdown fails. +object FiltersPushed { + var list: Seq[Filter] = Nil +} + +class FilteredScanSuite extends DataSourceTest { + + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + + def testPushDown(sqlString: String, expectedCount: Int): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala new file mode 100644 index 0000000000..fee2e22611 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class PrunedScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + sqlContext.sparkContext.parallelize(from to to).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class PrunedScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT a, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + testPruning("SELECT * FROM oneToTenPruned", "a", "b") + testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") + testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") + testPruning("SELECT b, b FROM oneToTenPruned", "b") + testPruning("SELECT a FROM oneToTenPruned", "a") + testPruning("SELECT b FROM oneToTenPruned", "b") + + def testPruning(sqlString: String, expectedColumns: String*): Unit = { + test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.size != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala new file mode 100644 index 0000000000..b254b0620c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -0,0 +1,125 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class DefaultSource extends SimpleScanSource + +class SimpleScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) +} + +class TableScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen WHERE i < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT i * 2 FROM oneToTen", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + + test("Caching") { + // Cached Query Execution + cacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen")) + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen")) + checkAnswer( + sql("SELECT i FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + checkAnswer( + sql("SELECT i FROM oneToTen WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT i * 2 FROM oneToTen")) + checkAnswer( + sql("SELECT i * 2 FROM oneToTen"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer( + sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen"), 0) + } + + test("defaultSource") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTenDef"), + (1 to 10).map(Row(_)).toSeq) + } +} |