aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala174
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala342
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala170
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala69
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala229
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala117
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala137
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala158
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala276
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala212
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala220
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala24
-rw-r--r--sql/core/src/test/resources/log4j.properties52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala201
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala75
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala211
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala71
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala126
26 files changed, 3385 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
new file mode 100644
index 0000000000..b8b9e5839d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.language.implicitConversions
+
+import scala.reflect._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.Aggregator
+import org.apache.spark.SparkContext._
+import org.apache.spark.util.collection.AppendOnlyMap
+
+/**
+ * Extra functions on RDDs that perform only local operations. These can be used when data has
+ * already been partitioned correctly.
+ */
+private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
+ extends Logging
+ with Serializable {
+
+ /**
+ * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have
+ * the same number of partitions. Partitions of these two RDDs are cogrouped
+ * according to the indexes of partitions. If we have two RDDs and
+ * each of them has n partitions, we will cogroup the partition i from `this`
+ * with the partition i from `other`.
+ * This function will not introduce a shuffling operation.
+ */
+ def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
+ val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => {
+ val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
+
+ val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any])
+ }
+
+ val getSeq = (k: K) => {
+ map.changeValue(k, update)
+ }
+
+ iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 }
+ iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 }
+
+ map.iterator
+ }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])}
+
+ cg
+ }
+
+ /**
+ * Group the values for each key within a partition of the RDD into a single sequence.
+ * This function will not introduce a shuffling operation.
+ */
+ def groupByKeyLocally(): RDD[(K, Seq[V])] = {
+ def createCombiner(v: V) = ArrayBuffer(v)
+ def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _)
+ val bufs = self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
+ }, preservesPartitioning = true)
+ bufs.asInstanceOf[RDD[(K, Seq[V])]]
+ }
+
+ /**
+ * Join corresponding partitions of `this` and `other`.
+ * If we have two RDDs and each of them has n partitions,
+ * we will join the partition i from `this` with the partition i from `other`.
+ * This function will not introduce a shuffling operation.
+ */
+ def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
+ cogroupLocally(other).flatMapValues {
+ case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ }
+ }
+}
+
+private[spark] object PartitionLocalRDDFunctions {
+ implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
+ new PartitionLocalRDDFunctions(rdd)
+}
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
new file mode 100644
index 0000000000..587cc7487f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.catalyst.planning.QueryPlanner
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand, WriteToFile}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.execution._
+
+/**
+ * <span class="badge" style="float: right; background-color: darkblue;">ALPHA COMPONENT</span>
+ *
+ * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]]
+ * objects and the execution of SQL queries.
+ *
+ * @groupname userf Spark SQL Functions
+ * @groupname Ungrouped Support functions for language integrated queries.
+ */
+class SQLContext(@transient val sparkContext: SparkContext)
+ extends Logging
+ with dsl.ExpressionConversions
+ with Serializable {
+
+ self =>
+
+ @transient
+ protected[sql] lazy val catalog: Catalog = new SimpleCatalog
+ @transient
+ protected[sql] lazy val analyzer: Analyzer =
+ new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true)
+ @transient
+ protected[sql] val optimizer = Optimizer
+ @transient
+ protected[sql] val parser = new catalyst.SqlParser
+
+ protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql)
+ protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
+ protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ /**
+ * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span>
+ *
+ * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan
+ * interface is considered internal, and thus not guranteed to be stable. As a result, using
+ * them directly is not reccomended.
+ */
+ implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = new SchemaRDD(this, plan)
+
+ /**
+ * Creates a SchemaRDD from an RDD of case classes.
+ *
+ * @group userf
+ */
+ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) =
+ new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd)))
+
+ /**
+ * Loads a parequet file, returning the result as a [[SchemaRDD]].
+ *
+ * @group userf
+ */
+ def parquetFile(path: String): SchemaRDD =
+ new SchemaRDD(this, parquet.ParquetRelation("ParquetFile", path))
+
+
+ /**
+ * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
+ * during the lifetime of this instance of SQLContext.
+ *
+ * @group userf
+ */
+ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
+ catalog.registerTable(None, tableName, rdd.logicalPlan)
+ }
+
+ /**
+ * Executes a SQL query using Spark, returning the result as a SchemaRDD.
+ *
+ * @group userf
+ */
+ def sql(sqlText: String): SchemaRDD = {
+ val result = new SchemaRDD(this, parseSql(sqlText))
+ // We force query optimization to happen right away instead of letting it happen lazily like
+ // when using the query DSL. This is so DDL commands behave as expected. This is only
+ // generates the RDD lineage for DML queries, but do not perform any execution.
+ result.queryExecution.toRdd
+ result
+ }
+
+ protected[sql] class SparkPlanner extends SparkStrategies {
+ val sparkContext = self.sparkContext
+
+ val strategies: Seq[Strategy] =
+ TopK ::
+ PartialAggregation ::
+ SparkEquiInnerJoin ::
+ BasicOperators ::
+ CartesianProduct ::
+ BroadcastNestedLoopJoin :: Nil
+ }
+
+ @transient
+ protected[sql] val planner = new SparkPlanner
+
+ /**
+ * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
+ * inserting shuffle operations as needed.
+ */
+ @transient
+ protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
+ val batches =
+ Batch("Add exchange", Once, AddExchange) ::
+ Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil
+ }
+
+ /**
+ * The primary workflow for executing relational queries using Spark. Designed to allow easy
+ * access to the intermediate phases of query execution for developers.
+ */
+ protected abstract class QueryExecution {
+ def logical: LogicalPlan
+
+ lazy val analyzed = analyzer(logical)
+ lazy val optimizedPlan = optimizer(analyzed)
+ // TODO: Don't just pick the first one...
+ lazy val sparkPlan = planner(optimizedPlan).next()
+ lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
+
+ /** Internal version of the RDD. Avoids copies and has no schema */
+ lazy val toRdd: RDD[Row] = executedPlan.execute()
+
+ protected def stringOrError[A](f: => A): String =
+ try f.toString catch { case e: Throwable => e.toString }
+
+ override def toString: String =
+ s"""== Logical Plan ==
+ |${stringOrError(analyzed)}
+ |== Optimized Logical Plan
+ |${stringOrError(optimizedPlan)}
+ |== Physical Plan ==
+ |${stringOrError(executedPlan)}
+ """.stripMargin.trim
+
+ /**
+ * Runs the query after interposing operators that print the result of each intermediate step.
+ */
+ def debugExec() = DebugQuery(executedPlan).execute().collect()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
new file mode 100644
index 0000000000..91c3aaa2b8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -0,0 +1,342 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.apache.spark.{OneToOneDependency, Dependency, Partition, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types.BooleanType
+
+/**
+ * <span class="badge" style="float: right; background-color: darkblue;">ALPHA COMPONENT</span>
+ *
+ * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions,
+ * SchemaRDDs can be used in relational queries, as shown in the examples below.
+ *
+ * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD
+ * whose elements are scala case classes into a SchemaRDD. This conversion can also be done
+ * explicitly using the `createSchemaRDD` function on a [[SQLContext]].
+ *
+ * A `SchemaRDD` can also be created by loading data in from external sources, for example,
+ * by using the `parquetFile` method on [[SQLContext]].
+ *
+ * == SQL Queries ==
+ * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once
+ * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements.
+ *
+ * {{{
+ * // One method for defining the schema of an RDD is to make a case class with the desired column
+ * // names and types.
+ * case class Record(key: Int, value: String)
+ *
+ * val sc: SparkContext // An existing spark context.
+ * val sqlContext = new SQLContext(sc)
+ *
+ * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
+ * import sqlContext._
+ *
+ * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i")))
+ * // Any RDD containing case classes can be registered as a table. The schema of the table is
+ * // automatically inferred using scala reflection.
+ * rdd.registerAsTable("records")
+ *
+ * val results: SchemaRDD = sql("SELECT * FROM records")
+ * }}}
+ *
+ * == Language Integrated Queries ==
+ *
+ * {{{
+ *
+ * case class Record(key: Int, value: String)
+ *
+ * val sc: SparkContext // An existing spark context.
+ * val sqlContext = new SQLContext(sc)
+ *
+ * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
+ * import sqlContext._
+ *
+ * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i)))
+ *
+ * // Example of language integrated queries.
+ * rdd.where('key === 1).orderBy('value.asc).select('key).collect()
+ * }}}
+ *
+ * @todo There is currently no support for creating SchemaRDDs from either Java or Python RDDs.
+ *
+ * @groupname Query Language Integrated Queries
+ * @groupdesc Query Functions that create new queries from SchemaRDDs. The
+ * result of all query functions is also a SchemaRDD, allowing multiple operations to be
+ * chained using a builder pattern.
+ * @groupprio Query -2
+ * @groupname schema SchemaRDD Functions
+ * @groupprio schema -1
+ * @groupname Ungrouped Base RDD Functions
+ */
+class SchemaRDD(
+ @transient val sqlContext: SQLContext,
+ @transient val logicalPlan: LogicalPlan)
+ extends RDD[Row](sqlContext.sparkContext, Nil) {
+
+ /**
+ * A lazily computed query execution workflow. All other RDD operations are passed
+ * through to the RDD that is produced by this workflow.
+ *
+ * We want this to be lazy because invoking the whole query optimization pipeline can be
+ * expensive.
+ */
+ @transient
+ protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan)
+
+ override def toString =
+ s"""${super.toString}
+ |== Query Plan ==
+ |${queryExecution.executedPlan}""".stripMargin.trim
+
+ // =========================================================================================
+ // RDD functions: Copy the interal row representation so we present immutable data to users.
+ // =========================================================================================
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Row] =
+ firstParent[Row].compute(split, context).map(_.copy())
+
+ override def getPartitions: Array[Partition] = firstParent[Row].partitions
+
+ override protected def getDependencies: Seq[Dependency[_]] =
+ List(new OneToOneDependency(queryExecution.toRdd))
+
+
+ // =======================================================================
+ // Query DSL
+ // =======================================================================
+
+ /**
+ * Changes the output of this relation to the given expressions, similar to the `SELECT` clause
+ * in SQL.
+ *
+ * {{{
+ * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName)
+ * }}}
+ *
+ * @param exprs a set of logical expression that will be evaluated for each input row.
+ *
+ * @group Query
+ */
+ def select(exprs: NamedExpression*): SchemaRDD =
+ new SchemaRDD(sqlContext, Project(exprs, logicalPlan))
+
+ /**
+ * Filters the ouput, only returning those rows where `condition` evaluates to true.
+ *
+ * {{{
+ * schemaRDD.where('a === 'b)
+ * schemaRDD.where('a === 1)
+ * schemaRDD.where('a + 'b > 10)
+ * }}}
+ *
+ * @group Query
+ */
+ def where(condition: Expression): SchemaRDD =
+ new SchemaRDD(sqlContext, Filter(condition, logicalPlan))
+
+ /**
+ * Performs a relational join on two SchemaRDDs
+ *
+ * @param otherPlan the [[SchemaRDD]] that should be joined with this one.
+ * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.`
+ * @param condition An optional condition for the join operation. This is equivilent to the `ON`
+ * clause in standard SQL. In the case of `Inner` joins, specifying a
+ * `condition` is equivilent to adding `where` clauses after the `join`.
+ *
+ * @group Query
+ */
+ def join(
+ otherPlan: SchemaRDD,
+ joinType: JoinType = Inner,
+ condition: Option[Expression] = None): SchemaRDD =
+ new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, condition))
+
+ /**
+ * Sorts the results by the given expressions.
+ * {{{
+ * schemaRDD.orderBy('a)
+ * schemaRDD.orderBy('a, 'b)
+ * schemaRDD.orderBy('a.asc, 'b.desc)
+ * }}}
+ *
+ * @group Query
+ */
+ def orderBy(sortExprs: SortOrder*): SchemaRDD =
+ new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
+
+ /**
+ * Performs a grouping followed by an aggregation.
+ *
+ * {{{
+ * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales)
+ * }}}
+ *
+ * @group Query
+ */
+ def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = {
+ val aliasedExprs = aggregateExprs.map {
+ case ne: NamedExpression => ne
+ case e => Alias(e, e.toString)()
+ }
+ new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
+ }
+
+ /**
+ * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
+ * with the same name, for example, when peforming self-joins.
+ *
+ * {{{
+ * val x = schemaRDD.where('a === 1).subquery('x)
+ * val y = schemaRDD.where('a === 2).subquery('y)
+ * x.join(y).where("x.a".attr === "y.a".attr),
+ * }}}
+ *
+ * @group Query
+ */
+ def subquery(alias: Symbol) =
+ new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan))
+
+ /**
+ * Combines the tuples of two RDDs with the same schema, keeping duplicates.
+ *
+ * @group Query
+ */
+ def unionAll(otherPlan: SchemaRDD) =
+ new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan))
+
+ /**
+ * Filters tuples using a function over the value of the specified column.
+ *
+ * {{{
+ * schemaRDD.sfilter('a)((a: Int) => ...)
+ * }}}
+ *
+ * @group Query
+ */
+ def where[T1](arg1: Symbol)(udf: (T1) => Boolean) =
+ new SchemaRDD(
+ sqlContext,
+ Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan))
+
+ /**
+ * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span>
+ *
+ * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use
+ * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of
+ * the column is not known at compile time, all attributes are converted to strings before
+ * being passed to the function.
+ *
+ * {{{
+ * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith")
+ * }}}
+ *
+ * @group Query
+ */
+ def where(dynamicUdf: (DynamicRow) => Boolean) =
+ new SchemaRDD(
+ sqlContext,
+ Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan))
+
+ /**
+ * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span>
+ *
+ * Returns a sampled version of the underlying dataset.
+ *
+ * @group Query
+ */
+ def sample(
+ fraction: Double,
+ withReplacement: Boolean = true,
+ seed: Int = (math.random * 1000).toInt) =
+ new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))
+
+ /**
+ * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span>
+ *
+ * Applies the given Generator, or table generating function, to this relation.
+ *
+ * @param generator A table generating function. The API for such functions is likely to change
+ * in future releases
+ * @param join when set to true, each output row of the generator is joined with the input row
+ * that produced it.
+ * @param outer when set to true, at least one row will be produced for each input row, similar to
+ * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a
+ * given row, a single row will be output, with `NULL` values for each of the
+ * generated columns.
+ * @param alias an optional alias that can be used as qualif for the attributes that are produced
+ * by this generate operation.
+ *
+ * @group Query
+ */
+ def generate(
+ generator: Generator,
+ join: Boolean = false,
+ outer: Boolean = false,
+ alias: Option[String] = None) =
+ new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan))
+
+ /**
+ * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span>
+ *
+ * Adds the rows from this RDD to the specified table. Note in a standard [[SQLContext]] there is
+ * no notion of persistent tables, and thus queries that contain this operator will fail to
+ * optimize. When working with an extension of a SQLContext that has a persistent catalog, such
+ * as a `HiveContext`, this operation will result in insertions to the table specified.
+ *
+ * @group schema
+ */
+ def insertInto(tableName: String, overwrite: Boolean = false) =
+ new SchemaRDD(
+ sqlContext,
+ InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite))
+
+ /**
+ * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that
+ * are written out using this method can be read back in as a SchemaRDD using the ``function
+ *
+ * @group schema
+ */
+ def saveAsParquetFile(path: String): Unit = {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+
+ /**
+ * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ * table is tied to the [[SQLContext]] that was used to create this SchemaRDD.
+ *
+ * @group schema
+ */
+ def registerAsTable(tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(this, tableName)
+ }
+
+ /**
+ * Returns this RDD as a SchemaRDD.
+ * @group schema
+ */
+ def toSchemaRDD = this
+
+ def analyze = sqlContext.analyzer(logicalPlan)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
new file mode 100644
index 0000000000..72dc5ec6ad
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import java.nio.ByteBuffer
+
+import com.esotericsoftware.kryo.{Kryo, Serializer}
+import com.esotericsoftware.kryo.io.{Output, Input}
+
+import org.apache.spark.{SparkConf, RangePartitioner, HashPartitioner}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.util.MutablePair
+
+import catalyst.rules.Rule
+import catalyst.errors._
+import catalyst.expressions._
+import catalyst.plans.physical._
+
+private class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
+ override def newKryo(): Kryo = {
+ val kryo = new Kryo
+ kryo.setRegistrationRequired(true)
+ kryo.register(classOf[MutablePair[_,_]])
+ kryo.register(classOf[Array[Any]])
+ kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
+ kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
+ kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
+ kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
+ kryo.setReferences(false)
+ kryo.setClassLoader(this.getClass.getClassLoader)
+ kryo
+ }
+}
+
+private class BigDecimalSerializer extends Serializer[BigDecimal] {
+ def write(kryo: Kryo, output: Output, bd: math.BigDecimal) {
+ // TODO: There are probably more efficient representations than strings...
+ output.writeString(bd.toString)
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = {
+ BigDecimal(input.readString())
+ }
+}
+
+case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
+
+ override def outputPartitioning = newPartitioning
+
+ def output = child.output
+
+ def execute() = attachTree(this , "execute") {
+ newPartitioning match {
+ case HashPartitioning(expressions, numPartitions) => {
+ // TODO: Eliminate redundant expressions in grouping key and value.
+ val rdd = child.execute().mapPartitions { iter =>
+ val hashExpressions = new MutableProjection(expressions)
+ val mutablePair = new MutablePair[Row, Row]()
+ iter.map(r => mutablePair.update(hashExpressions(r), r))
+ }
+ val part = new HashPartitioner(numPartitions)
+ val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part)
+ shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.map(_._2)
+ }
+ case RangePartitioning(sortingExpressions, numPartitions) => {
+ // TODO: RangePartitioner should take an Ordering.
+ implicit val ordering = new RowOrdering(sortingExpressions)
+
+ val rdd = child.execute().mapPartitions { iter =>
+ val mutablePair = new MutablePair[Row, Null](null, null)
+ iter.map(row => mutablePair.update(row, null))
+ }
+ val part = new RangePartitioner(numPartitions, rdd, ascending = true)
+ val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part)
+ shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+
+ shuffled.map(_._1)
+ }
+ case SinglePartition =>
+ child.execute().coalesce(1, true)
+
+ case _ => sys.error(s"Exchange not implemented for $newPartitioning")
+ // TODO: Handle BroadcastPartitioning.
+ }
+ }
+}
+
+/**
+ * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the
+ * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting
+ * [[Exchange]] Operators where required.
+ */
+object AddExchange extends Rule[SparkPlan] {
+ // TODO: Determine the number of partitions.
+ val numPartitions = 8
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator: SparkPlan =>
+ // Check if every child's outputPartitioning satisfies the corresponding
+ // required data distribution.
+ def meetsRequirements =
+ !operator.requiredChildDistribution.zip(operator.children).map {
+ case (required, child) =>
+ val valid = child.outputPartitioning.satisfies(required)
+ logger.debug(
+ s"${if (valid) "Valid" else "Invalid"} distribution," +
+ s"required: $required current: ${child.outputPartitioning}")
+ valid
+ }.exists(!_)
+
+ // Check if outputPartitionings of children are compatible with each other.
+ // It is possible that every child satisfies its required data distribution
+ // but two children have incompatible outputPartitionings. For example,
+ // A dataset is range partitioned by "a.asc" (RangePartitioning) and another
+ // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two
+ // datasets are both clustered by "a", but these two outputPartitionings are not
+ // compatible.
+ // TODO: ASSUMES TRANSITIVITY?
+ def compatible =
+ !operator.children
+ .map(_.outputPartitioning)
+ .sliding(2)
+ .map {
+ case Seq(a) => true
+ case Seq(a,b) => a compatibleWith b
+ }.exists(!_)
+
+ // Check if the partitioning we want to ensure is the same as the child's output
+ // partitioning. If so, we do not need to add the Exchange operator.
+ def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) =
+ if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
+
+ if (meetsRequirements && compatible) {
+ operator
+ } else {
+ // At least one child does not satisfies its required data distribution or
+ // at least one child's outputPartitioning is not compatible with another child's
+ // outputPartitioning. In this case, we need to add Exchange operators.
+ val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
+ case (AllTuples, child) =>
+ addExchangeIfNecessary(SinglePartition, child)
+ case (ClusteredDistribution(clustering), child) =>
+ addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
+ case (OrderedDistribution(ordering), child) =>
+ addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
+ case (UnspecifiedDistribution, child) => child
+ case (dist, _) => sys.error(s"Don't know how to ensure $dist")
+ }
+ operator.withNewChildren(repartitionedChildren)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
new file mode 100644
index 0000000000..c1da3653c5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import catalyst.expressions._
+import catalyst.types._
+
+/**
+ * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the
+ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
+ * programming with one important additional feature, which allows the input rows to be joined with
+ * their output.
+ * @param join when true, each output row is implicitly joined with the input tuple that produced
+ * it.
+ * @param outer when true, each input row will be output at least once, even if the output of the
+ * given `generator` is empty. `outer` has no effect when `join` is false.
+ */
+case class Generate(
+ generator: Generator,
+ join: Boolean,
+ outer: Boolean,
+ child: SparkPlan)
+ extends UnaryNode {
+
+ def output =
+ if (join) child.output ++ generator.output else generator.output
+
+ def execute() = {
+ if (join) {
+ child.execute().mapPartitions { iter =>
+ val nullValues = Seq.fill(generator.output.size)(Literal(null))
+ // Used to produce rows with no matches when outer = true.
+ val outerProjection =
+ new Projection(child.output ++ nullValues, child.output)
+
+ val joinProjection =
+ new Projection(child.output ++ generator.output, child.output ++ generator.output)
+ val joinedRow = new JoinedRow
+
+ iter.flatMap {row =>
+ val outputRows = generator(row)
+ if (outer && outputRows.isEmpty) {
+ outerProjection(row) :: Nil
+ } else {
+ outputRows.map(or => joinProjection(joinedRow(row, or)))
+ }
+ }
+ }
+ } else {
+ child.execute().mapPartitions(iter => iter.flatMap(generator))
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala
new file mode 100644
index 0000000000..7ce8608d20
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+class QueryExecutionException(message: String) extends Exception(message)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
new file mode 100644
index 0000000000..5626181d18
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.apache.spark.rdd.RDD
+
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.trees
+
+abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
+ self: Product =>
+
+ // TODO: Move to `DistributedPlan`
+ /** Specifies how data is partitioned across different nodes in the cluster. */
+ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
+ /** Specifies any partition requirements on the input data for this operator. */
+ def requiredChildDistribution: Seq[Distribution] =
+ Seq.fill(children.size)(UnspecifiedDistribution)
+
+ /**
+ * Runs this query returning the result as an RDD.
+ */
+ def execute(): RDD[Row]
+
+ /**
+ * Runs this query returning the result as an array.
+ */
+ def executeCollect(): Array[Row] = execute().collect()
+
+ protected def buildRow(values: Seq[Any]): Row =
+ new catalyst.expressions.GenericRow(values.toArray)
+}
+
+/**
+ * Allows already planned SparkQueries to be linked into logical query plans.
+ *
+ * Note that in general it is not valid to use this class to link multiple copies of the same
+ * physical operator into the same query plan as this violates the uniqueness of expression ids.
+ * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just
+ * replace the output attributes with new copies of themselves without breaking any attribute
+ * linking.
+ */
+case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
+ extends logical.LogicalPlan with MultiInstanceRelation {
+
+ def output = alreadyPlanned.output
+ def references = Set.empty
+ def children = Nil
+
+ override final def newInstance: this.type = {
+ SparkLogicalPlan(
+ alreadyPlanned match {
+ case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+ case _ => sys.error("Multiple instance of the same relation detected.")
+ }).asInstanceOf[this.type]
+ }
+}
+
+trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
+ self: Product =>
+}
+
+trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] {
+ self: Product =>
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] {
+ self: Product =>
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
new file mode 100644
index 0000000000..85035b8118
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.apache.spark.SparkContext
+
+import catalyst.expressions._
+import catalyst.planning._
+import catalyst.plans._
+import catalyst.plans.logical.LogicalPlan
+import catalyst.plans.physical._
+import parquet.ParquetRelation
+import parquet.InsertIntoParquetTable
+
+abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
+
+ val sparkContext: SparkContext
+
+ object SparkEquiInnerJoin extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
+ logger.debug(s"Considering join: ${predicates ++ condition}")
+ // Find equi-join predicates that can be evaluated before the join, and thus can be used
+ // as join keys. Note we can only mix in the conditions with other predicates because the
+ // match above ensures that this is and Inner join.
+ val (joinPredicates, otherPredicates) = (predicates ++ condition).partition {
+ case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
+ (canEvaluate(l, right) && canEvaluate(r, left)) => true
+ case _ => false
+ }
+
+ val joinKeys = joinPredicates.map {
+ case Equals(l,r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
+ case Equals(l,r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l)
+ }
+
+ // Do not consider this strategy if there are no join keys.
+ if (joinKeys.nonEmpty) {
+ val leftKeys = joinKeys.map(_._1)
+ val rightKeys = joinKeys.map(_._2)
+
+ val joinOp = execution.SparkEquiInnerJoin(
+ leftKeys, rightKeys, planLater(left), planLater(right))
+
+ // Make sure other conditions are met if present.
+ if (otherPredicates.nonEmpty) {
+ execution.Filter(combineConjunctivePredicates(otherPredicates), joinOp) :: Nil
+ } else {
+ joinOp :: Nil
+ }
+ } else {
+ logger.debug(s"Avoiding spark join with no join keys.")
+ Nil
+ }
+ case _ => Nil
+ }
+
+ private def combineConjunctivePredicates(predicates: Seq[Expression]) =
+ predicates.reduceLeft(And)
+
+ /** Returns true if `expr` can be evaluated using only the output of `plan`. */
+ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
+ expr.references subsetOf plan.outputSet
+ }
+
+ object PartialAggregation extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ // Collect all aggregate expressions.
+ val allAggregates =
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ // Collect all aggregate expressions that can be computed partially.
+ val partialAggregates =
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+
+ // Only do partial aggregation if supported by all aggregate expressions.
+ if (allAggregates.size == partialAggregates.size) {
+ // Create a map of expressions to their partial evaluations for all aggregate expressions.
+ val partialEvaluations: Map[Long, SplitEvaluation] =
+ partialAggregates.map(a => (a.id, a.asPartial)).toMap
+
+ // We need to pass all grouping expressions though so the grouping can happen a second
+ // time. However some of them might be unnamed so we alias them allowing them to be
+ // referenced in the second aggregation.
+ val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
+ case n: NamedExpression => (n, n)
+ case other => (other, Alias(other, "PartialGroup")())
+ }.toMap
+
+ // Replace aggregations with a new expression that computes the result from the already
+ // computed partial evaluations and grouping values.
+ val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+ case e: Expression if partialEvaluations.contains(e.id) =>
+ partialEvaluations(e.id).finalEvaluation
+ case e: Expression if namedGroupingExpressions.contains(e) =>
+ namedGroupingExpressions(e).toAttribute
+ }).asInstanceOf[Seq[NamedExpression]]
+
+ val partialComputation =
+ (namedGroupingExpressions.values ++
+ partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+
+ // Construct two phased aggregation.
+ execution.Aggregate(
+ partial = false,
+ namedGroupingExpressions.values.map(_.toAttribute).toSeq,
+ rewrittenAggregateExpressions,
+ execution.Aggregate(
+ partial = true,
+ groupingExpressions,
+ partialComputation,
+ planLater(child))(sparkContext))(sparkContext) :: Nil
+ } else {
+ Nil
+ }
+ case _ => Nil
+ }
+ }
+
+ object BroadcastNestedLoopJoin extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Join(left, right, joinType, condition) =>
+ execution.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ object CartesianProduct extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Join(left, right, _, None) =>
+ execution.CartesianProduct(planLater(left), planLater(right)) :: Nil
+ case logical.Join(left, right, Inner, Some(condition)) =>
+ execution.Filter(condition,
+ execution.CartesianProduct(planLater(left), planLater(right))) :: Nil
+ case _ => Nil
+ }
+ }
+
+ protected lazy val singleRowRdd =
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: Seq[Any] => s.map(convertToCatalyst)
+ case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+ case other => other
+ }
+
+ object TopK extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
+ execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ // Can we automate these 'pass through' operations?
+ object BasicOperators extends Strategy {
+ // TOOD: Set
+ val numPartitions = 200
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Distinct(child) =>
+ execution.Aggregate(
+ partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil
+ case logical.Sort(sortExprs, child) =>
+ // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
+ execution.Sort(sortExprs, global = true, planLater(child)):: Nil
+ case logical.SortPartitions(sortExprs, child) =>
+ // This sort only sorts tuples within a partition. Its requiredDistribution will be
+ // an UnspecifiedDistribution.
+ execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
+ case logical.Project(projectList, r: ParquetRelation)
+ if projectList.forall(_.isInstanceOf[Attribute]) =>
+
+ // simple projection of data loaded from Parquet file
+ parquet.ParquetTableScan(
+ projectList.asInstanceOf[Seq[Attribute]],
+ r,
+ None)(sparkContext) :: Nil
+ case logical.Project(projectList, child) =>
+ execution.Project(projectList, planLater(child)) :: Nil
+ case logical.Filter(condition, child) =>
+ execution.Filter(condition, planLater(child)) :: Nil
+ case logical.Aggregate(group, agg, child) =>
+ execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil
+ case logical.Sample(fraction, withReplacement, seed, child) =>
+ execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case logical.LocalRelation(output, data) =>
+ val dataAsRdd =
+ sparkContext.parallelize(data.map(r =>
+ new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
+ execution.ExistingRdd(output, dataAsRdd) :: Nil
+ case logical.StopAfter(IntegerLiteral(limit), child) =>
+ execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
+ case Unions(unionChildren) =>
+ execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
+ case logical.Generate(generator, join, outer, _, child) =>
+ execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
+ case logical.NoRelation =>
+ execution.ExistingRdd(Nil, singleRowRdd) :: Nil
+ case logical.Repartition(expressions, child) =>
+ execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
+ case logical.WriteToFile(path, child) =>
+ val relation =
+ ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None)
+ InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil
+ case p: parquet.ParquetRelation =>
+ parquet.ParquetTableScan(p.output, p, None)(sparkContext) :: Nil
+ case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
+ case _ => Nil
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
new file mode 100644
index 0000000000..51889c1988
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.apache.spark.SparkContext
+
+import catalyst.errors._
+import catalyst.expressions._
+import catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples}
+import catalyst.types._
+
+import org.apache.spark.rdd.PartitionLocalRDDFunctions._
+
+/**
+ * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
+ * group.
+ *
+ * @param partial if true then aggregation is done partially on local data without shuffling to
+ * ensure all values where `groupingExpressions` are equal are present.
+ * @param groupingExpressions expressions that are evaluated to determine grouping.
+ * @param aggregateExpressions expressions that are computed for each group.
+ * @param child the input data source.
+ */
+case class Aggregate(
+ partial: Boolean,
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: SparkPlan)(@transient sc: SparkContext)
+ extends UnaryNode {
+
+ override def requiredChildDistribution =
+ if (partial) {
+ UnspecifiedDistribution :: Nil
+ } else {
+ if (groupingExpressions == Nil) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = aggregateExpressions.map(_.toAttribute)
+
+ /* Replace all aggregate expressions with spark functions that will compute the result. */
+ def createAggregateImplementations() = aggregateExpressions.map { agg =>
+ val impl = agg transform {
+ case a: AggregateExpression => a.newInstance
+ }
+
+ val remainingAttributes = impl.collect { case a: Attribute => a }
+ // If any references exist that are not inside agg functions then the must be grouping exprs
+ // in this case we must rebind them to the grouping tuple.
+ if (remainingAttributes.nonEmpty) {
+ val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c }
+
+ // An exact match with a grouping expression
+ val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match {
+ case -1 => None
+ case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute))
+ }
+
+ exactGroupingExpr.getOrElse(
+ sys.error(s"$agg is not in grouping expressions: $groupingExpressions"))
+ } else {
+ impl
+ }
+ }
+
+ def execute() = attachTree(this, "execute") {
+ // TODO: If the child of it is an [[catalyst.execution.Exchange]],
+ // do not evaluate the groupingExpressions again since we have evaluated it
+ // in the [[catalyst.execution.Exchange]].
+ val grouped = child.execute().mapPartitions { iter =>
+ val buildGrouping = new Projection(groupingExpressions)
+ iter.map(row => (buildGrouping(row), row.copy()))
+ }.groupByKeyLocally()
+
+ val result = grouped.map { case (group, rows) =>
+ val aggImplementations = createAggregateImplementations()
+
+ // Pull out all the functions so we can feed each row into them.
+ val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f })
+
+ rows.foreach { row =>
+ aggFunctions.foreach(_.update(row))
+ }
+ buildRow(aggImplementations.map(_.apply(group)))
+ }
+
+ // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY...
+ if (groupingExpressions.isEmpty && result.count == 0) {
+ // When there there is no output to the Aggregate operator, we still output an empty row.
+ val aggImplementations = createAggregateImplementations()
+ sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil)
+ } else {
+ result
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
new file mode 100644
index 0000000000..c6d31d9abc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+
+import catalyst.errors._
+import catalyst.expressions._
+import catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution}
+import catalyst.plans.logical.LogicalPlan
+import catalyst.ScalaReflection
+
+case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
+ def output = projectList.map(_.toAttribute)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ @transient val resuableProjection = new MutableProjection(projectList)
+ iter.map(resuableProjection)
+ }
+}
+
+case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
+ def output = child.output
+
+ def execute() = child.execute().mapPartitions { iter =>
+ iter.filter(condition.apply(_).asInstanceOf[Boolean])
+ }
+}
+
+case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
+ extends UnaryNode {
+
+ def output = child.output
+
+ // TODO: How to pick seed?
+ def execute() = child.execute().sample(withReplacement, fraction, seed)
+}
+
+case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
+ // TODO: attributes output by union should be distinct for nullability purposes
+ def output = children.head.output
+ def execute() = sc.union(children.map(_.execute()))
+
+ override def otherCopyArgs = sc :: Nil
+}
+
+case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+ override def otherCopyArgs = sc :: Nil
+
+ def output = child.output
+
+ override def executeCollect() = child.execute().map(_.copy()).take(limit)
+
+ // TODO: Terminal split should be implemented differently from non-terminal split.
+ // TODO: Pick num splits based on |limit|.
+ def execute() = sc.makeRDD(executeCollect(), 1)
+}
+
+case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
+ (@transient sc: SparkContext) extends UnaryNode {
+ override def otherCopyArgs = sc :: Nil
+
+ def output = child.output
+
+ @transient
+ lazy val ordering = new RowOrdering(sortOrder)
+
+ override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
+
+ // TODO: Terminal split should be implemented differently from non-terminal split.
+ // TODO: Pick num splits based on |limit|.
+ def execute() = sc.makeRDD(executeCollect(), 1)
+}
+
+
+case class Sort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan)
+ extends UnaryNode {
+ override def requiredChildDistribution =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ @transient
+ lazy val ordering = new RowOrdering(sortOrder)
+
+ def execute() = attachTree(this, "sort") {
+ // TODO: Optimize sorting operation?
+ child.execute()
+ .mapPartitions(
+ iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
+ preservesPartitioning = true)
+ }
+
+ def output = child.output
+}
+
+object ExistingRdd {
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: Seq[Any] => s.map(convertToCatalyst)
+ case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+ case other => other
+ }
+
+ def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
+ // TODO: Reuse the row, don't use map on the product iterator. Maybe code gen?
+ data.map(r => new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)
+ }
+
+ def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {
+ ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
+ }
+}
+
+case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
+ def execute() = rdd
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala
new file mode 100644
index 0000000000..db259b4c4b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+object DebugQuery {
+ def apply(plan: SparkPlan): SparkPlan = {
+ val visited = new collection.mutable.HashSet[Long]()
+ plan transform {
+ case s: SparkPlan if !visited.contains(s.id) =>
+ visited += s.id
+ DebugNode(s)
+ }
+ }
+}
+
+case class DebugNode(child: SparkPlan) extends UnaryNode {
+ def references = Set.empty
+ def output = child.output
+ def execute() = {
+ val childRdd = child.execute()
+ println(
+ s"""
+ |=========================
+ |${child.simpleString}
+ |=========================
+ """.stripMargin)
+ childRdd.foreach(println(_))
+ childRdd
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
new file mode 100644
index 0000000000..5934fd1b03
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import scala.collection.mutable
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+
+import catalyst.errors._
+import catalyst.expressions._
+import catalyst.plans._
+import catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+
+import org.apache.spark.rdd.PartitionLocalRDDFunctions._
+
+case class SparkEquiInnerJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ def output = left.output ++ right.output
+
+ def execute() = attachTree(this, "execute") {
+ val leftWithKeys = left.execute().mapPartitions { iter =>
+ val generateLeftKeys = new Projection(leftKeys, left.output)
+ iter.map(row => (generateLeftKeys(row), row.copy()))
+ }
+
+ val rightWithKeys = right.execute().mapPartitions { iter =>
+ val generateRightKeys = new Projection(rightKeys, right.output)
+ iter.map(row => (generateRightKeys(row), row.copy()))
+ }
+
+ // Do the join.
+ val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
+ // Drop join keys and merge input tuples.
+ joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
+ }
+
+ /**
+ * Filters any rows where the any of the join keys is null, ensuring three-valued
+ * logic for the equi-join conditions.
+ */
+ protected def filterNulls(rdd: RDD[(Row, Row)]) =
+ rdd.filter {
+ case (key: Seq[_], _) => !key.exists(_ == null)
+ }
+}
+
+case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
+ def output = left.output ++ right.output
+
+ def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map {
+ case (l: Row, r: Row) => buildRow(l ++ r)
+ }
+}
+
+case class BroadcastNestedLoopJoin(
+ streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
+ (@transient sc: SparkContext)
+ extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = left.output ++ right.output
+
+ /** The Streamed Relation */
+ def left = streamed
+ /** The Broadcast relation */
+ def right = broadcast
+
+ @transient lazy val boundCondition =
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true))
+
+
+ def execute() = {
+ val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
+ val matchedRows = new mutable.ArrayBuffer[Row]
+ val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
+ val joinedRow = new JoinedRow
+
+ streamedIter.foreach { streamedRow =>
+ var i = 0
+ var matched = false
+
+ while (i < broadcastedRelation.value.size) {
+ // TODO: One bitset per partition instead of per row.
+ val broadcastedRow = broadcastedRelation.value(i)
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
+ matchedRows += buildRow(streamedRow ++ broadcastedRow)
+ matched = true
+ includedBroadcastTuples += i
+ }
+ i += 1
+ }
+
+ if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
+ matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null))
+ }
+ }
+ Iterator((matchedRows, includedBroadcastTuples))
+ }
+
+ val includedBroadcastTuples = streamedPlusMatches.map(_._2)
+ val allIncludedBroadcastTuples =
+ if (includedBroadcastTuples.count == 0) {
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ } else {
+ streamedPlusMatches.map(_._2).reduce(_ ++ _)
+ }
+
+ val rightOuterMatches: Seq[Row] =
+ if (joinType == RightOuter || joinType == FullOuter) {
+ broadcastedRelation.value.zipWithIndex.filter {
+ case (row, i) => !allIncludedBroadcastTuples.contains(i)
+ }.map {
+ // TODO: Use projection.
+ case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row)
+ }
+ } else {
+ Vector()
+ }
+
+ // TODO: Breaks lineage.
+ sc.union(
+ streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala
new file mode 100644
index 0000000000..67f6f43f90
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * An execution engine for relational query plans that runs on top Spark and returns RDDs.
+ *
+ * Note that the operators in this package are created automatically by a query planner using a
+ * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are
+ * documented here in order to make it easier for others to understand the performance
+ * characteristics of query plans that are generated by Spark SQL.
+ */
+package object execution {
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
new file mode 100644
index 0000000000..e87561fe13
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import java.io.{IOException, FileNotFoundException}
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.fs.permission.FsAction
+
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, BaseRelation}
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.ArrayType
+import org.apache.spark.sql.catalyst.expressions.{Row, AttributeReference, Attribute}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+
+import parquet.schema.{MessageTypeParser, MessageType}
+import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
+import parquet.schema.{PrimitiveType => ParquetPrimitiveType}
+import parquet.schema.{Type => ParquetType}
+import parquet.schema.Type.Repetition
+import parquet.io.api.{Binary, RecordConsumer}
+import parquet.hadoop.{Footer, ParquetFileWriter, ParquetFileReader}
+import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
+import parquet.hadoop.util.ContextUtil
+
+import scala.collection.JavaConversions._
+
+/**
+ * Relation that consists of data stored in a Parquet columnar format.
+ *
+ * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead
+ * of using this class directly.
+ *
+ * {{{
+ * val parquetRDD = sqlContext.parquetFile("path/to/parequet.file")
+ * }}}
+ *
+ * @param tableName The name of the relation that can be used in queries.
+ * @param path The path to the Parquet file.
+ */
+case class ParquetRelation(val tableName: String, val path: String) extends BaseRelation {
+
+ /** Schema derived from ParquetFile **/
+ def parquetSchema: MessageType =
+ ParquetTypesConverter
+ .readMetaData(new Path(path))
+ .getFileMetaData
+ .getSchema
+
+ /** Attributes **/
+ val attributes =
+ ParquetTypesConverter
+ .convertToAttributes(parquetSchema)
+
+ /** Output **/
+ override val output = attributes
+
+ // Parquet files have no concepts of keys, therefore no Partitioner
+ // Note: we could allow Block level access; needs to be thought through
+ override def isPartitioned = false
+}
+
+object ParquetRelation {
+
+ // The element type for the RDDs that this relation maps to.
+ type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
+ /**
+ * Creates a new ParquetRelation and underlying Parquetfile for the given
+ * LogicalPlan. Note that this is used inside [[SparkStrategies]] to
+ * create a resolved relation as a data sink for writing to a Parquetfile.
+ * The relation is empty but is initialized with ParquetMetadata and
+ * can be inserted into.
+ *
+ * @param pathString The directory the Parquetfile will be stored in.
+ * @param child The child node that will be used for extracting the schema.
+ * @param conf A configuration configuration to be used.
+ * @param tableName The name of the resulting relation.
+ * @return An empty ParquetRelation inferred metadata.
+ */
+ def create(pathString: String,
+ child: LogicalPlan,
+ conf: Configuration,
+ tableName: Option[String]): ParquetRelation = {
+ if (!child.resolved) {
+ throw new UnresolvedException[LogicalPlan](
+ child,
+ "Attempt to create Parquet table from unresolved child (when schema is not available)")
+ }
+
+ val name = s"${tableName.getOrElse(child.nodeName)}_parquet"
+ val path = checkPath(pathString, conf)
+ ParquetTypesConverter.writeMetaData(child.output, path, conf)
+ new ParquetRelation(name, path.toString)
+ }
+
+ private def checkPath(pathStr: String, conf: Configuration): Path = {
+ if (pathStr == null) {
+ throw new IllegalArgumentException("Unable to create ParquetRelation: path is null")
+ }
+ val origPath = new Path(pathStr)
+ val fs = origPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException(
+ s"Unable to create ParquetRelation: incorrectly formatted path $pathStr")
+ }
+ val path = origPath.makeQualified(fs)
+ if (fs.exists(path) &&
+ !fs.getFileStatus(path)
+ .getPermission
+ .getUserAction
+ .implies(FsAction.READ_WRITE)) {
+ throw new IOException(
+ s"Unable to create ParquetRelation: path $path not read-writable")
+ }
+ path
+ }
+}
+
+object ParquetTypesConverter {
+ def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match {
+ // for now map binary to string type
+ // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema
+ case ParquetPrimitiveTypeName.BINARY => StringType
+ case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
+ case ParquetPrimitiveTypeName.DOUBLE => DoubleType
+ case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType)
+ case ParquetPrimitiveTypeName.FLOAT => FloatType
+ case ParquetPrimitiveTypeName.INT32 => IntegerType
+ case ParquetPrimitiveTypeName.INT64 => LongType
+ case ParquetPrimitiveTypeName.INT96 => {
+ // TODO: add BigInteger type? TODO(andre) use DecimalType instead????
+ sys.error("Warning: potential loss of precision: converting INT96 to long")
+ LongType
+ }
+ case _ => sys.error(
+ s"Unsupported parquet datatype $parquetType")
+ }
+
+ def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match {
+ case StringType => ParquetPrimitiveTypeName.BINARY
+ case BooleanType => ParquetPrimitiveTypeName.BOOLEAN
+ case DoubleType => ParquetPrimitiveTypeName.DOUBLE
+ case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
+ case FloatType => ParquetPrimitiveTypeName.FLOAT
+ case IntegerType => ParquetPrimitiveTypeName.INT32
+ case LongType => ParquetPrimitiveTypeName.INT64
+ case _ => sys.error(s"Unsupported datatype $ctype")
+ }
+
+ def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = {
+ ctype match {
+ case StringType => consumer.addBinary(
+ Binary.fromByteArray(
+ record(index).asInstanceOf[String].getBytes("utf-8")
+ )
+ )
+ case IntegerType => consumer.addInteger(record.getInt(index))
+ case LongType => consumer.addLong(record.getLong(index))
+ case DoubleType => consumer.addDouble(record.getDouble(index))
+ case FloatType => consumer.addFloat(record.getFloat(index))
+ case BooleanType => consumer.addBoolean(record.getBoolean(index))
+ case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
+ }
+ }
+
+ def getSchema(schemaString : String) : MessageType =
+ MessageTypeParser.parseMessageType(schemaString)
+
+ def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = {
+ parquetSchema.getColumns.map {
+ case (desc) => {
+ val ctype = toDataType(desc.getType)
+ val name: String = desc.getPath.mkString(".")
+ new AttributeReference(name, ctype, false)()
+ }
+ }
+ }
+
+ // TODO: allow nesting?
+ def convertFromAttributes(attributes: Seq[Attribute]): MessageType = {
+ val fields: Seq[ParquetType] = attributes.map {
+ a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name)
+ }
+ new MessageType("root", fields)
+ }
+
+ def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) {
+ if (origPath == null) {
+ throw new IllegalArgumentException("Unable to write Parquet metadata: path is null")
+ }
+ val fs = origPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException(
+ s"Unable to write Parquet metadata: path $origPath is incorrectly formatted")
+ }
+ val path = origPath.makeQualified(fs)
+ if (fs.exists(path) && !fs.getFileStatus(path).isDir) {
+ throw new IllegalArgumentException(s"Expected to write to directory $path but found file")
+ }
+ val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)
+ if (fs.exists(metadataPath)) {
+ try {
+ fs.delete(metadataPath, true)
+ } catch {
+ case e: IOException =>
+ throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath")
+ }
+ }
+ val extraMetadata = new java.util.HashMap[String, String]()
+ extraMetadata.put("path", path.toString)
+ // TODO: add extra data, e.g., table name, date, etc.?
+
+ val parquetSchema: MessageType =
+ ParquetTypesConverter.convertFromAttributes(attributes)
+ val metaData: FileMetaData = new FileMetaData(
+ parquetSchema,
+ extraMetadata,
+ "Spark")
+
+ ParquetFileWriter.writeMetadataFile(
+ conf,
+ path,
+ new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil)
+ }
+
+ /**
+ * Try to read Parquet metadata at the given Path. We first see if there is a summary file
+ * in the parent directory. If so, this is used. Else we read the actual footer at the given
+ * location.
+ * @param path The path at which we expect one (or more) Parquet files.
+ * @return The `ParquetMetadata` containing among other things the schema.
+ */
+ def readMetaData(origPath: Path): ParquetMetadata = {
+ if (origPath == null) {
+ throw new IllegalArgumentException("Unable to read Parquet metadata: path is null")
+ }
+ val job = new Job()
+ // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access
+ // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!)
+ val conf = ContextUtil.getConfiguration(job)
+ val fs: FileSystem = origPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath")
+ }
+ val path = origPath.makeQualified(fs)
+ val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)
+ if (fs.exists(metadataPath) && fs.isFile(metadataPath)) {
+ // TODO: improve exception handling, etc.
+ ParquetFileReader.readFooter(conf, metadataPath)
+ } else {
+ if (!fs.exists(path) || !fs.isFile(path)) {
+ throw new FileNotFoundException(
+ s"Could not find file ${path.toString} when trying to read metadata")
+ }
+ ParquetFileReader.readFooter(conf, path)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
new file mode 100644
index 0000000000..61121103cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import parquet.io.InvalidRecordException
+import parquet.schema.MessageType
+import parquet.hadoop.{ParquetOutputFormat, ParquetInputFormat}
+import parquet.hadoop.util.ContextUtil
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.{TaskContext, SerializableWritable, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.{Row, Attribute, Expression}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, LeafNode}
+
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import java.io.IOException
+import java.text.SimpleDateFormat
+import java.util.Date
+
+/**
+ * Parquet table scan operator. Imports the file that backs the given
+ * [[ParquetRelation]] as a RDD[Row].
+ */
+case class ParquetTableScan(
+ @transient output: Seq[Attribute],
+ @transient relation: ParquetRelation,
+ @transient columnPruningPred: Option[Expression])(
+ @transient val sc: SparkContext)
+ extends LeafNode {
+
+ override def execute(): RDD[Row] = {
+ val job = new Job(sc.hadoopConfiguration)
+ ParquetInputFormat.setReadSupportClass(
+ job,
+ classOf[org.apache.spark.sql.parquet.RowReadSupport])
+ val conf: Configuration = ContextUtil.getConfiguration(job)
+ conf.set(
+ RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA,
+ ParquetTypesConverter.convertFromAttributes(output).toString)
+ // TODO: think about adding record filters
+ /* Comments regarding record filters: it would be nice to push down as much filtering
+ to Parquet as possible. However, currently it seems we cannot pass enough information
+ to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's
+ ``FilteredRecordReader`` (via Configuration, for example). Simple
+ filter-rows-by-column-values however should be supported.
+ */
+ sc.newAPIHadoopFile(
+ relation.path,
+ classOf[ParquetInputFormat[Row]],
+ classOf[Void], classOf[Row],
+ conf)
+ .map(_._2)
+ }
+
+ /**
+ * Applies a (candidate) projection.
+ *
+ * @param prunedAttributes The list of attributes to be used in the projection.
+ * @return Pruned TableScan.
+ */
+ def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
+ val success = validateProjection(prunedAttributes)
+ if (success) {
+ ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
+ } else {
+ sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
+ this
+ }
+ }
+
+ /**
+ * Evaluates a candidate projection by checking whether the candidate is a subtype
+ * of the original type.
+ *
+ * @param projection The candidate projection.
+ * @return True if the projection is valid, false otherwise.
+ */
+ private def validateProjection(projection: Seq[Attribute]): Boolean = {
+ val original: MessageType = relation.parquetSchema
+ val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection)
+ try {
+ original.checkContains(candidate)
+ true
+ } catch {
+ case e: InvalidRecordException => {
+ false
+ }
+ }
+ }
+}
+
+case class InsertIntoParquetTable(
+ @transient relation: ParquetRelation,
+ @transient child: SparkPlan)(
+ @transient val sc: SparkContext)
+ extends UnaryNode with SparkHadoopMapReduceUtil {
+
+ /**
+ * Inserts all the rows in the Parquet file. Note that OVERWRITE is implicit, since
+ * Parquet files are write-once.
+ */
+ override def execute() = {
+ // TODO: currently we do not check whether the "schema"s are compatible
+ // That means if one first creates a table and then INSERTs data with
+ // and incompatible schema the execution will fail. It would be nice
+ // to catch this early one, maybe having the planner validate the schema
+ // before calling execute().
+
+ val childRdd = child.execute()
+ assert(childRdd != null)
+
+ val job = new Job(sc.hadoopConfiguration)
+
+ ParquetOutputFormat.setWriteSupportClass(
+ job,
+ classOf[org.apache.spark.sql.parquet.RowWriteSupport])
+
+ // TODO: move that to function in object
+ val conf = job.getConfiguration
+ conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString)
+
+ val fspath = new Path(relation.path)
+ val fs = fspath.getFileSystem(conf)
+
+ try {
+ fs.delete(fspath, true)
+ } catch {
+ case e: IOException =>
+ throw new IOException(
+ s"Unable to clear output directory ${fspath.toString} prior"
+ + s" to InsertIntoParquetTable:\n${e.toString}")
+ }
+ saveAsHadoopFile(childRdd, relation.path.toString, conf)
+
+ // We return the child RDD to allow chaining (alternatively, one could return nothing).
+ childRdd
+ }
+
+ override def output = child.output
+
+ // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]]
+ // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2?
+ // .. then we could use the default one and could use [[MutablePair]]
+ // instead of ``Tuple2``
+ private def saveAsHadoopFile(
+ rdd: RDD[Row],
+ path: String,
+ conf: Configuration) {
+ val job = new Job(conf)
+ val keyType = classOf[Void]
+ val outputFormatType = classOf[parquet.hadoop.ParquetOutputFormat[Row]]
+ job.setOutputKeyClass(keyType)
+ job.setOutputValueClass(classOf[Row])
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ NewFileOutputFormat.setOutputPath(job, new Path(path))
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ val stageId = sc.newRddId()
+
+ def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ /* "reduce task" <split #> <attempt # = spark task #> */
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+ attemptNumber)
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val format = outputFormatType.newInstance
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext)
+ while (iter.hasNext) {
+ val row = iter.next()
+ writer.write(null, row)
+ }
+ writer.close(hadoopContext)
+ committer.commitTask(hadoopContext)
+ return 1
+ }
+ val jobFormat = outputFormatType.newInstance
+ /* apparently we need a TaskAttemptID to construct an OutputCommitter;
+ * however we're only going to use this local OutputCommitter for
+ * setupJob/commitJob, so we just use a dummy "map" task.
+ */
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+ jobCommitter.setupJob(jobTaskContext)
+ sc.runJob(rdd, writeShard _)
+ jobCommitter.commitJob(jobTaskContext)
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
new file mode 100644
index 0000000000..c2ae18b882
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.Logging
+
+import parquet.io.api._
+import parquet.schema.{MessageTypeParser, MessageType}
+import parquet.hadoop.api.{WriteSupport, ReadSupport}
+import parquet.hadoop.api.ReadSupport.ReadContext
+import parquet.hadoop.ParquetOutputFormat
+import parquet.column.ParquetProperties
+
+import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A `parquet.io.api.RecordMaterializer` for Rows.
+ *
+ *@param root The root group converter for the record.
+ */
+class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterializer[Row] {
+
+ def this(parquetSchema: MessageType) =
+ this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema)))
+
+ override def getCurrentRecord: Row = root.getCurrentRecord
+
+ override def getRootConverter: GroupConverter = root
+}
+
+/**
+ * A `parquet.hadoop.api.ReadSupport` for Row objects.
+ */
+class RowReadSupport extends ReadSupport[Row] with Logging {
+
+ override def prepareForRead(
+ conf: Configuration,
+ stringMap: java.util.Map[String, String],
+ fileSchema: MessageType,
+ readContext: ReadContext): RecordMaterializer[Row] = {
+ log.debug(s"preparing for read with schema ${fileSchema.toString}")
+ new RowRecordMaterializer(readContext.getRequestedSchema)
+ }
+
+ override def init(
+ configuration: Configuration,
+ keyValueMetaData: java.util.Map[String, String],
+ fileSchema: MessageType): ReadContext = {
+ val requested_schema_string =
+ configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString)
+ val requested_schema =
+ MessageTypeParser.parseMessageType(requested_schema_string)
+
+ log.debug(s"read support initialized for original schema ${requested_schema.toString}")
+ new ReadContext(requested_schema, keyValueMetaData)
+ }
+}
+
+object RowReadSupport {
+ val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema"
+}
+
+/**
+ * A `parquet.hadoop.api.WriteSupport` for Row ojects.
+ */
+class RowWriteSupport extends WriteSupport[Row] with Logging {
+ def setSchema(schema: MessageType, configuration: Configuration) {
+ // for testing
+ this.schema = schema
+ // TODO: could use Attributes themselves instead of Parquet schema?
+ configuration.set(
+ RowWriteSupport.PARQUET_ROW_SCHEMA,
+ schema.toString)
+ configuration.set(
+ ParquetOutputFormat.WRITER_VERSION,
+ ParquetProperties.WriterVersion.PARQUET_1_0.toString)
+ }
+
+ def getSchema(configuration: Configuration): MessageType = {
+ return MessageTypeParser.parseMessageType(
+ configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA))
+ }
+
+ private var schema: MessageType = null
+ private var writer: RecordConsumer = null
+ private var attributes: Seq[Attribute] = null
+
+ override def init(configuration: Configuration): WriteSupport.WriteContext = {
+ schema = if (schema == null) getSchema(configuration) else schema
+ attributes = ParquetTypesConverter.convertToAttributes(schema)
+ new WriteSupport.WriteContext(
+ schema,
+ new java.util.HashMap[java.lang.String, java.lang.String]());
+ }
+
+ override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+ writer = recordConsumer
+ }
+
+ // TODO: add groups (nested fields)
+ override def write(record: Row): Unit = {
+ var index = 0
+ writer.startMessage()
+ while(index < attributes.size) {
+ // null values indicate optional fields but we do not check currently
+ if (record(index) != null && record(index) != Nil) {
+ writer.startField(attributes(index).name, index)
+ ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index)
+ writer.endField(attributes(index).name, index)
+ }
+ index = index + 1
+ }
+ writer.endMessage()
+ }
+}
+
+object RowWriteSupport {
+ val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema"
+}
+
+/**
+ * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record
+ * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object.
+ *
+ * @param schema The corresponding Catalyst schema in the form of a list of attributes.
+ */
+class CatalystGroupConverter(
+ schema: Seq[Attribute],
+ protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter {
+
+ def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length))
+
+ val converters: Array[Converter] = schema.map {
+ a => a.dataType match {
+ case ctype: NativeType =>
+ // note: for some reason matching for StringType fails so use this ugly if instead
+ if (ctype == StringType) new CatalystPrimitiveStringConverter(this, schema.indexOf(a))
+ else new CatalystPrimitiveConverter(this, schema.indexOf(a))
+ case _ => throw new RuntimeException(
+ s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter")
+ }
+ }.toArray
+
+ override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
+
+ private[parquet] def getCurrentRecord: ParquetRelation.RowType = current
+
+ override def start(): Unit = {
+ var i = 0
+ while (i < schema.length) {
+ current.setNullAt(i)
+ i = i + 1
+ }
+ }
+
+ override def end(): Unit = {}
+}
+
+/**
+ * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types.
+ *
+ * @param parent The parent group converter.
+ * @param fieldIndex The index inside the record.
+ */
+class CatalystPrimitiveConverter(
+ parent: CatalystGroupConverter,
+ fieldIndex: Int) extends PrimitiveConverter {
+ // TODO: consider refactoring these together with ParquetTypesConverter
+ override def addBinary(value: Binary): Unit =
+ // TODO: fix this once a setBinary will become available in MutableRow
+ parent.getCurrentRecord.setByte(fieldIndex, value.getBytes.apply(0))
+
+ override def addBoolean(value: Boolean): Unit =
+ parent.getCurrentRecord.setBoolean(fieldIndex, value)
+
+ override def addDouble(value: Double): Unit =
+ parent.getCurrentRecord.setDouble(fieldIndex, value)
+
+ override def addFloat(value: Float): Unit =
+ parent.getCurrentRecord.setFloat(fieldIndex, value)
+
+ override def addInt(value: Int): Unit =
+ parent.getCurrentRecord.setInt(fieldIndex, value)
+
+ override def addLong(value: Long): Unit =
+ parent.getCurrentRecord.setLong(fieldIndex, value)
+}
+
+/**
+ * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays)
+ * into Catalyst Strings.
+ *
+ * @param parent The parent group converter.
+ * @param fieldIndex The index inside the record.
+ */
+class CatalystPrimitiveStringConverter(
+ parent: CatalystGroupConverter,
+ fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) {
+ override def addBinary(value: Binary): Unit =
+ parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8)
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
new file mode 100644
index 0000000000..bbe409fb9c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.Job
+
+import parquet.schema.{MessageTypeParser, MessageType}
+import parquet.hadoop.util.ContextUtil
+import parquet.hadoop.ParquetWriter
+
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import java.nio.charset.Charset
+
+object ParquetTestData {
+
+ val testSchema =
+ """message myrecord {
+ |optional boolean myboolean;
+ |optional int32 myint;
+ |optional binary mystring;
+ |optional int64 mylong;
+ |optional float myfloat;
+ |optional double mydouble;
+ |}""".stripMargin
+
+ // field names for test assertion error messages
+ val testSchemaFieldNames = Seq(
+ "myboolean:Boolean",
+ "mtint:Int",
+ "mystring:String",
+ "mylong:Long",
+ "myfloat:Float",
+ "mydouble:Double"
+ )
+
+ val subTestSchema =
+ """
+ |message myrecord {
+ |optional boolean myboolean;
+ |optional int64 mylong;
+ |}
+ """.stripMargin
+
+ // field names for test assertion error messages
+ val subTestSchemaFieldNames = Seq(
+ "myboolean:Boolean",
+ "mylong:Long"
+ )
+
+ val testFile = getTempFilePath("testParquetFile").getCanonicalFile
+
+ lazy val testData = new ParquetRelation("testData", testFile.toURI.toString)
+
+ def writeFile = {
+ testFile.delete
+ val path: Path = new Path(testFile.toURI)
+ val job = new Job()
+ val configuration: Configuration = ContextUtil.getConfiguration(job)
+ val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
+
+ val writeSupport = new RowWriteSupport()
+ writeSupport.setSchema(schema, configuration)
+ val writer = new ParquetWriter(path, writeSupport)
+ for(i <- 0 until 15) {
+ val data = new Array[Any](6)
+ if (i % 3 == 0) {
+ data.update(0, true)
+ } else {
+ data.update(0, false)
+ }
+ if (i % 5 == 0) {
+ data.update(1, 5)
+ } else {
+ data.update(1, null) // optional
+ }
+ data.update(2, "abc")
+ data.update(3, i.toLong << 33)
+ data.update(4, 2.5F)
+ data.update(5, 4.5D)
+ writer.write(new GenericRow(data.toArray))
+ }
+ writer.close()
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
new file mode 100644
index 0000000000..ca56c4476b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+package sql
+package test
+
+/** A SQLContext that can be used for local testing. */
+object TestSQLContext
+ extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..7bb6789bd3
--- /dev/null
+++ b/sql/core/src/test/resources/log4j.properties
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootLogger=DEBUG, CA, FA
+
+#Console Appender
+log4j.appender.CA=org.apache.log4j.ConsoleAppender
+log4j.appender.CA.layout=org.apache.log4j.PatternLayout
+log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n
+log4j.appender.CA.Threshold = WARN
+
+
+#File Appender
+log4j.appender.FA=org.apache.log4j.FileAppender
+log4j.appender.FA.append=false
+log4j.appender.FA.file=target/unit-tests.log
+log4j.appender.FA.layout=org.apache.log4j.PatternLayout
+log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Set the logger level of File Appender to WARN
+log4j.appender.FA.Threshold = INFO
+
+# Some packages are noisy for no good reason.
+log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
+log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF
+
+log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false
+log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF
+
+log4j.additivity.hive.ql.metadata.Hive=false
+log4j.logger.hive.ql.metadata.Hive=OFF
+
+# Parquet logging
+parquet.hadoop.InternalParquetRecordReader=WARN
+log4j.logger.parquet.hadoop.InternalParquetRecordReader=WARN
+parquet.hadoop.ParquetInputFormat=WARN
+log4j.logger.parquet.hadoop.ParquetInputFormat=WARN
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
new file mode 100644
index 0000000000..37c90a18a0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+class DslQuerySuite extends QueryTest {
+ import TestData._
+
+ test("table scan") {
+ checkAnswer(
+ testData,
+ testData.collect().toSeq)
+ }
+
+ test("agg") {
+ checkAnswer(
+ testData2.groupBy('a)('a, Sum('b)),
+ Seq((1,3),(2,3),(3,3))
+ )
+ }
+
+ test("select *") {
+ checkAnswer(
+ testData.select(Star(None)),
+ testData.collect().toSeq)
+ }
+
+ test("simple select") {
+ checkAnswer(
+ testData.where('key === 1).select('value),
+ Seq(Seq("1")))
+ }
+
+ test("sorting") {
+ checkAnswer(
+ testData2.orderBy('a.asc, 'b.asc),
+ Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+
+ checkAnswer(
+ testData2.orderBy('a.asc, 'b.desc),
+ Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+
+ checkAnswer(
+ testData2.orderBy('a.desc, 'b.desc),
+ Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+
+ checkAnswer(
+ testData2.orderBy('a.desc, 'b.asc),
+ Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ }
+
+ test("average") {
+ checkAnswer(
+ testData2.groupBy()(Average('a)),
+ 2.0)
+ }
+
+ test("count") {
+ checkAnswer(
+ testData2.groupBy()(Count(1)),
+ testData2.count()
+ )
+ }
+
+ test("null count") {
+ checkAnswer(
+ testData3.groupBy('a)('a, Count('b)),
+ Seq((1,0), (2, 1))
+ )
+
+ checkAnswer(
+ testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
+ (2, 1, 2, 2, 1) :: Nil
+ )
+ }
+
+ test("inner join where, one match per row") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join ON, one match per row") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join, where, multiple matches") {
+ val x = testData2.where('a === 1).subquery('x)
+ val y = testData2.where('a === 1).subquery('y)
+ checkAnswer(
+ x.join(y).where("x.a".attr === "y.a".attr),
+ (1,1,1,1) ::
+ (1,1,1,2) ::
+ (1,2,1,1) ::
+ (1,2,1,2) :: Nil
+ )
+ }
+
+ test("inner join, no matches") {
+ val x = testData2.where('a === 1).subquery('x)
+ val y = testData2.where('a === 2).subquery('y)
+ checkAnswer(
+ x.join(y).where("x.a".attr === "y.a".attr),
+ Nil)
+ }
+
+ test("big inner join, 4 matches per row") {
+ val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
+ val bigDataX = bigData.subquery('x)
+ val bigDataY = bigData.subquery('y)
+
+ checkAnswer(
+ bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
+ testData.flatMap(
+ row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ }
+
+ test("cartisian product join") {
+ checkAnswer(
+ testData3.join(testData3),
+ (1, null, 1, null) ::
+ (1, null, 2, 2) ::
+ (2, 2, 1, null) ::
+ (2, 2, 2, 2) :: Nil)
+ }
+
+ test("left outer join") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+ (1, "A", 1, "a") ::
+ (2, "B", 2, "b") ::
+ (3, "C", 3, "c") ::
+ (4, "D", 4, "d") ::
+ (5, "E", null, null) ::
+ (6, "F", null, null) :: Nil)
+ }
+
+ test("right outer join") {
+ checkAnswer(
+ lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+ (1, "a", 1, "A") ::
+ (2, "b", 2, "B") ::
+ (3, "c", 3, "C") ::
+ (4, "d", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+
+ test("full outer join") {
+ val left = upperCaseData.where('N <= 4).subquery('left)
+ val right = upperCaseData.where('N >= 3).subquery('right)
+
+ checkAnswer(
+ left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+ (1, "A", null, null) ::
+ (2, "B", null, null) ::
+ (3, "C", 3, "C") ::
+ (4, "D", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+} \ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala
new file mode 100644
index 0000000000..83908edf5a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.planner._
+
+class PlannerSuite extends FunSuite {
+
+
+ test("unions are collapsed") {
+ val query = testData.unionAll(testData).unionAll(testData).logicalPlan
+ val planned = BasicOperators(query).head
+ val logicalUnions = query collect { case u: logical.Union => u}
+ val physicalUnions = planned collect { case u: execution.Union => u}
+
+ assert(logicalUnions.size === 2)
+ assert(physicalUnions.size === 1)
+ }
+
+ test("count is partially aggregated") {
+ val query = testData.groupBy('value)(Count('key)).analyze.logicalPlan
+ val planned = PartialAggregation(query).head
+ val aggregations = planned.collect { case a: Aggregate => a }
+
+ assert(aggregations.size === 2)
+ }
+
+ test("count distinct is not partially aggregated") {
+ val query = testData.groupBy('value)(CountDistinct('key :: Nil)).analyze.logicalPlan
+ val planned = PartialAggregation(query.logicalPlan)
+ assert(planned.isEmpty)
+ }
+
+ test("mixed aggregates are not partially aggregated") {
+ val query =
+ testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).analyze.logicalPlan
+ val planned = PartialAggregation(query)
+ assert(planned.isEmpty)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
new file mode 100644
index 0000000000..728feceded
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+class QueryTest extends FunSuite {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * @param plan the query to be executed
+ * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ */
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
+ val convertedAnswer = expectedAnswer match {
+ case s: Seq[_] if s.isEmpty => s
+ case s: Seq[_] if s.head.isInstanceOf[Product] &&
+ !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
+ case s: Seq[_] => s
+ case singleItem => Seq(Seq(singleItem))
+ }
+
+ val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty
+ def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
+ val sparkAnswer = try rdd.collect().toSeq catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception thrown while executing query:
+ |${rdd.logicalPlan}
+ |== Exception ==
+ |$e
+ """.stripMargin)
+ }
+ if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ fail(s"""
+ |Results do not match for query:
+ |${rdd.logicalPlan}
+ |== Analyzed Plan ==
+ |${rdd.queryExecution.analyzed}
+ |== RDD ==
+ |$rdd
+ |== Results ==
+ |${sideBySide(
+ prepareAnswer(convertedAnswer).map(_.toString),
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
new file mode 100644
index 0000000000..5728313d6d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+import TestData._
+
+class SQLQuerySuite extends QueryTest {
+ test("agg") {
+ checkAnswer(
+ sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
+ Seq((1,3),(2,3),(3,3))
+ )
+ }
+
+ test("select *") {
+ checkAnswer(
+ sql("SELECT * FROM testData"),
+ testData.collect().toSeq)
+ }
+
+ test("simple select") {
+ checkAnswer(
+ sql("SELECT value FROM testData WHERE key = 1"),
+ Seq(Seq("1")))
+ }
+
+ test("sorting") {
+ checkAnswer(
+ sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
+ Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+
+ checkAnswer(
+ sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"),
+ Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+
+ checkAnswer(
+ sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"),
+ Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+
+ checkAnswer(
+ sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
+ Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ }
+
+ test("average") {
+ checkAnswer(
+ sql("SELECT AVG(a) FROM testData2"),
+ 2.0)
+ }
+
+ test("count") {
+ checkAnswer(
+ sql("SELECT COUNT(*) FROM testData2"),
+ testData2.count()
+ )
+ }
+
+ // No support for primitive nulls yet.
+ ignore("null count") {
+ checkAnswer(
+ sql("SELECT a, COUNT(b) FROM testData3"),
+ Seq((1,0), (2, 1))
+ )
+
+ checkAnswer(
+ testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
+ (2, 1, 2, 2, 1) :: Nil
+ )
+ }
+
+ test("inner join where, one match per row") {
+ checkAnswer(
+ sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join ON, one match per row") {
+ checkAnswer(
+ sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join, where, multiple matches") {
+ checkAnswer(
+ sql("""
+ |SELECT * FROM
+ | (SELECT * FROM testData2 WHERE a = 1) x JOIN
+ | (SELECT * FROM testData2 WHERE a = 1) y
+ |WHERE x.a = y.a""".stripMargin),
+ (1,1,1,1) ::
+ (1,1,1,2) ::
+ (1,2,1,1) ::
+ (1,2,1,2) :: Nil
+ )
+ }
+
+ test("inner join, no matches") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT * FROM
+ | (SELECT * FROM testData2 WHERE a = 1) x JOIN
+ | (SELECT * FROM testData2 WHERE a = 2) y
+ |WHERE x.a = y.a""".stripMargin),
+ Nil)
+ }
+
+ test("big inner join, 4 matches per row") {
+
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT * FROM
+ | (SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData) x JOIN
+ | (SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData UNION ALL
+ | SELECT * FROM testData) y
+ |WHERE x.key = y.key""".stripMargin),
+ testData.flatMap(
+ row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ }
+
+ ignore("cartisian product join") {
+ checkAnswer(
+ testData3.join(testData3),
+ (1, null, 1, null) ::
+ (1, null, 2, 2) ::
+ (2, 2, 1, null) ::
+ (2, 2, 2, 2) :: Nil)
+ }
+
+ test("left outer join") {
+ checkAnswer(
+ sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"),
+ (1, "A", 1, "a") ::
+ (2, "B", 2, "b") ::
+ (3, "C", 3, "c") ::
+ (4, "D", 4, "d") ::
+ (5, "E", null, null) ::
+ (6, "F", null, null) :: Nil)
+ }
+
+ test("right outer join") {
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"),
+ (1, "a", 1, "A") ::
+ (2, "b", 2, "B") ::
+ (3, "c", 3, "C") ::
+ (4, "d", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+
+ test("full outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT * FROM
+ | (SELECT * FROM upperCaseData WHERE N <= 4) left FULL OUTER JOIN
+ | (SELECT * FROM upperCaseData WHERE N >= 3) right
+ | ON left.N = right.N
+ """.stripMargin),
+ (1, "A", null, null) ::
+ (2, "B", null, null) ::
+ (3, "C", 3, "C") ::
+ (4, "D", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+} \ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
new file mode 100644
index 0000000000..640292571b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+object TestData {
+ case class TestData(key: Int, value: String)
+ val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)))
+ testData.registerAsTable("testData")
+
+ case class TestData2(a: Int, b: Int)
+ val testData2: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ TestData2(1, 1) ::
+ TestData2(1, 2) ::
+ TestData2(2, 1) ::
+ TestData2(2, 2) ::
+ TestData2(3, 1) ::
+ TestData2(3, 2) :: Nil
+ )
+ testData2.registerAsTable("testData2")
+
+ // TODO: There is no way to express null primitives as case classes currently...
+ val testData3 =
+ logical.LocalRelation('a.int, 'b.int).loadData(
+ (1, null) ::
+ (2, 2) :: Nil
+ )
+
+ case class UpperCaseData(N: Int, L: String)
+ val upperCaseData =
+ TestSQLContext.sparkContext.parallelize(
+ UpperCaseData(1, "A") ::
+ UpperCaseData(2, "B") ::
+ UpperCaseData(3, "C") ::
+ UpperCaseData(4, "D") ::
+ UpperCaseData(5, "E") ::
+ UpperCaseData(6, "F") :: Nil
+ )
+ upperCaseData.registerAsTable("upperCaseData")
+
+ case class LowerCaseData(n: Int, l: String)
+ val lowerCaseData =
+ TestSQLContext.sparkContext.parallelize(
+ LowerCaseData(1, "a") ::
+ LowerCaseData(2, "b") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(4, "d") :: Nil
+ )
+ lowerCaseData.registerAsTable("lowerCaseData")
+} \ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala
new file mode 100644
index 0000000000..08265b7a6a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.test._
+
+
+import TestSQLContext._
+
+/**
+ * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns
+ * from the input data. These will be replaced during analysis with specific AttributeReferences
+ * and then bound to specific ordinals during query planning. While TGFs could also access specific
+ * columns using hand-coded ordinals, doing so violates data independence.
+ *
+ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely
+ * involve a lot more sugar for cleaner use in Scala/Java/etc.
+ */
+case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator {
+ def children = input
+ protected def makeOutput() = 'nameAndAge.string :: Nil
+
+ val Seq(nameAttr, ageAttr) = input
+
+ override def apply(input: Row): TraversableOnce[Row] = {
+ val name = nameAttr.apply(input)
+ val age = ageAttr.apply(input).asInstanceOf[Int]
+
+ Iterator(
+ new GenericRow(Array[Any](s"$name is $age years old")),
+ new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old")))
+ }
+}
+
+class TgfSuite extends QueryTest {
+ val inputData =
+ logical.LocalRelation('name.string, 'age.int).loadData(
+ ("michael", 29) :: Nil
+ )
+
+ test("simple tgf example") {
+ checkAnswer(
+ inputData.generate(ExampleTGF()),
+ Seq(
+ "michael is 29 years old" :: Nil,
+ "Next year, michael will be 30 years old" :: Nil))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
new file mode 100644
index 0000000000..8b2ccb52d8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.test.TestSQLContext
+
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import parquet.schema.MessageTypeParser
+import parquet.hadoop.ParquetFileWriter
+import parquet.hadoop.util.ContextUtil
+
+class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll {
+ override def beforeAll() {
+ ParquetTestData.writeFile
+ }
+
+ override def afterAll() {
+ ParquetTestData.testFile.delete()
+ }
+
+ test("Import of simple Parquet file") {
+ val result = getRDD(ParquetTestData.testData).collect()
+ assert(result.size === 15)
+ result.zipWithIndex.foreach {
+ case (row, index) => {
+ val checkBoolean =
+ if (index % 3 == 0)
+ row(0) == true
+ else
+ row(0) == false
+ assert(checkBoolean === true, s"boolean field value in line $index did not match")
+ if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match")
+ assert(row(2) === "abc", s"string field value in line $index did not match")
+ assert(row(3) === (index.toLong << 33), s"long value in line $index did not match")
+ assert(row(4) === 2.5F, s"float field value in line $index did not match")
+ assert(row(5) === 4.5D, s"double field value in line $index did not match")
+ }
+ }
+ }
+
+ test("Projection of simple Parquet file") {
+ val scanner = new ParquetTableScan(
+ ParquetTestData.testData.output,
+ ParquetTestData.testData,
+ None)(TestSQLContext.sparkContext)
+ val projected = scanner.pruneColumns(ParquetTypesConverter
+ .convertToAttributes(MessageTypeParser
+ .parseMessageType(ParquetTestData.subTestSchema)))
+ assert(projected.output.size === 2)
+ val result = projected
+ .execute()
+ .map(_.copy())
+ .collect()
+ result.zipWithIndex.foreach {
+ case (row, index) => {
+ if (index % 3 == 0)
+ assert(row(0) === true, s"boolean field value in line $index did not match (every third row)")
+ else
+ assert(row(0) === false, s"boolean field value in line $index did not match")
+ assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match")
+ assert(row.size === 2, s"number of columns in projection in line $index is incorrect")
+ }
+ }
+ }
+
+ test("Writing metadata from scratch for table CREATE") {
+ val job = new Job()
+ val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString)
+ val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job))
+ ParquetTypesConverter.writeMetaData(
+ ParquetTestData.testData.output,
+ path,
+ TestSQLContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
+ val metaData = ParquetTypesConverter.readMetaData(path)
+ assert(metaData != null)
+ ParquetTestData
+ .testData
+ .parquetSchema
+ .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible
+ metaData
+ .getFileMetaData
+ .getSchema
+ .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible
+ fs.delete(path, true)
+ }
+
+ /**
+ * Computes the given [[ParquetRelation]] and returns its RDD.
+ *
+ * @param parquetRelation The Parquet relation.
+ * @return An RDD of Rows.
+ */
+ private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = {
+ val scanner = new ParquetTableScan(
+ parquetRelation.output,
+ parquetRelation,
+ None)(TestSQLContext.sparkContext)
+ scanner
+ .execute
+ .map(_.copy())
+ }
+}
+