aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala62
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala139
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala100
23 files changed, 567 insertions, 241 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 71810b798b..fe83eb1250 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -93,6 +93,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
+ i.copy(
+ table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
case UnresolvedRelation(databaseName, name, alias) =>
catalog.lookupRelation(databaseName, name, alias)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 59fb0311a9..e5a958d599 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
def withName(newName: String): Attribute
def toAttribute = this
- def newInstance: Attribute
+ def newInstance(): Attribute
}
@@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
h
}
- override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+ override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 28d863e58b..4f8ad8a7e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
@@ -73,6 +74,47 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def childrenResolved: Boolean = !children.exists(!_.resolved)
/**
+ * Returns true when the given logical plan will return the same results as this logical plan.
+ *
+ * Since its likely undecideable to generally determine if two given plans will produce the same
+ * results, it is okay for this function to return false, even if the results are actually
+ * the same. Such behavior will not affect correctness, only the application of performance
+ * enhancements like caching. However, it is not acceptable to return true if the results could
+ * possibly be different.
+ *
+ * By default this function performs a modified version of equality that is tolerant of cosmetic
+ * differences like attribute naming and or expression id differences. Logical operators that
+ * can do better should override this function.
+ */
+ def sameResult(plan: LogicalPlan): Boolean = {
+ plan.getClass == this.getClass &&
+ plan.children.size == children.size && {
+ logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]")
+ cleanArgs == plan.cleanArgs
+ } &&
+ (plan.children, children).zipped.forall(_ sameResult _)
+ }
+
+ /** Args that have cleaned such that differences in expression id should not affect equality */
+ protected lazy val cleanArgs: Seq[Any] = {
+ val input = children.flatMap(_.output)
+ productIterator.map {
+ // Children are checked using sameResult above.
+ case tn: TreeNode[_] if children contains tn => null
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case s: Option[_] => s.map {
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case other => other
+ }
+ case s: Seq[_] => s.map {
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case other => other
+ }
+ case other => other
+ }.toSeq
+ }
+
+ /**
* Optionally resolves the given string to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
index f8fe558511..19769986ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
}
override protected def stringArgs = Iterator(output)
+
+ override def sameResult(plan: LogicalPlan): Boolean = plan match {
+ case LocalRelation(otherOutput, otherData) =>
+ otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
+ case _ => false
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 391508279b..f8e9930ac2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -105,8 +105,8 @@ case class InsertIntoTable(
child: LogicalPlan,
overwrite: Boolean)
extends LogicalPlan {
- // The table being inserted into is a child for the purposes of transformations.
- override def children = table :: child :: Nil
+
+ override def children = child :: Nil
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
new file mode 100644
index 0000000000..e8a793d107
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * Provides helper methods for comparing plans.
+ */
+class SameResultSuite extends FunSuite {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
+
+ def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = {
+ val aAnalyzed = a.analyze
+ val bAnalyzed = b.analyze
+
+ if (aAnalyzed.sameResult(bAnalyzed) != result) {
+ val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
+ fail(s"Plans should return sameResult = $result\n$comparison")
+ }
+ }
+
+ test("relations") {
+ assertSameResult(testRelation, testRelation2)
+ }
+
+ test("projections") {
+ assertSameResult(testRelation.select('a), testRelation2.select('a))
+ assertSameResult(testRelation.select('b), testRelation2.select('b))
+ assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
+ assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))
+
+ assertSameResult(testRelation, testRelation2.select('a), false)
+ assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false)
+ }
+
+ test("filters") {
+ assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
new file mode 100644
index 0000000000..aebdbb68e4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
+
+/** Holds a cached logical plan and its data */
+private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
+
+/**
+ * Provides support in a SQLContext for caching query results and automatically using these cached
+ * results when subsequent queries are executed. Data is cached using byte buffers stored in an
+ * InMemoryRelation. This relation is automatically substituted query plans that return the
+ * `sameResult` as the originally cached query.
+ */
+private[sql] trait CacheManager {
+ self: SQLContext =>
+
+ @transient
+ private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
+
+ @transient
+ private val cacheLock = new ReentrantReadWriteLock
+
+ /** Returns true if the table is currently cached in-memory. */
+ def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty
+
+ /** Caches the specified table in-memory. */
+ def cacheTable(tableName: String): Unit = cacheQuery(table(tableName))
+
+ /** Removes the specified table from the in-memory cache. */
+ def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName))
+
+ /** Acquires a read lock on the cache for the duration of `f`. */
+ private def readLock[A](f: => A): A = {
+ val lock = cacheLock.readLock()
+ lock.lock()
+ try f finally {
+ lock.unlock()
+ }
+ }
+
+ /** Acquires a write lock on the cache for the duration of `f`. */
+ private def writeLock[A](f: => A): A = {
+ val lock = cacheLock.writeLock()
+ lock.lock()
+ try f finally {
+ lock.unlock()
+ }
+ }
+
+ private[sql] def clearCache(): Unit = writeLock {
+ cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
+ cachedData.clear()
+ }
+
+ /** Caches the data produced by the logical representation of the given schema rdd. */
+ private[sql] def cacheQuery(
+ query: SchemaRDD,
+ storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock {
+ val planToCache = query.queryExecution.optimizedPlan
+ if (lookupCachedData(planToCache).nonEmpty) {
+ logWarning("Asked to cache already cached data.")
+ } else {
+ cachedData +=
+ CachedData(
+ planToCache,
+ InMemoryRelation(
+ useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan))
+ }
+ }
+
+ /** Removes the data for the given SchemaRDD from the cache */
+ private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock {
+ val planToCache = query.queryExecution.optimizedPlan
+ val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache))
+
+ if (dataIndex < 0) {
+ throw new IllegalArgumentException(s"Table $query is not cached.")
+ }
+
+ cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ cachedData.remove(dataIndex)
+ }
+
+
+ /** Optionally returns cached data for the given SchemaRDD */
+ private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
+ lookupCachedData(query.queryExecution.optimizedPlan)
+ }
+
+ /** Optionally returns cached data for the given LogicalPlan. */
+ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
+ cachedData.find(_.plan.sameResult(plan))
+ }
+
+ /** Replaces segments of the given logical plan with cached versions where possible. */
+ private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
+ plan transformDown {
+ case currentFragment =>
+ lookupCachedData(currentFragment)
+ .map(_.cachedRepresentation.withOutput(currentFragment.output))
+ .getOrElse(currentFragment)
+ }
+ }
+
+ /**
+ * Invalidates the cache of any data that contains `plan`. Note that it is possible that this
+ * function will over invalidate.
+ */
+ private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock {
+ cachedData.foreach {
+ case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
+ data.cachedRepresentation.recache()
+ case _ =>
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a42bedbe6c..7a55c5bf97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext}
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with SQLConf
+ with CacheManager
with ExpressionConversions
with UDFRegistration
with Serializable {
@@ -96,7 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
- new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self))
+ new SchemaRDD(this,
+ LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
}
/**
@@ -133,7 +135,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
// schema differs from the existing schema on any field data type.
- val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self)
+ val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
new SchemaRDD(this, logicalPlan)
}
@@ -272,45 +274,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): SchemaRDD =
new SchemaRDD(this, catalog.lookupRelation(None, tableName))
- /** Caches the specified table in-memory. */
- def cacheTable(tableName: String): Unit = {
- val currentTable = table(tableName).queryExecution.analyzed
- val asInMemoryRelation = currentTable match {
- case _: InMemoryRelation =>
- currentTable
-
- case _ =>
- InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan)
- }
-
- catalog.registerTable(None, tableName, asInMemoryRelation)
- }
-
- /** Removes the specified table from the in-memory cache. */
- def uncacheTable(tableName: String): Unit = {
- table(tableName).queryExecution.analyzed match {
- // This is kind of a hack to make sure that if this was just an RDD registered as a table,
- // we reregister the RDD as a table.
- case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) =>
- inMem.cachedColumnBuffers.unpersist()
- catalog.unregisterTable(None, tableName)
- catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self))
- case inMem: InMemoryRelation =>
- inMem.cachedColumnBuffers.unpersist()
- catalog.unregisterTable(None, tableName)
- case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
- }
- }
-
- /** Returns true if the table is currently cached in-memory. */
- def isCached(tableName: String): Boolean = {
- val relation = table(tableName).queryExecution.analyzed
- relation match {
- case _: InMemoryRelation => true
- case _ => false
- }
- }
-
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
@@ -401,10 +364,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val analyzed = ExtractPythonUdfs(analyzer(logical))
lazy val optimizedPlan = optimizer(analyzed)
+ lazy val withCachedData = useCachedData(optimizedPlan)
+
// TODO: Don't just pick the first one...
lazy val sparkPlan = {
SparkPlan.currentContext.set(self)
- planner(optimizedPlan).next()
+ planner(withCachedData).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
@@ -526,6 +491,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
iter.map { m => new GenericRow(m): Row}
}
- new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self))
+ new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 3b873f7c62..594bf8ffc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
+import org.apache.spark.storage.StorageLevel
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.api.java.JavaRDD
/**
@@ -442,8 +444,7 @@ class SchemaRDD(
*/
private def applySchema(rdd: RDD[Row]): SchemaRDD = {
new SchemaRDD(sqlContext,
- SparkLogicalPlan(
- ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext))
+ LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext))
}
// =======================================================================
@@ -497,4 +498,20 @@ class SchemaRDD(
override def subtract(other: RDD[Row], p: Partitioner)
(implicit ord: Ordering[Row] = null): SchemaRDD =
applySchema(super.subtract(other, p)(ord))
+
+ /** Overridden cache function will always use the in-memory columnar caching. */
+ override def cache(): this.type = {
+ sqlContext.cacheQuery(this)
+ this
+ }
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheQuery(this, newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.uncacheQuery(this, blocking)
+ this
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index e52eeb3e1c..25ba7d88ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.execution.LogicalRDD
/**
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
@@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike {
// For various commands (like DDL) and queries with side effects, we force query optimization to
// happen right away to let these side effects take place eagerly.
case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile =>
- queryExecution.toRdd
- SparkLogicalPlan(queryExecution.executedPlan)(sqlContext)
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
case _ =>
baseLogicalPlan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 150ff8a420..c006c4330f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
import org.apache.spark.util.Utils
@@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
}
}
- new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext))
+ new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
}
/**
@@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType]
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
@@ -151,7 +151,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
@@ -167,7 +167,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType]
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 8a3612cdf1..cec82a7f2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -27,10 +27,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
+import org.apache.spark.storage.StorageLevel
private[sql] object InMemoryRelation {
- def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, child)()
+ def apply(
+ useCompression: Boolean,
+ batchSize: Int,
+ storageLevel: StorageLevel,
+ child: SparkPlan): InMemoryRelation =
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)()
}
private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
@@ -39,6 +44,7 @@ private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
batchSize: Int,
+ storageLevel: StorageLevel,
child: SparkPlan)
(private var _cachedColumnBuffers: RDD[CachedBatch] = null)
extends LogicalPlan with MultiInstanceRelation {
@@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation(
// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
+ buildBuffers()
+ }
+
+ def recache() = {
+ _cachedColumnBuffers.unpersist()
+ _cachedColumnBuffers = null
+ buildBuffers()
+ }
+
+ private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitions { rowIterator =>
new Iterator[CachedBatch] {
@@ -80,12 +96,17 @@ private[sql] case class InMemoryRelation(
def hasNext = rowIterator.hasNext
}
- }.cache()
+ }.persist(storageLevel)
cached.setName(child.toString)
_cachedColumnBuffers = cached
}
+ def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
+ InMemoryRelation(
+ newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers)
+ }
+
override def children = Seq.empty
override def newInstance() = {
@@ -93,6 +114,7 @@ private[sql] case class InMemoryRelation(
output.map(_.newInstance),
useCompression,
batchSize,
+ storageLevel,
child)(
_cachedColumnBuffers).asInstanceOf[this.type]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
new file mode 100644
index 0000000000..2ddf513b6f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+object RDDConversions {
+ def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
+ data.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val bufferedIterator = iterator.buffered
+ val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
+
+ bufferedIterator.map { r =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
+ i += 1
+ }
+
+ mutableRow
+ }
+ }
+ }
+ }
+
+ /*
+ def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = {
+ LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
+ }
+ */
+}
+
+case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
+
+ def children = Nil
+
+ def newInstance() =
+ LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
+
+ override def sameResult(plan: LogicalPlan) = plan match {
+ case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
+ case _ => false
+ }
+
+ @transient override lazy val statistics = Statistics(
+ // TODO: Instead of returning a default value here, find a way to return a meaningful size
+ // estimate for RDDs. See PR 1238 for more discussions.
+ sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
+ )
+}
+
+case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
+ override def execute() = rdd
+}
+
+@deprecated("Use LogicalRDD", "1.2.0")
+case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
+ override def execute() = rdd
+}
+
+@deprecated("Use LogicalRDD", "1.2.0")
+case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
+
+ def output = alreadyPlanned.output
+ override def children = Nil
+
+ override final def newInstance(): this.type = {
+ SparkLogicalPlan(
+ alreadyPlanned match {
+ case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+ case _ => sys.error("Multiple instance of the same relation detected.")
+ })(sqlContext).asInstanceOf[this.type]
+ }
+
+ override def sameResult(plan: LogicalPlan) = plan match {
+ case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
+ rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
+ case _ => false
+ }
+
+ @transient override lazy val statistics = Statistics(
+ // TODO: Instead of returning a default value here, find a way to return a meaningful size
+ // estimate for RDDs. See PR 1238 for more discussions.
+ sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
+ )
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2b8913985b..b1a7948b66 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}
-/**
- * :: DeveloperApi ::
- * Allows already planned SparkQueries to be linked into logical query plans.
- *
- * Note that in general it is not valid to use this class to link multiple copies of the same
- * physical operator into the same query plan as this violates the uniqueness of expression ids.
- * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just
- * replace the output attributes with new copies of themselves without breaking any attribute
- * linking.
- */
-@DeveloperApi
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
- extends LogicalPlan with MultiInstanceRelation {
-
- def output = alreadyPlanned.output
- override def children = Nil
-
- override final def newInstance(): this.type = {
- SparkLogicalPlan(
- alreadyPlanned match {
- case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
- case _ => sys.error("Multiple instance of the same relation detected.")
- })(sqlContext).asInstanceOf[this.type]
- }
-
- @transient override lazy val statistics = Statistics(
- // TODO: Instead of returning a default value here, find a way to return a meaningful size
- // estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
- )
-
-}
-
private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
self: Product =>
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 45687d9604..cf93d5ad7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -272,10 +272,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) =>
- ExistingRdd(
+ PhysicalRDD(
output,
- ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
+ RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
@@ -287,12 +288,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Generate(generator, join, outer, _, child) =>
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
case logical.NoRelation =>
- execution.ExistingRdd(Nil, singleRowRdd) :: Nil
+ execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
case e @ EvaluatePython(udf, child) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
- case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
+ case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index cac376608b..977f3c9f32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -212,45 +212,6 @@ case class Sort(
/**
* :: DeveloperApi ::
- */
-@DeveloperApi
-object ExistingRdd {
- def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
- data.mapPartitions { iterator =>
- if (iterator.isEmpty) {
- Iterator.empty
- } else {
- val bufferedIterator = iterator.buffered
- val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
-
- bufferedIterator.map { r =>
- var i = 0
- while (i < mutableRow.length) {
- mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
- i += 1
- }
-
- mutableRow
- }
- }
- }
- }
-
- def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {
- ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- override def execute() = rdd
-}
-
-/**
- * :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.
* @param partial when true the distinct operation is performed partially, per partition, without
* shuffling the data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 591592841e..957388e99b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -20,13 +20,30 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
case class BigData(s: String)
class CachedTableSuite extends QueryTest {
+ import TestSQLContext._
TestData // Load test tables.
+ /**
+ * Throws a test failed exception when the number of cached tables differs from the expected
+ * number.
+ */
+ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ if (cachedData.size != numCachedTables) {
+ fail(
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+ }
+
test("too big for memory") {
val data = "*" * 10000
sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData")
@@ -35,19 +52,21 @@ class CachedTableSuite extends QueryTest {
uncacheTable("bigData")
}
+ test("calling .cache() should use inmemory columnar caching") {
+ table("testData").cache()
+
+ assertCached(table("testData"))
+ }
+
test("SPARK-1669: cacheTable should be idempotent") {
assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
cacheTable("testData")
- table("testData").queryExecution.analyzed match {
- case _: InMemoryRelation =>
- case _ =>
- fail("testData should be cached")
- }
+ assertCached(table("testData"))
cacheTable("testData")
table("testData").queryExecution.analyzed match {
- case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) =>
+ case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) =>
fail("cacheTable is not idempotent")
case _ =>
@@ -55,81 +74,69 @@ class CachedTableSuite extends QueryTest {
}
test("read from cached table and uncache") {
- TestSQLContext.cacheTable("testData")
+ cacheTable("testData")
checkAnswer(
- TestSQLContext.table("testData"),
+ table("testData"),
testData.collect().toSeq
)
- TestSQLContext.table("testData").queryExecution.analyzed match {
- case _ : InMemoryRelation => // Found evidence of caching
- case noCache => fail(s"No cache node found in plan $noCache")
- }
+ assertCached(table("testData"))
- TestSQLContext.uncacheTable("testData")
+ uncacheTable("testData")
checkAnswer(
- TestSQLContext.table("testData"),
+ table("testData"),
testData.collect().toSeq
)
- TestSQLContext.table("testData").queryExecution.analyzed match {
- case cachePlan: InMemoryRelation =>
- fail(s"Table still cached after uncache: $cachePlan")
- case noCache => // Table uncached successfully
- }
+ assertCached(table("testData"), 0)
}
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- TestSQLContext.uncacheTable("testData")
+ uncacheTable("testData")
}
}
test("SELECT Star Cached Table") {
- TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar")
- TestSQLContext.cacheTable("selectStar")
- TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect()
- TestSQLContext.uncacheTable("selectStar")
+ sql("SELECT * FROM testData").registerTempTable("selectStar")
+ cacheTable("selectStar")
+ sql("SELECT * FROM selectStar WHERE key = 1").collect()
+ uncacheTable("selectStar")
}
test("Self-join cached") {
val unCachedAnswer =
- TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
- TestSQLContext.cacheTable("testData")
+ sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
+ cacheTable("testData")
checkAnswer(
- TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
+ sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
- TestSQLContext.uncacheTable("testData")
+ uncacheTable("testData")
}
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
- TestSQLContext.sql("CACHE TABLE testData")
- TestSQLContext.table("testData").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => // Found evidence of caching
- case _ => fail(s"Table 'testData' should be cached")
- }
- assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached")
+ sql("CACHE TABLE testData")
+ assertCached(table("testData"))
- TestSQLContext.sql("UNCACHE TABLE testData")
- TestSQLContext.table("testData").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached")
- case _ => // Found evidence of uncaching
- }
- assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
+ assert(isCached("testData"), "Table 'testData' should be cached")
+
+ sql("UNCACHE TABLE testData")
+ assertCached(table("testData"), 0)
+ assert(!isCached("testData"), "Table 'testData' should not be cached")
}
test("CACHE TABLE tableName AS SELECT Star Table") {
- TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
- assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
- TestSQLContext.uncacheTable("testCacheTable")
+ sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
+ sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
+ assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
+ uncacheTable("testCacheTable")
}
test("'CACHE TABLE tableName AS SELECT ..'") {
- TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
- TestSQLContext.uncacheTable("testCacheTable")
+ sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
+ assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
+ uncacheTable("testCacheTable")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index c1278248ef..9775dd26b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{QueryTest, TestData}
+import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
class InMemoryColumnarQuerySuite extends QueryTest {
import org.apache.spark.sql.TestData._
@@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("simple columnar query") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().toSeq)
}
@@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("projection") {
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
@@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().toSeq)
checkAnswer(scan, testData.collect().toSeq)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 989a9784a4..cc0605b0ad 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -133,11 +133,6 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
castChildOutput(p, table, child)
-
- case p @ logical.InsertIntoTable(
- InMemoryRelation(_, _, _,
- HiveTableScan(_, table, _)), _, child, _) =>
- castChildOutput(p, table, child)
}
def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = {
@@ -306,7 +301,7 @@ private[hive] case class MetastoreRelation
HiveMetastoreTypes.toDataType(f.getType),
// Since data can be dumped in randomly with no validation, everything is nullable.
nullable = true
- )(qualifiers = tableName +: alias.toSeq)
+ )(qualifiers = Seq(alias.getOrElse(tableName)))
}
// Must be a stable value since new attributes are born here.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 8ac17f3720..508d8239c7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.StringType
-import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
import org.apache.spark.sql.hive
import org.apache.spark.sql.hive.execution._
@@ -161,10 +160,7 @@ private[hive] trait HiveStrategies {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
- case logical.InsertIntoTable(
- InMemoryRelation(_, _, _,
- HiveTableScan(_, table, _)), partition, child, overwrite) =>
- InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
+
case logical.CreateTableAsSelect(database, tableName, child) =>
val query = planLater(child)
CreateTableAsSelect(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 4a999b98ad..c0e69393cc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -353,7 +353,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
var cacheTables: Boolean = false
def loadTestTable(name: String) {
if (!(loadedTables contains name)) {
- // Marks the table as loaded first to prevent infite mutually recursive table loading.
+ // Marks the table as loaded first to prevent infinite mutually recursive table loading.
loadedTables += name
logInfo(s"Loading test table $name")
val createCmds =
@@ -383,6 +383,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
}
+ clearCache()
loadedTables.clear()
catalog.client.getAllTables("default").foreach { t =>
logDebug(s"Deleting table $t")
@@ -428,7 +429,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
loadTestTable("srcpart")
} catch {
case e: Exception =>
- logError(s"FATAL ERROR: Failed to reset TestDB state. $e")
+ logError("FATAL ERROR: Failed to reset TestDB state.", e)
// At this point there is really no reason to continue, but the test framework traps exits.
// So instead we just pause forever so that at least the developer can see where things
// started to go wrong.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 16a8c782ac..f8b4e898ec 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -267,6 +267,9 @@ case class InsertIntoHiveTable(
holdDDLTime)
}
+ // Invalidate the cache.
+ sqlContext.invalidateCache(table)
+
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
// however for now we return an empty list to simplify compatibility checks with hive, which
// does not return anything for insert operations.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index b3057cd618..158cfb5bbe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -17,22 +17,60 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.{QueryTest, SchemaRDD}
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
-import org.apache.spark.sql.hive.execution.HiveComparisonTest
import org.apache.spark.sql.hive.test.TestHive
-class CachedTableSuite extends HiveComparisonTest {
+class CachedTableSuite extends QueryTest {
import TestHive._
- TestHive.loadTestTable("src")
+ /**
+ * Throws a test failed exception when the number of cached tables differs from the expected
+ * number.
+ */
+ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ if (cachedData.size != numCachedTables) {
+ fail(
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+ }
test("cache table") {
- TestHive.cacheTable("src")
+ val preCacheResults = sql("SELECT * FROM src").collect().toSeq
+
+ cacheTable("src")
+ assertCached(sql("SELECT * FROM src"))
+
+ checkAnswer(
+ sql("SELECT * FROM src"),
+ preCacheResults)
+
+ uncacheTable("src")
+ assertCached(sql("SELECT * FROM src"), 0)
}
- createQueryTest("read from cached table",
- "SELECT * FROM src LIMIT 1", reset = false)
+ test("cache invalidation") {
+ sql("CREATE TABLE cachedTable(key INT, value STRING)")
+
+ sql("INSERT INTO TABLE cachedTable SELECT * FROM src")
+ checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq)
+
+ cacheTable("cachedTable")
+ checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq)
+
+ sql("INSERT INTO TABLE cachedTable SELECT * FROM src")
+ checkAnswer(
+ sql("SELECT * FROM cachedTable"),
+ table("src").collect().toSeq ++ table("src").collect().toSeq)
+
+ sql("DROP TABLE cachedTable")
+ }
test("Drop cached table") {
sql("CREATE TABLE test(a INT)")
@@ -48,25 +86,6 @@ class CachedTableSuite extends HiveComparisonTest {
sql("DROP TABLE IF EXISTS nonexistantTable")
}
- test("check that table is cached and uncache") {
- TestHive.table("src").queryExecution.analyzed match {
- case _ : InMemoryRelation => // Found evidence of caching
- case noCache => fail(s"No cache node found in plan $noCache")
- }
- TestHive.uncacheTable("src")
- }
-
- createQueryTest("read from uncached table",
- "SELECT * FROM src LIMIT 1", reset = false)
-
- test("make sure table is uncached") {
- TestHive.table("src").queryExecution.analyzed match {
- case cachePlan: InMemoryRelation =>
- fail(s"Table still cached after uncache: $cachePlan")
- case noCache => // Table uncached successfully
- }
- }
-
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
TestHive.uncacheTable("src")
@@ -75,23 +94,24 @@ class CachedTableSuite extends HiveComparisonTest {
test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") {
TestHive.sql("CACHE TABLE src")
- TestHive.table("src").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => // Found evidence of caching
- case _ => fail(s"Table 'src' should be cached")
- }
+ assertCached(table("src"))
assert(TestHive.isCached("src"), "Table 'src' should be cached")
TestHive.sql("UNCACHE TABLE src")
- TestHive.table("src").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached")
- case _ => // Found evidence of uncaching
- }
+ assertCached(table("src"), 0)
assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
}
-
- test("'CACHE TABLE tableName AS SELECT ..'") {
- TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src")
- assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
- TestHive.uncacheTable("testCacheTable")
- }
+
+ test("CACHE TABLE AS SELECT") {
+ assertCached(sql("SELECT * FROM src"), 0)
+ sql("CACHE TABLE test AS SELECT key FROM src")
+
+ checkAnswer(
+ sql("SELECT * FROM test"),
+ sql("SELECT key FROM src").collect().toSeq)
+
+ assertCached(sql("SELECT * FROM test"))
+
+ assertCached(sql("SELECT * FROM test JOIN test"), 2)
+ }
}