diff options
Diffstat (limited to 'sql/core/src')
26 files changed, 3385 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala new file mode 100644 index 0000000000..b8b9e5839d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.language.implicitConversions + +import scala.reflect._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.Aggregator +import org.apache.spark.SparkContext._ +import org.apache.spark.util.collection.AppendOnlyMap + +/** + * Extra functions on RDDs that perform only local operations. These can be used when data has + * already been partitioned correctly. + */ +private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) + extends Logging + with Serializable { + + /** + * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have + * the same number of partitions. Partitions of these two RDDs are cogrouped + * according to the indexes of partitions. If we have two RDDs and + * each of them has n partitions, we will cogroup the partition i from `this` + * with the partition i from `other`. + * This function will not introduce a shuffling operation. + */ + def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => { + val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] + + val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any]) + } + + val getSeq = (k: K) => { + map.changeValue(k, update) + } + + iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 } + iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 } + + map.iterator + }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])} + + cg + } + + /** + * Group the values for each key within a partition of the RDD into a single sequence. + * This function will not introduce a shuffling operation. + */ + def groupByKeyLocally(): RDD[(K, Seq[V])] = { + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _) + val bufs = self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) + }, preservesPartitioning = true) + bufs.asInstanceOf[RDD[(K, Seq[V])]] + } + + /** + * Join corresponding partitions of `this` and `other`. + * If we have two RDDs and each of them has n partitions, + * we will join the partition i from `this` with the partition i from `other`. + * This function will not introduce a shuffling operation. + */ + def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + cogroupLocally(other).flatMapValues { + case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } +} + +private[spark] object PartitionLocalRDDFunctions { + implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = + new PartitionLocalRDDFunctions(rdd) +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala new file mode 100644 index 0000000000..587cc7487f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.planning.QueryPlanner +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand, WriteToFile} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution._ + +/** + * <span class="badge" style="float: right; background-color: darkblue;">ALPHA COMPONENT</span> + * + * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * objects and the execution of SQL queries. + * + * @groupname userf Spark SQL Functions + * @groupname Ungrouped Support functions for language integrated queries. + */ +class SQLContext(@transient val sparkContext: SparkContext) + extends Logging + with dsl.ExpressionConversions + with Serializable { + + self => + + @transient + protected[sql] lazy val catalog: Catalog = new SimpleCatalog + @transient + protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + @transient + protected[sql] val optimizer = Optimizer + @transient + protected[sql] val parser = new catalyst.SqlParser + + protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + /** + * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span> + * + * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan + * interface is considered internal, and thus not guranteed to be stable. As a result, using + * them directly is not reccomended. + */ + implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = new SchemaRDD(this, plan) + + /** + * Creates a SchemaRDD from an RDD of case classes. + * + * @group userf + */ + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + + /** + * Loads a parequet file, returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + def parquetFile(path: String): SchemaRDD = + new SchemaRDD(this, parquet.ParquetRelation("ParquetFile", path)) + + + /** + * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only + * during the lifetime of this instance of SQLContext. + * + * @group userf + */ + def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { + catalog.registerTable(None, tableName, rdd.logicalPlan) + } + + /** + * Executes a SQL query using Spark, returning the result as a SchemaRDD. + * + * @group userf + */ + def sql(sqlText: String): SchemaRDD = { + val result = new SchemaRDD(this, parseSql(sqlText)) + // We force query optimization to happen right away instead of letting it happen lazily like + // when using the query DSL. This is so DDL commands behave as expected. This is only + // generates the RDD lineage for DML queries, but do not perform any execution. + result.queryExecution.toRdd + result + } + + protected[sql] class SparkPlanner extends SparkStrategies { + val sparkContext = self.sparkContext + + val strategies: Seq[Strategy] = + TopK :: + PartialAggregation :: + SparkEquiInnerJoin :: + BasicOperators :: + CartesianProduct :: + BroadcastNestedLoopJoin :: Nil + } + + @transient + protected[sql] val planner = new SparkPlanner + + /** + * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and + * inserting shuffle operations as needed. + */ + @transient + protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { + val batches = + Batch("Add exchange", Once, AddExchange) :: + Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + } + + /** + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + */ + protected abstract class QueryExecution { + def logical: LogicalPlan + + lazy val analyzed = analyzer(logical) + lazy val optimizedPlan = optimizer(analyzed) + // TODO: Don't just pick the first one... + lazy val sparkPlan = planner(optimizedPlan).next() + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[Row] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + override def toString: String = + s"""== Logical Plan == + |${stringOrError(analyzed)} + |== Optimized Logical Plan + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + + /** + * Runs the query after interposing operators that print the result of each intermediate step. + */ + def debugExec() = DebugQuery(executedPlan).execute().collect() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala new file mode 100644 index 0000000000..91c3aaa2b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -0,0 +1,342 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.{OneToOneDependency, Dependency, Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.BooleanType + +/** + * <span class="badge" style="float: right; background-color: darkblue;">ALPHA COMPONENT</span> + * + * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, + * SchemaRDDs can be used in relational queries, as shown in the examples below. + * + * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD + * whose elements are scala case classes into a SchemaRDD. This conversion can also be done + * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. + * + * A `SchemaRDD` can also be created by loading data in from external sources, for example, + * by using the `parquetFile` method on [[SQLContext]]. + * + * == SQL Queries == + * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once + * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. + * + * {{{ + * // One method for defining the schema of an RDD is to make a case class with the desired column + * // names and types. + * case class Record(key: Int, value: String) + * + * val sc: SparkContext // An existing spark context. + * val sqlContext = new SQLContext(sc) + * + * // Importing the SQL context gives access to all the SQL functions and implicit conversions. + * import sqlContext._ + * + * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i"))) + * // Any RDD containing case classes can be registered as a table. The schema of the table is + * // automatically inferred using scala reflection. + * rdd.registerAsTable("records") + * + * val results: SchemaRDD = sql("SELECT * FROM records") + * }}} + * + * == Language Integrated Queries == + * + * {{{ + * + * case class Record(key: Int, value: String) + * + * val sc: SparkContext // An existing spark context. + * val sqlContext = new SQLContext(sc) + * + * // Importing the SQL context gives access to all the SQL functions and implicit conversions. + * import sqlContext._ + * + * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) + * + * // Example of language integrated queries. + * rdd.where('key === 1).orderBy('value.asc).select('key).collect() + * }}} + * + * @todo There is currently no support for creating SchemaRDDs from either Java or Python RDDs. + * + * @groupname Query Language Integrated Queries + * @groupdesc Query Functions that create new queries from SchemaRDDs. The + * result of all query functions is also a SchemaRDD, allowing multiple operations to be + * chained using a builder pattern. + * @groupprio Query -2 + * @groupname schema SchemaRDD Functions + * @groupprio schema -1 + * @groupname Ungrouped Base RDD Functions + */ +class SchemaRDD( + @transient val sqlContext: SQLContext, + @transient val logicalPlan: LogicalPlan) + extends RDD[Row](sqlContext.sparkContext, Nil) { + + /** + * A lazily computed query execution workflow. All other RDD operations are passed + * through to the RDD that is produced by this workflow. + * + * We want this to be lazy because invoking the whole query optimization pipeline can be + * expensive. + */ + @transient + protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan) + + override def toString = + s"""${super.toString} + |== Query Plan == + |${queryExecution.executedPlan}""".stripMargin.trim + + // ========================================================================================= + // RDD functions: Copy the interal row representation so we present immutable data to users. + // ========================================================================================= + + override def compute(split: Partition, context: TaskContext): Iterator[Row] = + firstParent[Row].compute(split, context).map(_.copy()) + + override def getPartitions: Array[Partition] = firstParent[Row].partitions + + override protected def getDependencies: Seq[Dependency[_]] = + List(new OneToOneDependency(queryExecution.toRdd)) + + + // ======================================================================= + // Query DSL + // ======================================================================= + + /** + * Changes the output of this relation to the given expressions, similar to the `SELECT` clause + * in SQL. + * + * {{{ + * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) + * }}} + * + * @param exprs a set of logical expression that will be evaluated for each input row. + * + * @group Query + */ + def select(exprs: NamedExpression*): SchemaRDD = + new SchemaRDD(sqlContext, Project(exprs, logicalPlan)) + + /** + * Filters the ouput, only returning those rows where `condition` evaluates to true. + * + * {{{ + * schemaRDD.where('a === 'b) + * schemaRDD.where('a === 1) + * schemaRDD.where('a + 'b > 10) + * }}} + * + * @group Query + */ + def where(condition: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) + + /** + * Performs a relational join on two SchemaRDDs + * + * @param otherPlan the [[SchemaRDD]] that should be joined with this one. + * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` + * @param condition An optional condition for the join operation. This is equivilent to the `ON` + * clause in standard SQL. In the case of `Inner` joins, specifying a + * `condition` is equivilent to adding `where` clauses after the `join`. + * + * @group Query + */ + def join( + otherPlan: SchemaRDD, + joinType: JoinType = Inner, + condition: Option[Expression] = None): SchemaRDD = + new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, condition)) + + /** + * Sorts the results by the given expressions. + * {{{ + * schemaRDD.orderBy('a) + * schemaRDD.orderBy('a, 'b) + * schemaRDD.orderBy('a.asc, 'b.desc) + * }}} + * + * @group Query + */ + def orderBy(sortExprs: SortOrder*): SchemaRDD = + new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + + /** + * Performs a grouping followed by an aggregation. + * + * {{{ + * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) + * }}} + * + * @group Query + */ + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) + } + + /** + * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes + * with the same name, for example, when peforming self-joins. + * + * {{{ + * val x = schemaRDD.where('a === 1).subquery('x) + * val y = schemaRDD.where('a === 2).subquery('y) + * x.join(y).where("x.a".attr === "y.a".attr), + * }}} + * + * @group Query + */ + def subquery(alias: Symbol) = + new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) + + /** + * Combines the tuples of two RDDs with the same schema, keeping duplicates. + * + * @group Query + */ + def unionAll(otherPlan: SchemaRDD) = + new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) + + /** + * Filters tuples using a function over the value of the specified column. + * + * {{{ + * schemaRDD.sfilter('a)((a: Int) => ...) + * }}} + * + * @group Query + */ + def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = + new SchemaRDD( + sqlContext, + Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) + + /** + * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span> + * + * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use + * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of + * the column is not known at compile time, all attributes are converted to strings before + * being passed to the function. + * + * {{{ + * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith") + * }}} + * + * @group Query + */ + def where(dynamicUdf: (DynamicRow) => Boolean) = + new SchemaRDD( + sqlContext, + Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) + + /** + * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span> + * + * Returns a sampled version of the underlying dataset. + * + * @group Query + */ + def sample( + fraction: Double, + withReplacement: Boolean = true, + seed: Int = (math.random * 1000).toInt) = + new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) + + /** + * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span> + * + * Applies the given Generator, or table generating function, to this relation. + * + * @param generator A table generating function. The API for such functions is likely to change + * in future releases + * @param join when set to true, each output row of the generator is joined with the input row + * that produced it. + * @param outer when set to true, at least one row will be produced for each input row, similar to + * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a + * given row, a single row will be output, with `NULL` values for each of the + * generated columns. + * @param alias an optional alias that can be used as qualif for the attributes that are produced + * by this generate operation. + * + * @group Query + */ + def generate( + generator: Generator, + join: Boolean = false, + outer: Boolean = false, + alias: Option[String] = None) = + new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + + /** + * <span class="badge badge-red" style="float: right;">EXPERIMENTAL</span> + * + * Adds the rows from this RDD to the specified table. Note in a standard [[SQLContext]] there is + * no notion of persistent tables, and thus queries that contain this operator will fail to + * optimize. When working with an extension of a SQLContext that has a persistent catalog, such + * as a `HiveContext`, this operation will result in insertions to the table specified. + * + * @group schema + */ + def insertInto(tableName: String, overwrite: Boolean = false) = + new SchemaRDD( + sqlContext, + InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)) + + /** + * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that + * are written out using this method can be read back in as a SchemaRDD using the ``function + * + * @group schema + */ + def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. + * + * @group schema + */ + def registerAsTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Returns this RDD as a SchemaRDD. + * @group schema + */ + def toSchemaRDD = this + + def analyze = sqlContext.analyzer(logicalPlan) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala new file mode 100644 index 0000000000..72dc5ec6ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, RangePartitioner, HashPartitioner} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.MutablePair + +import catalyst.rules.Rule +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical._ + +private class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { + override def newKryo(): Kryo = { + val kryo = new Kryo + kryo.setRegistrationRequired(true) + kryo.register(classOf[MutablePair[_,_]]) + kryo.register(classOf[Array[Any]]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) + kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) + kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) + kryo.setReferences(false) + kryo.setClassLoader(this.getClass.getClassLoader) + kryo + } +} + +private class BigDecimalSerializer extends Serializer[BigDecimal] { + def write(kryo: Kryo, output: Output, bd: math.BigDecimal) { + // TODO: There are probably more efficient representations than strings... + output.writeString(bd.toString) + } + + def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = { + BigDecimal(input.readString()) + } +} + +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { + + override def outputPartitioning = newPartitioning + + def output = child.output + + def execute() = attachTree(this , "execute") { + newPartitioning match { + case HashPartitioning(expressions, numPartitions) => { + // TODO: Eliminate redundant expressions in grouping key and value. + val rdd = child.execute().mapPartitions { iter => + val hashExpressions = new MutableProjection(expressions) + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + } + case RangePartitioning(sortingExpressions, numPartitions) => { + // TODO: RangePartitioner should take an Ordering. + implicit val ordering = new RowOrdering(sortingExpressions) + + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null](null, null) + iter.map(row => mutablePair.update(row, null)) + } + val part = new RangePartitioner(numPartitions, rdd, ascending = true) + val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + shuffled.map(_._1) + } + case SinglePartition => + child.execute().coalesce(1, true) + + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + } +} + +/** + * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the + * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting + * [[Exchange]] Operators where required. + */ +object AddExchange extends Rule[SparkPlan] { + // TODO: Determine the number of partitions. + val numPartitions = 8 + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => + // Check if every child's outputPartitioning satisfies the corresponding + // required data distribution. + def meetsRequirements = + !operator.requiredChildDistribution.zip(operator.children).map { + case (required, child) => + val valid = child.outputPartitioning.satisfies(required) + logger.debug( + s"${if (valid) "Valid" else "Invalid"} distribution," + + s"required: $required current: ${child.outputPartitioning}") + valid + }.exists(!_) + + // Check if outputPartitionings of children are compatible with each other. + // It is possible that every child satisfies its required data distribution + // but two children have incompatible outputPartitionings. For example, + // A dataset is range partitioned by "a.asc" (RangePartitioning) and another + // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two + // datasets are both clustered by "a", but these two outputPartitionings are not + // compatible. + // TODO: ASSUMES TRANSITIVITY? + def compatible = + !operator.children + .map(_.outputPartitioning) + .sliding(2) + .map { + case Seq(a) => true + case Seq(a,b) => a compatibleWith b + }.exists(!_) + + // Check if the partitioning we want to ensure is the same as the child's output + // partitioning. If so, we do not need to add the Exchange operator. + def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) = + if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + + if (meetsRequirements && compatible) { + operator + } else { + // At least one child does not satisfies its required data distribution or + // at least one child's outputPartitioning is not compatible with another child's + // outputPartitioning. In this case, we need to add Exchange operators. + val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { + case (AllTuples, child) => + addExchangeIfNecessary(SinglePartition, child) + case (ClusteredDistribution(clustering), child) => + addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (OrderedDistribution(ordering), child) => + addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) + case (UnspecifiedDistribution, child) => child + case (dist, _) => sys.error(s"Don't know how to ensure $dist") + } + operator.withNewChildren(repartitionedChildren) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala new file mode 100644 index 0000000000..c1da3653c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import catalyst.expressions._ +import catalyst.types._ + +/** + * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional + * programming with one important additional feature, which allows the input rows to be joined with + * their output. + * @param join when true, each output row is implicitly joined with the input tuple that produced + * it. + * @param outer when true, each input row will be output at least once, even if the output of the + * given `generator` is empty. `outer` has no effect when `join` is false. + */ +case class Generate( + generator: Generator, + join: Boolean, + outer: Boolean, + child: SparkPlan) + extends UnaryNode { + + def output = + if (join) child.output ++ generator.output else generator.output + + def execute() = { + if (join) { + child.execute().mapPartitions { iter => + val nullValues = Seq.fill(generator.output.size)(Literal(null)) + // Used to produce rows with no matches when outer = true. + val outerProjection = + new Projection(child.output ++ nullValues, child.output) + + val joinProjection = + new Projection(child.output ++ generator.output, child.output ++ generator.output) + val joinedRow = new JoinedRow + + iter.flatMap {row => + val outputRows = generator(row) + if (outer && outputRows.isEmpty) { + outerProjection(row) :: Nil + } else { + outputRows.map(or => joinProjection(joinedRow(row, or))) + } + } + } + } else { + child.execute().mapPartitions(iter => iter.flatMap(generator)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala new file mode 100644 index 0000000000..7ce8608d20 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +class QueryExecutionException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala new file mode 100644 index 0000000000..5626181d18 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.rdd.RDD + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees + +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { + self: Product => + + // TODO: Move to `DistributedPlan` + /** Specifies how data is partitioned across different nodes in the cluster. */ + def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + /** Specifies any partition requirements on the input data for this operator. */ + def requiredChildDistribution: Seq[Distribution] = + Seq.fill(children.size)(UnspecifiedDistribution) + + /** + * Runs this query returning the result as an RDD. + */ + def execute(): RDD[Row] + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[Row] = execute().collect() + + protected def buildRow(values: Seq[Any]): Row = + new catalyst.expressions.GenericRow(values.toArray) +} + +/** + * Allows already planned SparkQueries to be linked into logical query plans. + * + * Note that in general it is not valid to use this class to link multiple copies of the same + * physical operator into the same query plan as this violates the uniqueness of expression ids. + * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just + * replace the output attributes with new copies of themselves without breaking any attribute + * linking. + */ +case class SparkLogicalPlan(alreadyPlanned: SparkPlan) + extends logical.LogicalPlan with MultiInstanceRelation { + + def output = alreadyPlanned.output + def references = Set.empty + def children = Nil + + override final def newInstance: this.type = { + SparkLogicalPlan( + alreadyPlanned match { + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case _ => sys.error("Multiple instance of the same relation detected.") + }).asInstanceOf[this.type] + } +} + +trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { + self: Product => +} + +trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { + self: Product => + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { + self: Product => +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala new file mode 100644 index 0000000000..85035b8118 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.SparkContext + +import catalyst.expressions._ +import catalyst.planning._ +import catalyst.plans._ +import catalyst.plans.logical.LogicalPlan +import catalyst.plans.physical._ +import parquet.ParquetRelation +import parquet.InsertIntoParquetTable + +abstract class SparkStrategies extends QueryPlanner[SparkPlan] { + + val sparkContext: SparkContext + + object SparkEquiInnerJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => + logger.debug(s"Considering join: ${predicates ++ condition}") + // Find equi-join predicates that can be evaluated before the join, and thus can be used + // as join keys. Note we can only mix in the conditions with other predicates because the + // match above ensures that this is and Inner join. + val (joinPredicates, otherPredicates) = (predicates ++ condition).partition { + case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + (canEvaluate(l, right) && canEvaluate(r, left)) => true + case _ => false + } + + val joinKeys = joinPredicates.map { + case Equals(l,r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case Equals(l,r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + } + + // Do not consider this strategy if there are no join keys. + if (joinKeys.nonEmpty) { + val leftKeys = joinKeys.map(_._1) + val rightKeys = joinKeys.map(_._2) + + val joinOp = execution.SparkEquiInnerJoin( + leftKeys, rightKeys, planLater(left), planLater(right)) + + // Make sure other conditions are met if present. + if (otherPredicates.nonEmpty) { + execution.Filter(combineConjunctivePredicates(otherPredicates), joinOp) :: Nil + } else { + joinOp :: Nil + } + } else { + logger.debug(s"Avoiding spark join with no join keys.") + Nil + } + case _ => Nil + } + + private def combineConjunctivePredicates(predicates: Seq[Expression]) = + predicates.reduceLeft(And) + + /** Returns true if `expr` can be evaluated using only the output of `plan`. */ + protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = + expr.references subsetOf plan.outputSet + } + + object PartialAggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + // Construct two phased aggregation. + execution.Aggregate( + partial = false, + namedGroupingExpressions.values.map(_.toAttribute).toSeq, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))(sparkContext))(sparkContext) :: Nil + } else { + Nil + } + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + execution.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil + case _ => Nil + } + } + + object CartesianProduct extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, _, None) => + execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + case logical.Join(left, right, Inner, Some(condition)) => + execution.Filter(condition, + execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + case _ => Nil + } + } + + protected lazy val singleRowRdd = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + + def convertToCatalyst(a: Any): Any = a match { + case s: Seq[Any] => s.map(convertToCatalyst) + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + + object TopK extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) => + execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil + case _ => Nil + } + } + + // Can we automate these 'pass through' operations? + object BasicOperators extends Strategy { + // TOOD: Set + val numPartitions = 200 + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Distinct(child) => + execution.Aggregate( + partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil + case logical.Sort(sortExprs, child) => + // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. + execution.Sort(sortExprs, global = true, planLater(child)):: Nil + case logical.SortPartitions(sortExprs, child) => + // This sort only sorts tuples within a partition. Its requiredDistribution will be + // an UnspecifiedDistribution. + execution.Sort(sortExprs, global = false, planLater(child)) :: Nil + case logical.Project(projectList, r: ParquetRelation) + if projectList.forall(_.isInstanceOf[Attribute]) => + + // simple projection of data loaded from Parquet file + parquet.ParquetTableScan( + projectList.asInstanceOf[Seq[Attribute]], + r, + None)(sparkContext) :: Nil + case logical.Project(projectList, child) => + execution.Project(projectList, planLater(child)) :: Nil + case logical.Filter(condition, child) => + execution.Filter(condition, planLater(child)) :: Nil + case logical.Aggregate(group, agg, child) => + execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil + case logical.Sample(fraction, withReplacement, seed, child) => + execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case logical.LocalRelation(output, data) => + val dataAsRdd = + sparkContext.parallelize(data.map(r => + new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) + execution.ExistingRdd(output, dataAsRdd) :: Nil + case logical.StopAfter(IntegerLiteral(limit), child) => + execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil + case Unions(unionChildren) => + execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil + case logical.Generate(generator, join, outer, _, child) => + execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case logical.NoRelation => + execution.ExistingRdd(Nil, singleRowRdd) :: Nil + case logical.Repartition(expressions, child) => + execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case logical.WriteToFile(path, child) => + val relation = + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None) + InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil + case p: parquet.ParquetRelation => + parquet.ParquetTableScan(p.output, p, None)(sparkContext) :: Nil + case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case _ => Nil + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala new file mode 100644 index 0000000000..51889c1988 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples} +import catalyst.types._ + +import org.apache.spark.rdd.PartitionLocalRDDFunctions._ + +/** + * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each + * group. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +case class Aggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sc: SparkContext) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def otherCopyArgs = sc :: Nil + + def output = aggregateExpressions.map(_.toAttribute) + + /* Replace all aggregate expressions with spark functions that will compute the result. */ + def createAggregateImplementations() = aggregateExpressions.map { agg => + val impl = agg transform { + case a: AggregateExpression => a.newInstance + } + + val remainingAttributes = impl.collect { case a: Attribute => a } + // If any references exist that are not inside agg functions then the must be grouping exprs + // in this case we must rebind them to the grouping tuple. + if (remainingAttributes.nonEmpty) { + val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c } + + // An exact match with a grouping expression + val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match { + case -1 => None + case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute)) + } + + exactGroupingExpr.getOrElse( + sys.error(s"$agg is not in grouping expressions: $groupingExpressions")) + } else { + impl + } + } + + def execute() = attachTree(this, "execute") { + // TODO: If the child of it is an [[catalyst.execution.Exchange]], + // do not evaluate the groupingExpressions again since we have evaluated it + // in the [[catalyst.execution.Exchange]]. + val grouped = child.execute().mapPartitions { iter => + val buildGrouping = new Projection(groupingExpressions) + iter.map(row => (buildGrouping(row), row.copy())) + }.groupByKeyLocally() + + val result = grouped.map { case (group, rows) => + val aggImplementations = createAggregateImplementations() + + // Pull out all the functions so we can feed each row into them. + val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f }) + + rows.foreach { row => + aggFunctions.foreach(_.update(row)) + } + buildRow(aggImplementations.map(_.apply(group))) + } + + // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY... + if (groupingExpressions.isEmpty && result.count == 0) { + // When there there is no output to the Aggregate operator, we still output an empty row. + val aggImplementations = createAggregateImplementations() + sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil) + } else { + result + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala new file mode 100644 index 0000000000..c6d31d9abc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution} +import catalyst.plans.logical.LogicalPlan +import catalyst.ScalaReflection + +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + def output = projectList.map(_.toAttribute) + + def execute() = child.execute().mapPartitions { iter => + @transient val resuableProjection = new MutableProjection(projectList) + iter.map(resuableProjection) + } +} + +case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { + def output = child.output + + def execute() = child.execute().mapPartitions { iter => + iter.filter(condition.apply(_).asInstanceOf[Boolean]) + } +} + +case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan) + extends UnaryNode { + + def output = child.output + + // TODO: How to pick seed? + def execute() = child.execute().sample(withReplacement, fraction, seed) +} + +case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { + // TODO: attributes output by union should be distinct for nullability purposes + def output = children.head.output + def execute() = sc.union(children.map(_.execute())) + + override def otherCopyArgs = sc :: Nil +} + +case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { + override def otherCopyArgs = sc :: Nil + + def output = child.output + + override def executeCollect() = child.execute().map(_.copy()).take(limit) + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + def execute() = sc.makeRDD(executeCollect(), 1) +} + +case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) + (@transient sc: SparkContext) extends UnaryNode { + override def otherCopyArgs = sc :: Nil + + def output = child.output + + @transient + lazy val ordering = new RowOrdering(sortOrder) + + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + def execute() = sc.makeRDD(executeCollect(), 1) +} + + +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + override def requiredChildDistribution = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + @transient + lazy val ordering = new RowOrdering(sortOrder) + + def execute() = attachTree(this, "sort") { + // TODO: Optimize sorting operation? + child.execute() + .mapPartitions( + iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, + preservesPartitioning = true) + } + + def output = child.output +} + +object ExistingRdd { + def convertToCatalyst(a: Any): Any = a match { + case s: Seq[Any] => s.map(convertToCatalyst) + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + + def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + // TODO: Reuse the row, don't use map on the product iterator. Maybe code gen? + data.map(r => new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row) + } + + def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = { + ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) + } +} + +case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + def execute() = rdd +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala new file mode 100644 index 0000000000..db259b4c4b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +object DebugQuery { + def apply(plan: SparkPlan): SparkPlan = { + val visited = new collection.mutable.HashSet[Long]() + plan transform { + case s: SparkPlan if !visited.contains(s.id) => + visited += s.id + DebugNode(s) + } + } +} + +case class DebugNode(child: SparkPlan) extends UnaryNode { + def references = Set.empty + def output = child.output + def execute() = { + val childRdd = child.execute() + println( + s""" + |========================= + |${child.simpleString} + |========================= + """.stripMargin) + childRdd.foreach(println(_)) + childRdd + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala new file mode 100644 index 0000000000..5934fd1b03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import scala.collection.mutable + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans._ +import catalyst.plans.physical.{ClusteredDistribution, Partitioning} + +import org.apache.spark.rdd.PartitionLocalRDDFunctions._ + +case class SparkEquiInnerJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + def execute() = attachTree(this, "execute") { + val leftWithKeys = left.execute().mapPartitions { iter => + val generateLeftKeys = new Projection(leftKeys, left.output) + iter.map(row => (generateLeftKeys(row), row.copy())) + } + + val rightWithKeys = right.execute().mapPartitions { iter => + val generateRightKeys = new Projection(rightKeys, right.output) + iter.map(row => (generateRightKeys(row), row.copy())) + } + + // Do the join. + val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) + // Drop join keys and merge input tuples. + joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } + } + + /** + * Filters any rows where the any of the join keys is null, ensuring three-valued + * logic for the equi-join conditions. + */ + protected def filterNulls(rdd: RDD[(Row, Row)]) = + rdd.filter { + case (key: Seq[_], _) => !key.exists(_ == null) + } +} + +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + def output = left.output ++ right.output + + def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { + case (l: Row, r: Row) => buildRow(l ++ r) + } +} + +case class BroadcastNestedLoopJoin( + streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) + (@transient sc: SparkContext) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def otherCopyArgs = sc :: Nil + + def output = left.output ++ right.output + + /** The Streamed Relation */ + def left = streamed + /** The Broadcast relation */ + def right = broadcast + + @transient lazy val boundCondition = + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true)) + + + def execute() = { + val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new mutable.ArrayBuffer[Row] + val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + + streamedIter.foreach { streamedRow => + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + matchedRows += buildRow(streamedRow ++ broadcastedRow) + matched = true + includedBroadcastTuples += i + } + i += 1 + } + + if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { + matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + streamedPlusMatches.map(_._2).reduce(_ ++ _) + } + + val rightOuterMatches: Seq[Row] = + if (joinType == RightOuter || joinType == FullOuter) { + broadcastedRelation.value.zipWithIndex.filter { + case (row, i) => !allIncludedBroadcastTuples.contains(i) + }.map { + // TODO: Use projection. + case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + } + } else { + Vector() + } + + // TODO: Breaks lineage. + sc.union( + streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala new file mode 100644 index 0000000000..67f6f43f90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * An execution engine for relational query plans that runs on top Spark and returns RDDs. + * + * Note that the operators in this package are created automatically by a query planner using a + * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are + * documented here in order to make it easier for others to understand the performance + * characteristics of query plans that are generated by Spark SQL. + */ +package object execution { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala new file mode 100644 index 0000000000..e87561fe13 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.{IOException, FileNotFoundException} + +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.fs.permission.FsAction + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, BaseRelation} +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.ArrayType +import org.apache.spark.sql.catalyst.expressions.{Row, AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.analysis.UnresolvedException + +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.{PrimitiveType => ParquetPrimitiveType} +import parquet.schema.{Type => ParquetType} +import parquet.schema.Type.Repetition +import parquet.io.api.{Binary, RecordConsumer} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetFileReader} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import parquet.hadoop.util.ContextUtil + +import scala.collection.JavaConversions._ + +/** + * Relation that consists of data stored in a Parquet columnar format. + * + * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead + * of using this class directly. + * + * {{{ + * val parquetRDD = sqlContext.parquetFile("path/to/parequet.file") + * }}} + * + * @param tableName The name of the relation that can be used in queries. + * @param path The path to the Parquet file. + */ +case class ParquetRelation(val tableName: String, val path: String) extends BaseRelation { + + /** Schema derived from ParquetFile **/ + def parquetSchema: MessageType = + ParquetTypesConverter + .readMetaData(new Path(path)) + .getFileMetaData + .getSchema + + /** Attributes **/ + val attributes = + ParquetTypesConverter + .convertToAttributes(parquetSchema) + + /** Output **/ + override val output = attributes + + // Parquet files have no concepts of keys, therefore no Partitioner + // Note: we could allow Block level access; needs to be thought through + override def isPartitioned = false +} + +object ParquetRelation { + + // The element type for the RDDs that this relation maps to. + type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow + + /** + * Creates a new ParquetRelation and underlying Parquetfile for the given + * LogicalPlan. Note that this is used inside [[SparkStrategies]] to + * create a resolved relation as a data sink for writing to a Parquetfile. + * The relation is empty but is initialized with ParquetMetadata and + * can be inserted into. + * + * @param pathString The directory the Parquetfile will be stored in. + * @param child The child node that will be used for extracting the schema. + * @param conf A configuration configuration to be used. + * @param tableName The name of the resulting relation. + * @return An empty ParquetRelation inferred metadata. + */ + def create(pathString: String, + child: LogicalPlan, + conf: Configuration, + tableName: Option[String]): ParquetRelation = { + if (!child.resolved) { + throw new UnresolvedException[LogicalPlan]( + child, + "Attempt to create Parquet table from unresolved child (when schema is not available)") + } + + val name = s"${tableName.getOrElse(child.nodeName)}_parquet" + val path = checkPath(pathString, conf) + ParquetTypesConverter.writeMetaData(child.output, path, conf) + new ParquetRelation(name, path.toString) + } + + private def checkPath(pathStr: String, conf: Configuration): Path = { + if (pathStr == null) { + throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") + } + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && + !fs.getFileStatus(path) + .getPermission + .getUserAction + .implies(FsAction.READ_WRITE)) { + throw new IOException( + s"Unable to create ParquetRelation: path $path not read-writable") + } + path + } +} + +object ParquetTypesConverter { + def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { + // for now map binary to string type + // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema + case ParquetPrimitiveTypeName.BINARY => StringType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => { + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Warning: potential loss of precision: converting INT96 to long") + LongType + } + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } + + def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match { + case StringType => ParquetPrimitiveTypeName.BINARY + case BooleanType => ParquetPrimitiveTypeName.BOOLEAN + case DoubleType => ParquetPrimitiveTypeName.DOUBLE + case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + case FloatType => ParquetPrimitiveTypeName.FLOAT + case IntegerType => ParquetPrimitiveTypeName.INT32 + case LongType => ParquetPrimitiveTypeName.INT64 + case _ => sys.error(s"Unsupported datatype $ctype") + } + + def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = { + ctype match { + case StringType => consumer.addBinary( + Binary.fromByteArray( + record(index).asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => consumer.addInteger(record.getInt(index)) + case LongType => consumer.addLong(record.getLong(index)) + case DoubleType => consumer.addDouble(record.getDouble(index)) + case FloatType => consumer.addFloat(record.getFloat(index)) + case BooleanType => consumer.addBoolean(record.getBoolean(index)) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } + + def getSchema(schemaString : String) : MessageType = + MessageTypeParser.parseMessageType(schemaString) + + def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = { + parquetSchema.getColumns.map { + case (desc) => { + val ctype = toDataType(desc.getType) + val name: String = desc.getPath.mkString(".") + new AttributeReference(name, ctype, false)() + } + } + } + + // TODO: allow nesting? + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val fields: Seq[ParquetType] = attributes.map { + a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name) + } + new MessageType("root", fields) + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put("path", path.toString) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = + ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param path The path at which we expect one (or more) Parquet files. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access + // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!) + val conf = ContextUtil.getConfiguration(job) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { + // TODO: improve exception handling, etc. + ParquetFileReader.readFooter(conf, metadataPath) + } else { + if (!fs.exists(path) || !fs.isFile(path)) { + throw new FileNotFoundException( + s"Could not find file ${path.toString} when trying to read metadata") + } + ParquetFileReader.readFooter(conf, path) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala new file mode 100644 index 0000000000..61121103cb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import parquet.io.InvalidRecordException +import parquet.schema.MessageType +import parquet.hadoop.{ParquetOutputFormat, ParquetInputFormat} +import parquet.hadoop.util.ContextUtil + +import org.apache.spark.rdd.RDD +import org.apache.spark.{TaskContext, SerializableWritable, SparkContext} +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute, Expression} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, LeafNode} + +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import java.io.IOException +import java.text.SimpleDateFormat +import java.util.Date + +/** + * Parquet table scan operator. Imports the file that backs the given + * [[ParquetRelation]] as a RDD[Row]. + */ +case class ParquetTableScan( + @transient output: Seq[Attribute], + @transient relation: ParquetRelation, + @transient columnPruningPred: Option[Expression])( + @transient val sc: SparkContext) + extends LeafNode { + + override def execute(): RDD[Row] = { + val job = new Job(sc.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass( + job, + classOf[org.apache.spark.sql.parquet.RowReadSupport]) + val conf: Configuration = ContextUtil.getConfiguration(job) + conf.set( + RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertFromAttributes(output).toString) + // TODO: think about adding record filters + /* Comments regarding record filters: it would be nice to push down as much filtering + to Parquet as possible. However, currently it seems we cannot pass enough information + to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's + ``FilteredRecordReader`` (via Configuration, for example). Simple + filter-rows-by-column-values however should be supported. + */ + sc.newAPIHadoopFile( + relation.path, + classOf[ParquetInputFormat[Row]], + classOf[Void], classOf[Row], + conf) + .map(_._2) + } + + /** + * Applies a (candidate) projection. + * + * @param prunedAttributes The list of attributes to be used in the projection. + * @return Pruned TableScan. + */ + def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { + val success = validateProjection(prunedAttributes) + if (success) { + ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc) + } else { + sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") + this + } + } + + /** + * Evaluates a candidate projection by checking whether the candidate is a subtype + * of the original type. + * + * @param projection The candidate projection. + * @return True if the projection is valid, false otherwise. + */ + private def validateProjection(projection: Seq[Attribute]): Boolean = { + val original: MessageType = relation.parquetSchema + val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) + try { + original.checkContains(candidate) + true + } catch { + case e: InvalidRecordException => { + false + } + } + } +} + +case class InsertIntoParquetTable( + @transient relation: ParquetRelation, + @transient child: SparkPlan)( + @transient val sc: SparkContext) + extends UnaryNode with SparkHadoopMapReduceUtil { + + /** + * Inserts all the rows in the Parquet file. Note that OVERWRITE is implicit, since + * Parquet files are write-once. + */ + override def execute() = { + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val childRdd = child.execute() + assert(childRdd != null) + + val job = new Job(sc.hadoopConfiguration) + + ParquetOutputFormat.setWriteSupportClass( + job, + classOf[org.apache.spark.sql.parquet.RowWriteSupport]) + + // TODO: move that to function in object + val conf = job.getConfiguration + conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) + + val fspath = new Path(relation.path) + val fs = fspath.getFileSystem(conf) + + try { + fs.delete(fspath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${fspath.toString} prior" + + s" to InsertIntoParquetTable:\n${e.toString}") + } + saveAsHadoopFile(childRdd, relation.path.toString, conf) + + // We return the child RDD to allow chaining (alternatively, one could return nothing). + childRdd + } + + override def output = child.output + + // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]] + // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2? + // .. then we could use the default one and could use [[MutablePair]] + // instead of ``Tuple2`` + private def saveAsHadoopFile( + rdd: RDD[Row], + path: String, + conf: Configuration) { + val job = new Job(conf) + val keyType = classOf[Void] + val outputFormatType = classOf[parquet.hadoop.ParquetOutputFormat[Row]] + job.setOutputKeyClass(keyType) + job.setOutputValueClass(classOf[Row]) + val wrappedConf = new SerializableWritable(job.getConfiguration) + NewFileOutputFormat.setOutputPath(job, new Path(path)) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = sc.newRddId() + + def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + /* "reduce task" <split #> <attempt # = spark task #> */ + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + attemptNumber) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = outputFormatType.newInstance + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + while (iter.hasNext) { + val row = iter.next() + writer.write(null, row) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + val jobFormat = outputFormatType.newInstance + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + sc.runJob(rdd, writeShard _) + jobCommitter.commitJob(jobTaskContext) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala new file mode 100644 index 0000000000..c2ae18b882 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.Logging + +import parquet.io.api._ +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.hadoop.api.{WriteSupport, ReadSupport} +import parquet.hadoop.api.ReadSupport.ReadContext +import parquet.hadoop.ParquetOutputFormat +import parquet.column.ParquetProperties + +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.types._ + +/** + * A `parquet.io.api.RecordMaterializer` for Rows. + * + *@param root The root group converter for the record. + */ +class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterializer[Row] { + + def this(parquetSchema: MessageType) = + this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) + + override def getCurrentRecord: Row = root.getCurrentRecord + + override def getRootConverter: GroupConverter = root +} + +/** + * A `parquet.hadoop.api.ReadSupport` for Row objects. + */ +class RowReadSupport extends ReadSupport[Row] with Logging { + + override def prepareForRead( + conf: Configuration, + stringMap: java.util.Map[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[Row] = { + log.debug(s"preparing for read with schema ${fileSchema.toString}") + new RowRecordMaterializer(readContext.getRequestedSchema) + } + + override def init( + configuration: Configuration, + keyValueMetaData: java.util.Map[String, String], + fileSchema: MessageType): ReadContext = { + val requested_schema_string = + configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) + val requested_schema = + MessageTypeParser.parseMessageType(requested_schema_string) + + log.debug(s"read support initialized for original schema ${requested_schema.toString}") + new ReadContext(requested_schema, keyValueMetaData) + } +} + +object RowReadSupport { + val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" +} + +/** + * A `parquet.hadoop.api.WriteSupport` for Row ojects. + */ +class RowWriteSupport extends WriteSupport[Row] with Logging { + def setSchema(schema: MessageType, configuration: Configuration) { + // for testing + this.schema = schema + // TODO: could use Attributes themselves instead of Parquet schema? + configuration.set( + RowWriteSupport.PARQUET_ROW_SCHEMA, + schema.toString) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } + + def getSchema(configuration: Configuration): MessageType = { + return MessageTypeParser.parseMessageType( + configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) + } + + private var schema: MessageType = null + private var writer: RecordConsumer = null + private var attributes: Seq[Attribute] = null + + override def init(configuration: Configuration): WriteSupport.WriteContext = { + schema = if (schema == null) getSchema(configuration) else schema + attributes = ParquetTypesConverter.convertToAttributes(schema) + new WriteSupport.WriteContext( + schema, + new java.util.HashMap[java.lang.String, java.lang.String]()); + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + writer = recordConsumer + } + + // TODO: add groups (nested fields) + override def write(record: Row): Unit = { + var index = 0 + writer.startMessage() + while(index < attributes.size) { + // null values indicate optional fields but we do not check currently + if (record(index) != null && record(index) != Nil) { + writer.startField(attributes(index).name, index) + ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } +} + +object RowWriteSupport { + val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * + * @param schema The corresponding Catalyst schema in the form of a list of attributes. + */ +class CatalystGroupConverter( + schema: Seq[Attribute], + protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { + + def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length)) + + val converters: Array[Converter] = schema.map { + a => a.dataType match { + case ctype: NativeType => + // note: for some reason matching for StringType fails so use this ugly if instead + if (ctype == StringType) new CatalystPrimitiveStringConverter(this, schema.indexOf(a)) + else new CatalystPrimitiveConverter(this, schema.indexOf(a)) + case _ => throw new RuntimeException( + s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter") + } + }.toArray + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + private[parquet] def getCurrentRecord: ParquetRelation.RowType = current + + override def start(): Unit = { + var i = 0 + while (i < schema.length) { + current.setNullAt(i) + i = i + 1 + } + } + + override def end(): Unit = {} +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +class CatalystPrimitiveConverter( + parent: CatalystGroupConverter, + fieldIndex: Int) extends PrimitiveConverter { + // TODO: consider refactoring these together with ParquetTypesConverter + override def addBinary(value: Binary): Unit = + // TODO: fix this once a setBinary will become available in MutableRow + parent.getCurrentRecord.setByte(fieldIndex, value.getBytes.apply(0)) + + override def addBoolean(value: Boolean): Unit = + parent.getCurrentRecord.setBoolean(fieldIndex, value) + + override def addDouble(value: Double): Unit = + parent.getCurrentRecord.setDouble(fieldIndex, value) + + override def addFloat(value: Float): Unit = + parent.getCurrentRecord.setFloat(fieldIndex, value) + + override def addInt(value: Int): Unit = + parent.getCurrentRecord.setInt(fieldIndex, value) + + override def addLong(value: Long): Unit = + parent.getCurrentRecord.setLong(fieldIndex, value) +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays) + * into Catalyst Strings. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +class CatalystPrimitiveStringConverter( + parent: CatalystGroupConverter, + fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala new file mode 100644 index 0000000000..bbe409fb9c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.Job + +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.ParquetWriter + +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.expressions.GenericRow +import java.nio.charset.Charset + +object ParquetTestData { + + val testSchema = + """message myrecord { + |optional boolean myboolean; + |optional int32 myint; + |optional binary mystring; + |optional int64 mylong; + |optional float myfloat; + |optional double mydouble; + |}""".stripMargin + + // field names for test assertion error messages + val testSchemaFieldNames = Seq( + "myboolean:Boolean", + "mtint:Int", + "mystring:String", + "mylong:Long", + "myfloat:Float", + "mydouble:Double" + ) + + val subTestSchema = + """ + |message myrecord { + |optional boolean myboolean; + |optional int64 mylong; + |} + """.stripMargin + + // field names for test assertion error messages + val subTestSchemaFieldNames = Seq( + "myboolean:Boolean", + "mylong:Long" + ) + + val testFile = getTempFilePath("testParquetFile").getCanonicalFile + + lazy val testData = new ParquetRelation("testData", testFile.toURI.toString) + + def writeFile = { + testFile.delete + val path: Path = new Path(testFile.toURI) + val job = new Job() + val configuration: Configuration = ContextUtil.getConfiguration(job) + val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) + + val writeSupport = new RowWriteSupport() + writeSupport.setSchema(schema, configuration) + val writer = new ParquetWriter(path, writeSupport) + for(i <- 0 until 15) { + val data = new Array[Any](6) + if (i % 3 == 0) { + data.update(0, true) + } else { + data.update(0, false) + } + if (i % 5 == 0) { + data.update(1, 5) + } else { + data.update(1, null) // optional + } + data.update(2, "abc") + data.update(3, i.toLong << 33) + data.update(4, 2.5F) + data.update(5, 4.5D) + writer.write(new GenericRow(data.toArray)) + } + writer.close() + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 0000000000..ca56c4476b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark +package sql +package test + +/** A SQLContext that can be used for local testing. */ +object TestSQLContext + extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties new file mode 100644 index 0000000000..7bb6789bd3 --- /dev/null +++ b/sql/core/src/test/resources/log4j.properties @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF + +# Parquet logging +parquet.hadoop.InternalParquetRecordReader=WARN +log4j.logger.parquet.hadoop.InternalParquetRecordReader=WARN +parquet.hadoop.ParquetInputFormat=WARN +log4j.logger.parquet.hadoop.ParquetInputFormat=WARN diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala new file mode 100644 index 0000000000..37c90a18a0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class DslQuerySuite extends QueryTest { + import TestData._ + + test("table scan") { + checkAnswer( + testData, + testData.collect().toSeq) + } + + test("agg") { + checkAnswer( + testData2.groupBy('a)('a, Sum('b)), + Seq((1,3),(2,3),(3,3)) + ) + } + + test("select *") { + checkAnswer( + testData.select(Star(None)), + testData.collect().toSeq) + } + + test("simple select") { + checkAnswer( + testData.where('key === 1).select('value), + Seq(Seq("1"))) + } + + test("sorting") { + checkAnswer( + testData2.orderBy('a.asc, 'b.asc), + Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + + checkAnswer( + testData2.orderBy('a.asc, 'b.desc), + Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.desc), + Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.asc), + Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + } + + test("average") { + checkAnswer( + testData2.groupBy()(Average('a)), + 2.0) + } + + test("count") { + checkAnswer( + testData2.groupBy()(Count(1)), + testData2.count() + ) + } + + test("null count") { + checkAnswer( + testData3.groupBy('a)('a, Count('b)), + Seq((1,0), (2, 1)) + ) + + checkAnswer( + testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + (2, 1, 2, 2, 1) :: Nil + ) + } + + test("inner join where, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + val x = testData2.where('a === 1).subquery('x) + val y = testData2.where('a === 1).subquery('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + val x = testData2.where('a === 1).subquery('x) + val y = testData2.where('a === 2).subquery('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + Nil) + } + + test("big inner join, 4 matches per row") { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.subquery('x) + val bigDataY = bigData.subquery('y) + + checkAnswer( + bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + test("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + val left = upperCaseData.where('N <= 4).subquery('left) + val right = upperCaseData.where('N >= 3).subquery('right) + + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +}
\ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala new file mode 100644 index 0000000000..83908edf5a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.planner._ + +class PlannerSuite extends FunSuite { + + + test("unions are collapsed") { + val query = testData.unionAll(testData).unionAll(testData).logicalPlan + val planned = BasicOperators(query).head + val logicalUnions = query collect { case u: logical.Union => u} + val physicalUnions = planned collect { case u: execution.Union => u} + + assert(logicalUnions.size === 2) + assert(physicalUnions.size === 1) + } + + test("count is partially aggregated") { + val query = testData.groupBy('value)(Count('key)).analyze.logicalPlan + val planned = PartialAggregation(query).head + val aggregations = planned.collect { case a: Aggregate => a } + + assert(aggregations.size === 2) + } + + test("count distinct is not partially aggregated") { + val query = testData.groupBy('value)(CountDistinct('key :: Nil)).analyze.logicalPlan + val planned = PartialAggregation(query.logicalPlan) + assert(planned.isEmpty) + } + + test("mixed aggregates are not partially aggregated") { + val query = + testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).analyze.logicalPlan + val planned = PartialAggregation(query) + assert(planned.isEmpty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala new file mode 100644 index 0000000000..728feceded --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class QueryTest extends FunSuite { + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param plan the query to be executed + * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + */ + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { + val convertedAnswer = expectedAnswer match { + case s: Seq[_] if s.isEmpty => s + case s: Seq[_] if s.head.isInstanceOf[Product] && + !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) + case s: Seq[_] => s + case singleItem => Seq(Seq(singleItem)) + } + + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty + def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer + val sparkAnswer = try rdd.collect().toSeq catch { + case e: Exception => + fail( + s""" + |Exception thrown while executing query: + |${rdd.logicalPlan} + |== Exception == + |$e + """.stripMargin) + } + if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + fail(s""" + |Results do not match for query: + |${rdd.logicalPlan} + |== Analyzed Plan == + |${rdd.queryExecution.analyzed} + |== RDD == + |$rdd + |== Results == + |${sideBySide( + prepareAnswer(convertedAnswer).map(_.toString), + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala new file mode 100644 index 0000000000..5728313d6d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ +import TestData._ + +class SQLQuerySuite extends QueryTest { + test("agg") { + checkAnswer( + sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), + Seq((1,3),(2,3),(3,3)) + ) + } + + test("select *") { + checkAnswer( + sql("SELECT * FROM testData"), + testData.collect().toSeq) + } + + test("simple select") { + checkAnswer( + sql("SELECT value FROM testData WHERE key = 1"), + Seq(Seq("1"))) + } + + test("sorting") { + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), + Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), + Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), + Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), + Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + } + + test("average") { + checkAnswer( + sql("SELECT AVG(a) FROM testData2"), + 2.0) + } + + test("count") { + checkAnswer( + sql("SELECT COUNT(*) FROM testData2"), + testData2.count() + ) + } + + // No support for primitive nulls yet. + ignore("null count") { + checkAnswer( + sql("SELECT a, COUNT(b) FROM testData3"), + Seq((1,0), (2, 1)) + ) + + checkAnswer( + testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + (2, 1, 2, 2, 1) :: Nil + ) + } + + test("inner join where, one match per row") { + checkAnswer( + sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + checkAnswer( + sql(""" + |SELECT * FROM + | (SELECT * FROM testData2 WHERE a = 1) x JOIN + | (SELECT * FROM testData2 WHERE a = 1) y + |WHERE x.a = y.a""".stripMargin), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM testData2 WHERE a = 1) x JOIN + | (SELECT * FROM testData2 WHERE a = 2) y + |WHERE x.a = y.a""".stripMargin), + Nil) + } + + test("big inner join, 4 matches per row") { + + + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData) x JOIN + | (SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData) y + |WHERE x.key = y.key""".stripMargin), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + ignore("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM upperCaseData WHERE N <= 4) left FULL OUTER JOIN + | (SELECT * FROM upperCaseData WHERE N >= 3) right + | ON left.N = right.N + """.stripMargin), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +}
\ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala new file mode 100644 index 0000000000..640292571b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +object TestData { + case class TestData(key: Int, value: String) + val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))) + testData.registerAsTable("testData") + + case class TestData2(a: Int, b: Int) + val testData2: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil + ) + testData2.registerAsTable("testData2") + + // TODO: There is no way to express null primitives as case classes currently... + val testData3 = + logical.LocalRelation('a.int, 'b.int).loadData( + (1, null) :: + (2, 2) :: Nil + ) + + case class UpperCaseData(N: Int, L: String) + val upperCaseData = + TestSQLContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil + ) + upperCaseData.registerAsTable("upperCaseData") + + case class LowerCaseData(n: Int, l: String) + val lowerCaseData = + TestSQLContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil + ) + lowerCaseData.registerAsTable("lowerCaseData") +}
\ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala new file mode 100644 index 0000000000..08265b7a6a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.test._ + + +import TestSQLContext._ + +/** + * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns + * from the input data. These will be replaced during analysis with specific AttributeReferences + * and then bound to specific ordinals during query planning. While TGFs could also access specific + * columns using hand-coded ordinals, doing so violates data independence. + * + * Note: this is only a rough example of how TGFs can be expressed, the final version will likely + * involve a lot more sugar for cleaner use in Scala/Java/etc. + */ +case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { + def children = input + protected def makeOutput() = 'nameAndAge.string :: Nil + + val Seq(nameAttr, ageAttr) = input + + override def apply(input: Row): TraversableOnce[Row] = { + val name = nameAttr.apply(input) + val age = ageAttr.apply(input).asInstanceOf[Int] + + Iterator( + new GenericRow(Array[Any](s"$name is $age years old")), + new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) + } +} + +class TgfSuite extends QueryTest { + val inputData = + logical.LocalRelation('name.string, 'age.int).loadData( + ("michael", 29) :: Nil + ) + + test("simple tgf example") { + checkAnswer( + inputData.generate(ExampleTGF()), + Seq( + "michael is 29 years old" :: Nil, + "Next year, michael will be 30 years old" :: Nil)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala new file mode 100644 index 0000000000..8b2ccb52d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.test.TestSQLContext + +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.fs.{Path, FileSystem} + +import parquet.schema.MessageTypeParser +import parquet.hadoop.ParquetFileWriter +import parquet.hadoop.util.ContextUtil + +class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { + override def beforeAll() { + ParquetTestData.writeFile + } + + override def afterAll() { + ParquetTestData.testFile.delete() + } + + test("Import of simple Parquet file") { + val result = getRDD(ParquetTestData.testData).collect() + assert(result.size === 15) + result.zipWithIndex.foreach { + case (row, index) => { + val checkBoolean = + if (index % 3 == 0) + row(0) == true + else + row(0) == false + assert(checkBoolean === true, s"boolean field value in line $index did not match") + if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") + assert(row(2) === "abc", s"string field value in line $index did not match") + assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") + assert(row(4) === 2.5F, s"float field value in line $index did not match") + assert(row(5) === 4.5D, s"double field value in line $index did not match") + } + } + } + + test("Projection of simple Parquet file") { + val scanner = new ParquetTableScan( + ParquetTestData.testData.output, + ParquetTestData.testData, + None)(TestSQLContext.sparkContext) + val projected = scanner.pruneColumns(ParquetTypesConverter + .convertToAttributes(MessageTypeParser + .parseMessageType(ParquetTestData.subTestSchema))) + assert(projected.output.size === 2) + val result = projected + .execute() + .map(_.copy()) + .collect() + result.zipWithIndex.foreach { + case (row, index) => { + if (index % 3 == 0) + assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") + else + assert(row(0) === false, s"boolean field value in line $index did not match") + assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") + assert(row.size === 2, s"number of columns in projection in line $index is incorrect") + } + } + } + + test("Writing metadata from scratch for table CREATE") { + val job = new Job() + val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString) + val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job)) + ParquetTypesConverter.writeMetaData( + ParquetTestData.testData.output, + path, + TestSQLContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + val metaData = ParquetTypesConverter.readMetaData(path) + assert(metaData != null) + ParquetTestData + .testData + .parquetSchema + .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible + metaData + .getFileMetaData + .getSchema + .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible + fs.delete(path, true) + } + + /** + * Computes the given [[ParquetRelation]] and returns its RDD. + * + * @param parquetRelation The Parquet relation. + * @return An RDD of Rows. + */ + private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = { + val scanner = new ParquetTableScan( + parquetRelation.output, + parquetRelation, + None)(TestSQLContext.sparkContext) + scanner + .execute + .map(_.copy()) + } +} + |