aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-02-06 12:38:07 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-06 12:38:07 -0800
commit3eccf29ce061559c86e6f7338851932fc89a9afd (patch)
tree9d8dca5ab32cd82f04d6e96287e1d56d8d56559e
parent0b7eb3f3b700080bf6cb810d092709a8a468e5db (diff)
downloadspark-3eccf29ce061559c86e6f7338851932fc89a9afd.tar.gz
spark-3eccf29ce061559c86e6f7338851932fc89a9afd.tar.bz2
spark-3eccf29ce061559c86e6f7338851932fc89a9afd.zip
[SPARK-5595][SPARK-5603][SQL] Add a rule to do PreInsert type casting and field renaming and invalidating in memory cache after INSERT
This PR adds a rule to Analyzer that will add preinsert data type casting and field renaming to the select clause in an `INSERT INTO/OVERWRITE` statement. Also, with the change of this PR, we always invalidate our in memory data cache after inserting into a BaseRelation. cc marmbrus liancheng Author: Yin Huai <yhuai@databricks.com> Closes #4373 from yhuai/insertFollowUp and squashes the following commits: 08237a7 [Yin Huai] Merge remote-tracking branch 'upstream/master' into insertFollowUp 316542e [Yin Huai] Doc update. c9ccfeb [Yin Huai] Revert a unnecessary change. 84aecc4 [Yin Huai] Address comments. 1951fe1 [Yin Huai] Merge remote-tracking branch 'upstream/master' c18da34 [Yin Huai] Invalidate cache after insert. 727f21a [Yin Huai] Preinsert casting and renaming.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala76
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala80
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
10 files changed, 227 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 706ef6ad4f..bf39906710 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -91,7 +91,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] lazy val analyzer: Analyzer =
- new Analyzer(catalog, functionRegistry, caseSensitive = true)
+ new Analyzer(catalog, functionRegistry, caseSensitive = true) {
+ override val extendedRules =
+ sources.PreInsertCastAndRename ::
+ Nil
+ }
@transient
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f27585d05a..c4e14c6c92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -72,7 +72,6 @@ private[sql] case class JSONRelation(
userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends TableScan with InsertableRelation {
-
// TODO: Support partitioned JSON relation.
private def baseRDD = sqlContext.sparkContext.textFile(path)
@@ -99,10 +98,21 @@ private[sql] case class JSONRelation(
s"Unable to clear output directory ${filesystemPath.toString} prior"
+ s" to INSERT OVERWRITE a JSON table:\n${e.toString}")
}
+ // Write the data.
data.toJSON.saveAsTextFile(path)
+ // Right now, we assume that the schema is not changed. We will not update the schema.
+ // schema = data.schema
} else {
// TODO: Support INSERT INTO
sys.error("JSON table only support INSERT OVERWRITE for now.")
}
}
+
+ override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode()
+
+ override def equals(other: Any): Boolean = other match {
+ case that: JSONRelation =>
+ (this.path == that.path) && (this.schema == that.schema)
+ case _ => false
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index d23ffb8b7a..624369afe8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -59,7 +59,7 @@ private[sql] object DataSourceStrategy extends Strategy {
if (partition.nonEmpty) {
sys.error(s"Insert into a partition is not allowed because $l is not partitioned.")
}
- execution.ExecutedCommand(InsertIntoRelation(t, query, overwrite)) :: Nil
+ execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil
case _ => Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index d7942dc309..c9cd0e6e93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -19,17 +19,21 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand}
-private[sql] case class InsertIntoRelation(
- relation: InsertableRelation,
+private[sql] case class InsertIntoDataSource(
+ logicalRelation: LogicalRelation,
query: LogicalPlan,
overwrite: Boolean)
extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
+ val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
relation.insert(DataFrame(sqlContext, query), overwrite)
+ // Invalidate the cache.
+ sqlContext.cacheManager.invalidateCache(logicalRelation)
+
Seq.empty[Row]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 40fc1f2aa2..a640ba57e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -158,6 +158,22 @@ trait CatalystScan extends BaseRelation {
}
@DeveloperApi
+/**
+ * ::DeveloperApi::
+ * A BaseRelation that can be used to insert data into it through the insert method.
+ * If overwrite in insert method is true, the old data in the relation should be overwritten with
+ * the new data. If overwrite in insert method is false, the new data should be appended.
+ *
+ * InsertableRelation has the following three assumptions.
+ * 1. It assumes that the data (Rows in the DataFrame) provided to the insert method
+ * exactly matches the ordinal of fields in the schema of the BaseRelation.
+ * 2. It assumes that the schema of this relation will not be changed.
+ * Even if the insert method updates the schema (e.g. a relation of JSON or Parquet data may have a
+ * schema update after an insert operation), the new schema will not be used.
+ * 3. It assumes that fields of the data provided in the insert method are nullable.
+ * If a data source needs to check the actual nullability of a field, it needs to do it in the
+ * insert method.
+ */
trait InsertableRelation extends BaseRelation {
def insert(data: DataFrame, overwrite: Boolean): Unit
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
new file mode 100644
index 0000000000..4ed22d363d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias}
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A rule to do pre-insert data type casting and field renaming. Before we insert into
+ * an [[InsertableRelation]], we will use this rule to make sure that
+ * the columns to be inserted have the correct data type and fields have the correct names.
+ * @param resolver The resolver used by the Analyzer.
+ */
+private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.transform {
+ // Wait until children are resolved.
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ // We are inserting into an InsertableRelation.
+ case i @ InsertIntoTable(
+ l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite) => {
+ // First, make sure the data to be inserted have the same number of fields with the
+ // schema of the relation.
+ if (l.output.size != child.output.size) {
+ sys.error(
+ s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " +
+ s"statement generates the same number of columns as its schema.")
+ }
+ castAndRenameChildOutput(i, l.output, child)
+ }
+ }
+ }
+
+ /** If necessary, cast data types and rename fields to the expected types and names. */
+ def castAndRenameChildOutput(
+ insertInto: InsertIntoTable,
+ expectedOutput: Seq[Attribute],
+ child: LogicalPlan) = {
+ val newChildOutput = expectedOutput.zip(child.output).map {
+ case (expected, actual) =>
+ val needCast = !DataType.equalsIgnoreNullability(expected.dataType, actual.dataType)
+ // We want to make sure the filed names in the data to be inserted exactly match
+ // names in the schema.
+ val needRename = expected.name != actual.name
+ (needCast, needRename) match {
+ case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)()
+ case (false, true) => Alias(actual, expected.name)()
+ case (_, _) => actual
+ }
+ }
+
+ if (newChildOutput == child.output) {
+ insertInto
+ } else {
+ insertInto.copy(child = Project(newChildOutput, child))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 1396c6b724..926ba68828 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
+import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
@@ -923,6 +924,30 @@ class JsonSuite extends QueryTest {
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"),
Row(5, null)
)
+ }
+ test("JSONRelation equality test") {
+ val relation1 =
+ JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null)
+ val logicalRelation1 = LogicalRelation(relation1)
+ val relation2 =
+ JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(
+ org.apache.spark.sql.test.TestSQLContext)
+ val logicalRelation2 = LogicalRelation(relation2)
+ val relation3 =
+ JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null)
+ val logicalRelation3 = LogicalRelation(relation3)
+
+ assert(relation1 === relation2)
+ assert(logicalRelation1.sameResult(logicalRelation2),
+ s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.")
+
+ assert(relation1 !== relation3)
+ assert(!logicalRelation1.sameResult(logicalRelation3),
+ s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.")
+
+ assert(relation2 !== relation3)
+ assert(!logicalRelation2.sameResult(logicalRelation3),
+ s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 9626252e74..53f5f7426e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -28,7 +28,11 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) {
@transient
override protected[sql] lazy val analyzer: Analyzer =
- new Analyzer(catalog, functionRegistry, caseSensitive = false)
+ new Analyzer(catalog, functionRegistry, caseSensitive = false) {
+ override val extendedRules =
+ PreInsertCastAndRename ::
+ Nil
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
index f91cea6a37..36e504e759 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
@@ -63,6 +63,41 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
)
}
+ test("PreInsert casting and renaming") {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i * 2, s"${i * 4}"))
+ )
+
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i * 4, s"${i * 6}"))
+ )
+ }
+
+ test("SELECT clause generating a different number of columns is not allowed.") {
+ val message = intercept[RuntimeException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
+ """.stripMargin)
+ }.getMessage
+ assert(
+ message.contains("generates the same number of columns as its schema"),
+ "SELECT clause generating a different number of columns should not be not allowed."
+ )
+ }
+
test("INSERT OVERWRITE a JSONRelation multiple times") {
sql(
s"""
@@ -93,4 +128,49 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
""".stripMargin)
}
}
+
+ test("Caching") {
+ // Cached Query Execution
+ cacheTable("jsonTable")
+ assertCached(sql("SELECT * FROM jsonTable"))
+ checkAnswer(
+ sql("SELECT * FROM jsonTable"),
+ (1 to 10).map(i => Row(i, s"str$i")))
+
+ assertCached(sql("SELECT a FROM jsonTable"))
+ checkAnswer(
+ sql("SELECT a FROM jsonTable"),
+ (1 to 10).map(Row(_)).toSeq)
+
+ assertCached(sql("SELECT a FROM jsonTable WHERE a < 5"))
+ checkAnswer(
+ sql("SELECT a FROM jsonTable WHERE a < 5"),
+ (1 to 4).map(Row(_)).toSeq)
+
+ assertCached(sql("SELECT a * 2 FROM jsonTable"))
+ checkAnswer(
+ sql("SELECT a * 2 FROM jsonTable"),
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
+ checkAnswer(
+ sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
+ (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+ // Insert overwrite and keep the same schema.
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt
+ """.stripMargin)
+ // jsonTable should be recached.
+ assertCached(sql("SELECT * FROM jsonTable"))
+ // The cached data is the new data.
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ sql("SELECT a * 2, b FROM jt").collect())
+
+ // Verify uncaching
+ uncacheTable("jsonTable")
+ assertCached(sql("SELECT * FROM jsonTable"), 0)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index d2371d4a55..ad37b7d0e6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -324,6 +324,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
ResolveUdtfsAlias ::
+ sources.PreInsertCastAndRename ::
Nil
}