From 6a1d48f4f02c4498b64439c3dd5f671286a90e30 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 3 Oct 2014 12:34:27 -0700 Subject: [SPARK-3212][SQL] Use logical plan matching instead of temporary tables for table caching _Also addresses: SPARK-1671, SPARK-1379 and SPARK-3641_ This PR introduces a new trait, `CacheManger`, which replaces the previous temporary table based caching system. Instead of creating a temporary table that shadows an existing table with and equivalent cached representation, the cached manager maintains a separate list of logical plans and their cached data. After optimization, this list is searched for any matching plan fragments. When a matching plan fragment is found it is replaced with the cached data. There are several advantages to this approach: - Calling .cache() on a SchemaRDD now works as you would expect, and uses the more efficient columnar representation. - Its now possible to provide a list of temporary tables, without having to decide if a given table is actually just a cached persistent table. (To be done in a follow-up PR) - In some cases it is possible that cached data will be used, even if a cached table was not explicitly requested. This is because we now look at the logical structure instead of the table name. - We now correctly invalidate when data is inserted into a hive table. Author: Michael Armbrust Closes #2501 from marmbrus/caching and squashes the following commits: 63fbc2c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching. 0ea889e [Michael Armbrust] Address comments. 1e23287 [Michael Armbrust] Add support for cache invalidation for hive inserts. 65ed04a [Michael Armbrust] fix tests. bdf9a3f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching b4b77f2 [Michael Armbrust] Address comments 6923c9d [Michael Armbrust] More comments / tests 80f26ac [Michael Armbrust] First draft of improved semantics for Spark SQL caching. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 + .../catalyst/expressions/namedExpressions.scala | 4 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 42 +++++++ .../sql/catalyst/plans/logical/TestRelation.scala | 6 + .../catalyst/plans/logical/basicOperators.scala | 4 +- .../spark/sql/catalyst/plans/SameResultSuite.scala | 62 +++++++++ .../scala/org/apache/spark/sql/CacheManager.scala | 139 +++++++++++++++++++++ .../scala/org/apache/spark/sql/SQLContext.scala | 51 ++------ .../scala/org/apache/spark/sql/SchemaRDD.scala | 23 +++- .../scala/org/apache/spark/sql/SchemaRDDLike.scala | 5 +- .../apache/spark/sql/api/java/JavaSQLContext.scala | 10 +- .../sql/columnar/InMemoryColumnarTableScan.scala | 28 ++++- .../apache/spark/sql/execution/ExistingRDD.scala | 119 ++++++++++++++++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 33 ----- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../spark/sql/execution/basicOperators.scala | 39 ------ .../org/apache/spark/sql/CachedTableSuite.scala | 103 ++++++++------- .../sql/columnar/InMemoryColumnarQuerySuite.scala | 7 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 +- .../org/apache/spark/sql/hive/HiveStrategies.scala | 6 +- .../scala/org/apache/spark/sql/hive/TestHive.scala | 5 +- .../sql/hive/execution/InsertIntoHiveTable.scala | 3 + .../apache/spark/sql/hive/CachedTableSuite.scala | 100 +++++++++------ 23 files changed, 567 insertions(+), 241 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 71810b798b..fe83eb1250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -93,6 +93,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object ResolveRelations extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) => + i.copy( + table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias))) case UnresolvedRelation(databaseName, name, alias) => catalog.lookupRelation(databaseName, name, alias) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 59fb0311a9..e5a958d599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def withName(newName: String): Attribute def toAttribute = this - def newInstance: Attribute + def newInstance(): Attribute } @@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea h } - override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 28d863e58b..4f8ad8a7e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees @@ -72,6 +73,47 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = !children.exists(!_.resolved) + /** + * Returns true when the given logical plan will return the same results as this logical plan. + * + * Since its likely undecideable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Logical operators that + * can do better should override this function. + */ + def sameResult(plan: LogicalPlan): Boolean = { + plan.getClass == this.getClass && + plan.children.size == children.size && { + logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") + cleanArgs == plan.cleanArgs + } && + (plan.children, children).zipped.forall(_ sameResult _) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + val input = children.flatMap(_.output) + productIterator.map { + // Children are checked using sameResult above. + case tn: TreeNode[_] if children contains tn => null + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case s: Option[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case s: Seq[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case other => other + }.toSeq + } + /** * Optionally resolves the given string to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index f8fe558511..19769986ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) } override protected def stringArgs = Iterator(output) + + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LocalRelation(otherOutput, otherData) => + otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 391508279b..f8e9930ac2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -105,8 +105,8 @@ case class InsertIntoTable( child: LogicalPlan, overwrite: Boolean) extends LogicalPlan { - // The table being inserted into is a child for the purposes of transformations. - override def children = table :: child :: Nil + + override def children = child :: Nil override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala new file mode 100644 index 0000000000..e8a793d107 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +class SameResultSuite extends FunSuite { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + val aAnalyzed = a.analyze + val bAnalyzed = b.analyze + + if (aAnalyzed.sameResult(bAnalyzed) != result) { + val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n") + fail(s"Plans should return sameResult = $result\n$comparison") + } + } + + test("relations") { + assertSameResult(testRelation, testRelation2) + } + + test("projections") { + assertSameResult(testRelation.select('a), testRelation2.select('a)) + assertSameResult(testRelation.select('b), testRelation2.select('b)) + assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) + + assertSameResult(testRelation, testRelation2.select('a), false) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false) + } + + test("filters") { + assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala new file mode 100644 index 0000000000..aebdbb68e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.locks.ReentrantReadWriteLock + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY + +/** Holds a cached logical plan and its data */ +private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) + +/** + * Provides support in a SQLContext for caching query results and automatically using these cached + * results when subsequent queries are executed. Data is cached using byte buffers stored in an + * InMemoryRelation. This relation is automatically substituted query plans that return the + * `sameResult` as the originally cached query. + */ +private[sql] trait CacheManager { + self: SQLContext => + + @transient + private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + + @transient + private val cacheLock = new ReentrantReadWriteLock + + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = cacheQuery(table(tableName)) + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val lock = cacheLock.readLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val lock = cacheLock.writeLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + private[sql] def clearCache(): Unit = writeLock { + cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.clear() + } + + /** Caches the data produced by the logical representation of the given schema rdd. */ + private[sql] def cacheQuery( + query: SchemaRDD, + storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + if (lookupCachedData(planToCache).nonEmpty) { + logWarning("Asked to cache already cached data.") + } else { + cachedData += + CachedData( + planToCache, + InMemoryRelation( + useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan)) + } + } + + /** Removes the data for the given SchemaRDD from the cache */ + private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache)) + + if (dataIndex < 0) { + throw new IllegalArgumentException(s"Table $query is not cached.") + } + + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + + + /** Optionally returns cached data for the given SchemaRDD */ + private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + lookupCachedData(query.queryExecution.optimizedPlan) + } + + /** Optionally returns cached data for the given LogicalPlan. */ + private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { + cachedData.find(_.plan.sameResult(plan)) + } + + /** Replaces segments of the given logical plan with cached versions where possible. */ + private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + plan transformDown { + case currentFragment => + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) + } + } + + /** + * Invalidates the cache of any data that contains `plan`. Note that it is possible that this + * function will over invalidate. + */ + private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { + cachedData.foreach { + case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => + data.cachedRepresentation.recache() + case _ => + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a42bedbe6c..7a55c5bf97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext} class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with SQLConf + with CacheManager with ExpressionConversions with UDFRegistration with Serializable { @@ -96,7 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + new SchemaRDD(this, + LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } /** @@ -133,7 +135,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) new SchemaRDD(this, logicalPlan) } @@ -272,45 +274,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) - /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = { - val currentTable = table(tableName).queryExecution.analyzed - val asInMemoryRelation = currentTable match { - case _: InMemoryRelation => - currentTable - - case _ => - InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) - } - - catalog.registerTable(None, tableName, asInMemoryRelation) - } - - /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = { - table(tableName).queryExecution.analyzed match { - // This is kind of a hack to make sure that if this was just an RDD registered as a table, - // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) - case inMem: InMemoryRelation => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") - } - } - - /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = { - val relation = table(tableName).queryExecution.analyzed - relation match { - case _: InMemoryRelation => true - case _ => false - } - } - protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -401,10 +364,12 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) + lazy val withCachedData = useCachedData(optimizedPlan) + // TODO: Don't just pick the first one... lazy val sparkPlan = { SparkPlan.currentContext.set(self) - planner(optimizedPlan).next() + planner(withCachedData).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. @@ -526,6 +491,6 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3b873f7c62..594bf8ffc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.storage.StorageLevel + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.api.java.JavaRDD /** @@ -442,8 +444,7 @@ class SchemaRDD( */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { new SchemaRDD(sqlContext, - SparkLogicalPlan( - ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext)) + LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) } // ======================================================================= @@ -497,4 +498,20 @@ class SchemaRDD( override def subtract(other: RDD[Row], p: Partitioner) (implicit ord: Ordering[Row] = null): SchemaRDD = applySchema(super.subtract(other, p)(ord)) + + /** Overridden cache function will always use the in-memory columnar caching. */ + override def cache(): this.type = { + sqlContext.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheQuery(this, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.uncacheQuery(this, blocking) + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index e52eeb3e1c..25ba7d88ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.execution.LogicalRDD /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) @@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => - queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => baseLogicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 150ff8a420..c006c4330f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) } /** @@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val scalaRowRDD = rowRDD.rdd.map(r => r.row) val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = - SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -151,7 +151,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -167,7 +167,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 8a3612cdf1..cec82a7f2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,10 +27,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.storage.StorageLevel private[sql] object InMemoryRelation { - def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, child)() + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) @@ -39,6 +44,7 @@ private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, + storageLevel: StorageLevel, child: SparkPlan) (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { @@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation( // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { + buildBuffers() + } + + def recache() = { + _cachedColumnBuffers.unpersist() + _cachedColumnBuffers = null + buildBuffers() + } + + private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { @@ -80,12 +96,17 @@ private[sql] case class InMemoryRelation( def hasNext = rowIterator.hasNext } - }.cache() + }.persist(storageLevel) cached.setName(child.toString) _cachedColumnBuffers = cached } + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + InMemoryRelation( + newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers) + } + override def children = Seq.empty override def newInstance() = { @@ -93,6 +114,7 @@ private[sql] case class InMemoryRelation( output.map(_.newInstance), useCompression, batchSize, + storageLevel, child)( _cachedColumnBuffers).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala new file mode 100644 index 0000000000..2ddf513b6f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object RDDConversions { + def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + i += 1 + } + + mutableRow + } + } + } + } + + /* + def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { + LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) + } + */ +} + +case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def children = Nil + + def newInstance() = + LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] + + override def sameResult(plan: LogicalPlan) = plan match { + case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} + +case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def output = alreadyPlanned.output + override def children = Nil + + override final def newInstance(): this.type = { + SparkLogicalPlan( + alreadyPlanned match { + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case _ => sys.error("Multiple instance of the same relation detected.") + })(sqlContext).asInstanceOf[this.type] + } + + override def sameResult(plan: LogicalPlan) = plan match { + case SparkLogicalPlan(ExistingRdd(_, rdd)) => + rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8913985b..b1a7948b66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -/** - * :: DeveloperApi :: - * Allows already planned SparkQueries to be linked into logical query plans. - * - * Note that in general it is not valid to use this class to link multiple copies of the same - * physical operator into the same query plan as this violates the uniqueness of expression ids. - * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just - * replace the output attributes with new copies of themselves without breaking any attribute - * linking. - */ -@DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - def output = alreadyPlanned.output - override def children = Nil - - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } - - @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) - ) - -} - private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { self: Product => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 45687d9604..cf93d5ad7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -272,10 +272,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - ExistingRdd( + PhysicalRDD( output, - ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -287,12 +288,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil case logical.NoRelation => - execution.ExistingRdd(Nil, singleRowRdd) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cac376608b..977f3c9f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -210,45 +210,6 @@ case class Sort( override def output = child.output } -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object ExistingRdd { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { - data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) - i += 1 - } - - mutableRow - } - } - } - } - - def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = { - ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 591592841e..957388e99b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,13 +20,30 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ case class BigData(s: String) class CachedTableSuite extends QueryTest { + import TestSQLContext._ TestData // Load test tables. + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } + test("too big for memory") { val data = "*" * 10000 sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") @@ -35,19 +52,21 @@ class CachedTableSuite extends QueryTest { uncacheTable("bigData") } + test("calling .cache() should use inmemory columnar caching") { + table("testData").cache() + + assertCached(table("testData")) + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) cacheTable("testData") - table("testData").queryExecution.analyzed match { - case _: InMemoryRelation => - case _ => - fail("testData should be cached") - } + assertCached(table("testData")) cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => @@ -55,81 +74,69 @@ class CachedTableSuite extends QueryTest { } test("read from cached table and uncache") { - TestSQLContext.cacheTable("testData") + cacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } + assertCached(table("testData")) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } + assertCached(table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") - TestSQLContext.cacheTable("selectStar") - TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() - TestSQLContext.uncacheTable("selectStar") + sql("SELECT * FROM testData").registerTempTable("selectStar") + cacheTable("selectStar") + sql("SELECT * FROM selectStar WHERE key = 1").collect() + uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - TestSQLContext.cacheTable("testData") + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() + cacheTable("testData") checkAnswer( - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { - TestSQLContext.sql("CACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'testData' should be cached") - } - assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached") + sql("CACHE TABLE testData") + assertCached(table("testData")) - TestSQLContext.sql("UNCACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached") - case _ => // Found evidence of uncaching - } - assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") + assert(isCached("testData"), "Table 'testData' should be cached") + + sql("UNCACHE TABLE testData") + assertCached(table("testData"), 0) + assert(!isCached("testData"), "Table 'testData' should not be cached") } test("CACHE TABLE tableName AS SELECT Star Table") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } test("'CACHE TABLE tableName AS SELECT ..'") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index c1278248ef..9775dd26b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, TestData} +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) } @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 989a9784a4..cc0605b0ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -133,11 +133,6 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) - - case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), _, child, _) => - castChildOutput(p, table, child) } def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { @@ -306,7 +301,7 @@ private[hive] case class MetastoreRelation HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true - )(qualifiers = tableName +: alias.toSeq) + )(qualifiers = Seq(alias.getOrElse(tableName))) } // Must be a stable value since new attributes are born here. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8ac17f3720..508d8239c7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType -import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ @@ -161,10 +160,7 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), partition, child, overwrite) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.CreateTableAsSelect(database, tableName, child) => val query = planLater(child) CreateTableAsSelect( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 4a999b98ad..c0e69393cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -353,7 +353,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { - // Marks the table as loaded first to prevent infite mutually recursive table loading. + // Marks the table as loaded first to prevent infinite mutually recursive table loading. loadedTables += name logInfo(s"Loading test table $name") val createCmds = @@ -383,6 +383,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } + clearCache() loadedTables.clear() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") @@ -428,7 +429,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logError(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError("FATAL ERROR: Failed to reset TestDB state.", e) // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 16a8c782ac..f8b4e898ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -267,6 +267,9 @@ case class InsertIntoHiveTable( holdDDLTime) } + // Invalidate the cache. + sqlContext.invalidateCache(table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index b3057cd618..158cfb5bbe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,22 +17,60 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.{QueryTest, SchemaRDD} import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} -import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive -class CachedTableSuite extends HiveComparisonTest { +class CachedTableSuite extends QueryTest { import TestHive._ - TestHive.loadTestTable("src") + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } test("cache table") { - TestHive.cacheTable("src") + val preCacheResults = sql("SELECT * FROM src").collect().toSeq + + cacheTable("src") + assertCached(sql("SELECT * FROM src")) + + checkAnswer( + sql("SELECT * FROM src"), + preCacheResults) + + uncacheTable("src") + assertCached(sql("SELECT * FROM src"), 0) } - createQueryTest("read from cached table", - "SELECT * FROM src LIMIT 1", reset = false) + test("cache invalidation") { + sql("CREATE TABLE cachedTable(key INT, value STRING)") + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + cacheTable("cachedTable") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer( + sql("SELECT * FROM cachedTable"), + table("src").collect().toSeq ++ table("src").collect().toSeq) + + sql("DROP TABLE cachedTable") + } test("Drop cached table") { sql("CREATE TABLE test(a INT)") @@ -48,25 +86,6 @@ class CachedTableSuite extends HiveComparisonTest { sql("DROP TABLE IF EXISTS nonexistantTable") } - test("check that table is cached and uncache") { - TestHive.table("src").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } - TestHive.uncacheTable("src") - } - - createQueryTest("read from uncached table", - "SELECT * FROM src LIMIT 1", reset = false) - - test("make sure table is uncached") { - TestHive.table("src").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } - } - test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { TestHive.uncacheTable("src") @@ -75,23 +94,24 @@ class CachedTableSuite extends HiveComparisonTest { test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { TestHive.sql("CACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'src' should be cached") - } + assertCached(table("src")) assert(TestHive.isCached("src"), "Table 'src' should be cached") TestHive.sql("UNCACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") - case _ => // Found evidence of uncaching - } + assertCached(table("src"), 0) assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } - - test("'CACHE TABLE tableName AS SELECT ..'") { - TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestHive.uncacheTable("testCacheTable") - } + + test("CACHE TABLE AS SELECT") { + assertCached(sql("SELECT * FROM src"), 0) + sql("CACHE TABLE test AS SELECT key FROM src") + + checkAnswer( + sql("SELECT * FROM test"), + sql("SELECT key FROM src").collect().toSeq) + + assertCached(sql("SELECT * FROM test")) + + assertCached(sql("SELECT * FROM test JOIN test"), 2) + } } -- cgit v1.2.3