aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-02-10 17:29:52 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-10 17:29:52 -0800
commitaaf50d05c7616e4f8f16654b642500ae06cdd774 (patch)
tree7f30e0d08e4f2b531ac62c82a4361a2db577932d
parented167e70c6d355f39b366ea0d3b92dd26d826a0b (diff)
downloadspark-aaf50d05c7616e4f8f16654b642500ae06cdd774.tar.gz
spark-aaf50d05c7616e4f8f16654b642500ae06cdd774.tar.bz2
spark-aaf50d05c7616e4f8f16654b642500ae06cdd774.zip
[SPARK-5658][SQL] Finalize DDL and write support APIs
https://issues.apache.org/jira/browse/SPARK-5658 Author: Yin Huai <yhuai@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #4446 from yhuai/writeSupportFollowup and squashes the following commits: f3a96f7 [Yin Huai] davies's comments. 225ff71 [Yin Huai] Use Scala TestHiveContext to initialize the Python HiveContext in Python tests. 2306f93 [Yin Huai] Style. 2091fcd [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 537e28f [Yin Huai] Correctly clean up temp data. ae4649e [Yin Huai] Fix Python test. 609129c [Yin Huai] Doc format. 92b6659 [Yin Huai] Python doc and other minor updates. cbc717f [Yin Huai] Rename dataSourceName to source. d1c12d3 [Yin Huai] No need to delete the duplicate rule since it has been removed in master. 22cfa70 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup d91ecb8 [Yin Huai] Fix test. 4c76d78 [Yin Huai] Simplify APIs. 3abc215 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 0832ce4 [Yin Huai] Fix test. 98e7cdb [Yin Huai] Python style. 2bf44ef [Yin Huai] Python APIs. c204967 [Yin Huai] Format a10223d [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 9ff97d8 [Yin Huai] Add SaveMode to saveAsTable. 9b6e570 [Yin Huai] Update doc. c2be775 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 99950a2 [Yin Huai] Use Java enum for SaveMode. 4679665 [Yin Huai] Remove duplicate rule. 77d89dc [Yin Huai] Update doc. e04d908 [Yin Huai] Move import and add (Scala-specific) to scala APIs. cf5703d [Yin Huai] Add checkAnswer to Java tests. 7db95ff [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 6dfd386 [Yin Huai] Add java test. f2f33ef [Yin Huai] Fix test. e702386 [Yin Huai] Apache header. b1e9b1b [Yin Huai] Format. ed4e1b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup af9e9b3 [Yin Huai] DDL and write support API followup. 2a6213a [Yin Huai] Update API names. e6a0b77 [Yin Huai] Update test. 43bae01 [Yin Huai] Remove createTable from HiveContext. 5ffc372 [Yin Huai] Add more load APIs to SQLContext. 5390743 [Yin Huai] Add more save APIs to DataFrame.
-rw-r--r--python/pyspark/sql/context.py68
-rw-r--r--python/pyspark/sql/dataframe.py72
-rw-r--r--python/pyspark/sql/tests.py107
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala160
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala164
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala19
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java97
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala92
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala59
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala13
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala105
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala (renamed from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala)20
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java147
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala64
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala33
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala118
26 files changed, 1357 insertions, 350 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 49f016a9cf..882c0f98ea 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -21,6 +21,7 @@ from array import array
from itertools import imap
from py4j.protocol import Py4JError
+from py4j.java_collections import MapConverter
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -87,6 +88,18 @@ class SQLContext(object):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
+ def setConf(self, key, value):
+ """Sets the given Spark SQL configuration property.
+ """
+ self._ssql_ctx.setConf(key, value)
+
+ def getConf(self, key, defaultValue):
+ """Returns the value of Spark SQL configuration property for the given key.
+
+ If the key is not set, returns defaultValue.
+ """
+ return self._ssql_ctx.getConf(key, defaultValue)
+
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -455,6 +468,61 @@ class SQLContext(object):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
+ def load(self, path=None, source=None, schema=None, **options):
+ """Returns the dataset in a data source as a DataFrame.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.load(source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.load(source, scala_datatype, joptions)
+ return DataFrame(df, self)
+
+ def createExternalTable(self, tableName, path=None, source=None,
+ schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame and
+ created external table.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
+ joptions)
+ return DataFrame(df, self)
+
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 04be65fe24..3eef0cc376 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -146,9 +146,75 @@ class DataFrame(object):
"""
self._jdf.insertInto(tableName, overwrite)
- def saveAsTable(self, tableName):
- """Creates a new table with the contents of this DataFrame."""
- self._jdf.saveAsTable(tableName)
+ def _java_save_mode(self, mode):
+ """Returns the Java save mode based on the Python save mode represented by a string.
+ """
+ jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode
+ jmode = jSaveMode.ErrorIfExists
+ mode = mode.lower()
+ if mode == "append":
+ jmode = jSaveMode.Append
+ elif mode == "overwrite":
+ jmode = jSaveMode.Overwrite
+ elif mode == "ignore":
+ jmode = jSaveMode.Ignore
+ elif mode == "error":
+ pass
+ else:
+ raise ValueError(
+ "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
+ return jmode
+
+ def saveAsTable(self, tableName, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source as a table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the saveAsTable operation when
+ table already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by the contents of \
+ this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing table.
+ """
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self.sql_ctx._sc._gateway._gateway_client)
+ self._jdf.saveAsTable(tableName, source, jmode, joptions)
+
+ def save(self, path=None, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the save operation when
+ data already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing data.
+ * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing data.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ self._jdf.save(source, jmode, joptions)
def schema(self):
"""Returns the schema of this DataFrame (represented by
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d25c6365ed..bc945091f7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -34,10 +34,9 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-
-from pyspark.sql import SQLContext, Column
+from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType
+ UserDefinedType, DoubleType, LongType, StringType
from pyspark.tests import ReusedPySparkTestCase
@@ -286,6 +285,37 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ def test_save_and_load(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.save(tmpPath, "org.apache.spark.sql.json", "error")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+
+ df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
+ noUse="this options will not be used in save.")
+ actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
+ noUse="this options will not be used in load.")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -296,5 +326,76 @@ class SQLTests(ReusedPySparkTestCase):
pydoc.render_doc(df.take(1))
+class HiveContextSQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ print "type", type(cls.sc)
+ print "type", type(cls.sc._jsc)
+ _scala_HiveContext =\
+ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
+ cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_save_and_load_table(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
+ "org.apache.spark.sql.json")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.createExternalTable("externalJsonTable",
+ source="org.apache.spark.sql.json",
+ schema=schema, path=tmpPath,
+ noUse="this options will not be used")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.select("value").collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
if __name__ == "__main__":
unittest.main()
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
new file mode 100644
index 0000000000..3109f5716d
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources;
+
+/**
+ * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
+ */
+public enum SaveMode {
+ /**
+ * Append mode means that when saving a DataFrame to a data source, if data/table already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ */
+ Append,
+ /**
+ * Overwrite mode means that when saving a DataFrame to a data source,
+ * if data/table already exists, existing data is expected to be overwritten by the contents of
+ * the DataFrame.
+ */
+ Overwrite,
+ /**
+ * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists,
+ * an exception is expected to be thrown.
+ */
+ ErrorIfExists,
+ /**
+ * Ignore mode means that when saving a DataFrame to a data source, if data already exists,
+ * the save operation is expected to not save the contents of the DataFrame and to not
+ * change the existing data.
+ */
+ Ignore
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 04e0d09947..ca8d552c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -17,19 +17,19 @@
package org.apache.spark.sql
+import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.sources.SaveMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-import scala.util.control.NonFatal
-
-
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrameImpl(sqlContext, logicalPlan)
@@ -574,8 +574,64 @@ trait DataFrame extends RDDApi[Row] {
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame. This will fail if the table already
- * exists.
+ * Creates a table from the the contents of this DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ * This will fail if the table already exists.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(tableName: String): Unit = {
+ saveAsTable(tableName, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table from the the contents of this DataFrame, using the default data source
+ * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(tableName: String, mode: SaveMode): Unit = {
+ if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) {
+ // If table already exists and the save mode is Append,
+ // we will just call insertInto to append the contents of this DataFrame.
+ insertInto(tableName, overwrite = false)
+ } else {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ saveAsTable(tableName, dataSourceName, mode)
+ }
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source and a set of options,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(
+ tableName: String,
+ source: String): Unit = {
+ saveAsTable(tableName, source, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -583,12 +639,17 @@ trait DataFrame extends RDDApi[Row] {
* be the target of an `insertInto`.
*/
@Experimental
- def saveAsTable(tableName: String): Unit
+ def saveAsTable(
+ tableName: String,
+ source: String,
+ mode: SaveMode): Unit = {
+ saveAsTable(tableName, source, mode, Map.empty[String, String])
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -598,14 +659,17 @@ trait DataFrame extends RDDApi[Row] {
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ saveAsTable(tableName, source, mode, options.toMap)
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * (Scala-specific)
+ * Creates a table from the the contents of this DataFrame based on a given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -615,22 +679,76 @@ trait DataFrame extends RDDApi[Row] {
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path,
+ * using the default data source configured by spark.sql.sources.default and
+ * [[SaveMode.ErrorIfExists]] as the save mode.
+ */
+ @Experimental
+ def save(path: String): Unit = {
+ save(path, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode,
+ * using the default data source configured by spark.sql.sources.default.
+ */
+ @Experimental
+ def save(path: String, mode: SaveMode): Unit = {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ save(path, dataSourceName, mode)
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
+ */
+ @Experimental
+ def save(path: String, source: String): Unit = {
+ save(source, SaveMode.ErrorIfExists, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source and
+ * [[SaveMode]] specified by mode.
+ */
@Experimental
- def save(path: String): Unit
+ def save(path: String, source: String, mode: SaveMode): Unit = {
+ save(source, mode, Map("path" -> path))
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
+ */
@Experimental
def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ save(source, mode, options.toMap)
+ }
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options
+ */
@Experimental
def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 1ee16ad516..11f9334556 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -28,13 +28,14 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan}
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}
@@ -341,68 +342,34 @@ private[sql] class DataFrameImpl protected[sql](
override def saveAsParquetFile(path: String): Unit = {
if (sqlContext.conf.parquetUseDataSourceApi) {
- save("org.apache.spark.sql.parquet", "path" -> path)
+ save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path))
} else {
sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
}
}
- override def saveAsTable(tableName: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- val cmd =
- CreateTableUsingAsLogicalPlan(
- tableName,
- dataSourceName,
- temporary = false,
- Map.empty,
- allowExisting = false,
- logicalPlan)
-
- sqlContext.executePlan(cmd).toRdd
- }
-
override def saveAsTable(
tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
val cmd =
CreateTableUsingAsLogicalPlan(
tableName,
- dataSourceName,
+ source,
temporary = false,
- (option +: options).toMap,
- allowExisting = false,
+ mode,
+ options,
logicalPlan)
sqlContext.executePlan(cmd).toRdd
}
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*)
- }
-
- override def save(path: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- save(dataSourceName, "path" -> path)
- }
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
- ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this)
- }
-
override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- save(dataSourceName, opts.head, opts.tail:_*)
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
+ ResolvedDataSource(sqlContext, source, mode, options, this)
}
override def insertInto(tableName: String, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index ce0557b881..494e49c131 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedSt
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.sources.SaveMode
import org.apache.spark.sql.types.StructType
-
private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column {
def this(name: String) = this(name match {
@@ -156,29 +156,16 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def saveAsParquetFile(path: String): Unit = err()
- override def saveAsTable(tableName: String): Unit = err()
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
-
override def saveAsTable(
tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
-
- override def save(path: String): Unit = err()
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = err()
override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = err()
override def insertInto(tableName: String, overwrite: Boolean): Unit = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 180f5e765f..39f6c2f4bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -50,7 +50,7 @@ private[spark] object SQLConf {
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
// This is used to set the default data source
- val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource"
+ val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
// Whether to perform eager analysis on a DataFrame.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 97e3777f93..801505bceb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -401,27 +401,173 @@ class SQLContext(@transient val sparkContext: SparkContext)
jsonRDD(json.rdd, samplingRatio);
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame,
+ * using the default data source configured by spark.sql.sources.default.
+ */
@Experimental
def load(path: String): DataFrame = {
val dataSourceName = conf.defaultDataSourceName
- load(dataSourceName, ("path", path))
+ load(path, dataSourceName)
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame,
+ * using the given data source.
+ */
@Experimental
- def load(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): DataFrame = {
- val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap)
+ def load(path: String, source: String): DataFrame = {
+ load(source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame.
+ */
+ @Experimental
+ def load(source: String, options: java.util.Map[String, String]): DataFrame = {
+ load(source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame.
+ */
+ @Experimental
+ def load(source: String, options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, None, source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame,
+ * using the given schema as the schema of the DataFrame.
+ */
@Experimental
def load(
- dataSourceName: String,
+ source: String,
+ schema: StructType,
options: java.util.Map[String, String]): DataFrame = {
- val opts = options.toSeq
- load(dataSourceName, opts.head, opts.tail:_*)
+ load(source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame,
+ * using the given schema as the schema of the DataFrame.
+ */
+ @Experimental
+ def load(
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, Some(schema), source, options)
+ DataFrame(this, LogicalRelation(resolved.relation))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path and returns the corresponding DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ */
+ @Experimental
+ def createExternalTable(tableName: String, path: String): DataFrame = {
+ val dataSourceName = conf.defaultDataSourceName
+ createExternalTable(tableName, path, dataSourceName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source
+ * and returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ path: String,
+ source: String): DataFrame = {
+ createExternalTable(tableName, source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = None,
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = Some(schema),
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index edf8a5be64..e915e0e6a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -309,7 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) =>
+ case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) =>
ExecutedCommand(
CreateTempTableUsing(
tableName, userSpecifiedSchema, provider, opts)) :: Nil
@@ -318,24 +318,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) =>
+ case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) =>
val logicalPlan = sqlContext.parseSql(query)
val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan)
+ CreateTempTableUsingAsSelect(tableName, provider, mode, opts, logicalPlan)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) =>
+ case CreateTableUsingAsLogicalPlan(tableName, provider, true, mode, opts, query) =>
val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, query)
+ CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsLogicalPlan if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
case LogicalDescribeCommand(table, isExtended) =>
val resultPlan = self.sqlContext.executePlan(table).executedPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index c4e14c6c92..f828bcdd65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -29,6 +28,10 @@ import org.apache.spark.sql.types.StructType
private[sql] class DefaultSource
extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
+ private def checkPath(parameters: Map[String, String]): String = {
+ parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
+ }
+
/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
@@ -52,15 +55,30 @@ private[sql] class DefaultSource
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
- val path = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val path = checkPath(parameters)
val filesystemPath = new Path(path)
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- if (fs.exists(filesystemPath)) {
- sys.error(s"path $path already exists.")
+ val doSave = if (fs.exists(filesystemPath)) {
+ mode match {
+ case SaveMode.Append =>
+ sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
+ case SaveMode.Overwrite =>
+ fs.delete(filesystemPath, true)
+ true
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"path $path already exists.")
+ case SaveMode.Ignore => false
+ }
+ } else {
+ true
+ }
+ if (doSave) {
+ // Only save data when the save mode is not ignore.
+ data.toJSON.saveAsTextFile(path)
}
- data.toJSON.saveAsTextFile(path)
createRelation(sqlContext, parameters, data.schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 04804f78f5..aef9c10fbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -80,18 +80,45 @@ class DefaultSource
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val path = checkPath(parameters)
- ParquetRelation.createEmpty(
- path,
- data.schema.toAttributes,
- false,
- sqlContext.sparkContext.hadoopConfiguration,
- sqlContext)
-
- val relation = createRelation(sqlContext, parameters, data.schema)
- relation.asInstanceOf[ParquetRelation2].insert(data, true)
+ val filesystemPath = new Path(path)
+ val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val doSave = if (fs.exists(filesystemPath)) {
+ mode match {
+ case SaveMode.Append =>
+ sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
+ case SaveMode.Overwrite =>
+ fs.delete(filesystemPath, true)
+ true
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"path $path already exists.")
+ case SaveMode.Ignore => false
+ }
+ } else {
+ true
+ }
+
+ val relation = if (doSave) {
+ // Only save data when the save mode is not ignore.
+ ParquetRelation.createEmpty(
+ path,
+ data.schema.toAttributes,
+ false,
+ sqlContext.sparkContext.hadoopConfiguration,
+ sqlContext)
+
+ val createdRelation = createRelation(sqlContext, parameters, data.schema)
+ createdRelation.asInstanceOf[ParquetRelation2].insert(data, true)
+
+ createdRelation
+ } else {
+ // If the save mode is Ignore, we will just create the relation based on existing data.
+ createRelation(sqlContext, parameters)
+ }
+
relation
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 9f64f76100..6487c14b1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -119,11 +119,20 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
throw new DDLException(
"a CREATE TABLE AS SELECT statement does not allow column definitions.")
}
+ // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
+ val mode = if (allowExisting.isDefined) {
+ SaveMode.Ignore
+ } else if (temp.isDefined) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+
CreateTableUsingAsSelect(tableName,
provider,
temp.isDefined,
+ mode,
options,
- allowExisting.isDefined,
query.get)
} else {
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
@@ -133,7 +142,8 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
provider,
temp.isDefined,
options,
- allowExisting.isDefined)
+ allowExisting.isDefined,
+ managedIfNoPath = false)
}
}
)
@@ -264,6 +274,7 @@ object ResolvedDataSource {
def apply(
sqlContext: SQLContext,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
val loader = Utils.getContextOrSparkClassLoader
@@ -277,7 +288,7 @@ object ResolvedDataSource {
val relation = clazz.newInstance match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sqlContext, options, data)
+ dataSource.createRelation(sqlContext, mode, options, data)
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}
@@ -307,28 +318,40 @@ private[sql] case class DescribeCommand(
new MetadataBuilder().putString("comment", "comment of the column").build())())
}
+/**
+ * Used to represent the operation of create table using a data source.
+ * @param tableName
+ * @param userSpecifiedSchema
+ * @param provider
+ * @param temporary
+ * @param options
+ * @param allowExisting If it is true, we will do nothing when the table already exists.
+ * If it is false, an exception will be thrown
+ * @param managedIfNoPath
+ */
private[sql] case class CreateTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
temporary: Boolean,
options: Map[String, String],
- allowExisting: Boolean) extends Command
+ allowExisting: Boolean,
+ managedIfNoPath: Boolean) extends Command
private[sql] case class CreateTableUsingAsSelect(
tableName: String,
provider: String,
temporary: Boolean,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: String) extends Command
private[sql] case class CreateTableUsingAsLogicalPlan(
tableName: String,
provider: String,
temporary: Boolean,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: LogicalPlan) extends Command
private [sql] case class CreateTempTableUsing(
@@ -348,12 +371,13 @@ private [sql] case class CreateTempTableUsing(
private [sql] case class CreateTempTableUsingAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
def run(sqlContext: SQLContext) = {
val df = DataFrame(sqlContext, query)
- val resolved = ResolvedDataSource(sqlContext, provider, options, df)
+ val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df)
sqlContext.registerRDDAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
@@ -364,7 +388,7 @@ private [sql] case class CreateTempTableUsingAsSelect(
/**
* Builds a map in which keys are case insensitive
*/
-protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
+protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {
val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 5eecc303ef..37fda7ba6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -79,8 +79,27 @@ trait SchemaRelationProvider {
@DeveloperApi
trait CreatableRelationProvider {
+ /**
+ * Creates a relation with the given parameters based on the contents of the given
+ * DataFrame. The mode specifies the expected behavior of createRelation when
+ * data already exists.
+ * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
+ * Append mode means that when saving a DataFrame to a data source, if data already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
+ * existing data is expected to be overwritten by the contents of the DataFrame.
+ * ErrorIfExists mode means that when saving a DataFrame to a data source,
+ * if data already exists, an exception is expected to be thrown.
+ *
+ * @param sqlContext
+ * @param mode
+ * @param parameters
+ * @param data
+ * @return
+ */
def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
new file mode 100644
index 0000000000..852baf0e09
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.test.TestSQLContext$;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaSaveLoadSuite {
+
+ private transient JavaSparkContext sc;
+ private transient SQLContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List<Row> expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestSQLContext$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+
+ List<String> jsonObjects = new ArrayList<String>(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD<String> rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @Test
+ public void saveAndLoad() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ }
+
+ @Test
+ public void saveAndLoadWithSchema() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ List<StructField> fields = new ArrayList<>();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f9ddd2ca5c..dfb6858957 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
+import scala.collection.JavaConversions._
+
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
@@ -52,9 +54,51 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(rdd, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
+ /**
+ * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
+ */
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param rdd the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -70,18 +114,20 @@ class QueryTest extends PlanTest {
}
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
|${rdd.logicalPlan}
|== Analyzed Plan ==
@@ -90,37 +136,21 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
- }
-
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
- }
+ return None
}
- /**
- * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
- */
- def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
- val planWithCaching = query.queryExecution.withCachedData
- val cachedData = planWithCaching collect {
- case cached: InMemoryRelation => cached
+ def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(rdd, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
-
- assert(
- cachedData.size == numCachedTables,
- s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
- planWithCaching)
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index b02389978b..29caed9337 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -77,12 +77,10 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())
- dropTempTable("jsonTable")
-
- val message = intercept[RuntimeException]{
+ val message = intercept[DDLException]{
sql(
s"""
- |CREATE TEMPORARY TABLE jsonTable
+ |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
| path '${path.toString}'
@@ -91,10 +89,25 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
""".stripMargin)
}.getMessage
assert(
- message.contains(s"path ${path.toString} already exists."),
+ message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."),
"CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.")
- // Explicitly delete it.
+ // Overwrite the temporary table.
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${path.toString}'
+ |) AS
+ |SELECT a * 4 FROM jt
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT * FROM jsonTable"),
+ sql("SELECT a * 4 FROM jt").collect())
+
+ dropTempTable("jsonTable")
+ // Explicitly delete the data.
if (path.exists()) Utils.deleteRecursively(path)
sql(
@@ -104,12 +117,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
|OPTIONS (
| path '${path.toString}'
|) AS
- |SELECT a * 4 FROM jt
+ |SELECT b FROM jt
""".stripMargin)
checkAnswer(
sql("SELECT * FROM jsonTable"),
- sql("SELECT a * 4 FROM jt").collect())
+ sql("SELECT b FROM jt").collect())
dropTempTable("jsonTable")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index fe2f76cc39..a510045671 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -21,10 +21,10 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.util.Utils
-
import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.{SQLConf, DataFrame}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
@@ -38,42 +38,60 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
path = util.getTempFilePath("datasource").getCanonicalFile
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
df = jsonRDD(rdd)
+ df.registerTempTable("jsonTable")
}
override def afterAll(): Unit = {
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
after {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
if (path.exists()) Utils.deleteRecursively(path)
}
def checkLoad(): Unit = {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
checkAnswer(load(path.toString), df.collect())
- checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect())
+
+ // Test if we can pick up the data source name passed in load.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect())
+ checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect())
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ checkAnswer(
+ load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)),
+ sql("SELECT b FROM jsonTable").collect())
}
- test("save with overwrite and load") {
+ test("save with path and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
df.save(path.toString)
- checkLoad
+ checkLoad()
+ }
+
+ test("save with path and datasource, and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
}
test("save with data source and options, and load") {
- df.save("org.apache.spark.sql.json", ("path", path.toString))
- checkLoad
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString))
+ checkLoad()
}
test("save and save again") {
- df.save(path.toString)
+ df.save(path.toString, "org.apache.spark.sql.json")
- val message = intercept[RuntimeException] {
- df.save(path.toString)
+ var message = intercept[RuntimeException] {
+ df.save(path.toString, "org.apache.spark.sql.json")
}.getMessage
assert(
@@ -82,7 +100,18 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
if (path.exists()) Utils.deleteRecursively(path)
- df.save(path.toString)
- checkLoad
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
+
+ df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString))
+ checkLoad()
+
+ message = intercept[RuntimeException] {
+ df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString))
+ }.getMessage
+
+ assert(
+ message.contains("Append mode is not supported"),
+ "We should complain that 'Append mode is not supported' for JSON source.")
}
} \ No newline at end of file
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2c00659496..7ae6ed6f84 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -80,18 +80,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
/**
- * Creates a table using the schema of the given class.
- *
- * @param tableName The name of the table to create.
- * @param allowExisting When false, an exception will be thrown if the table already exists.
- * @tparam A A case class that is used to describe the schema of the table to be created.
- */
- @Deprecated
- def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) {
- catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting)
- }
-
- /**
* Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
* table, such as the location of blocks. When those change outside of Spark SQL, users should
@@ -107,70 +95,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.invalidateTable("default", tableName)
}
- @Experimental
- def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = {
- val dataSourceName = conf.defaultDataSourceName
- createTable(tableName, dataSourceName, allowExisting, ("path", path))
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = None,
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = Some(schema),
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*)
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*)
- }
-
/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 95abc363ae..cb138be90e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -216,20 +216,21 @@ private[hive] trait HiveStrategies {
object HiveDDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) =>
+ case CreateTableUsing(
+ tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
ExecutedCommand(
CreateMetastoreDataSource(
- tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil
+ tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil
- case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) =>
+ case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) =>
val logicalPlan = hiveContext.parseSql(query)
val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan)
+ CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, logicalPlan)
ExecutedCommand(cmd) :: Nil
- case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) =>
+ case CreateTableUsingAsLogicalPlan(tableName, provider, false, mode, opts, query) =>
val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query)
+ CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case _ => Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 95dcaccefd..f6bea1c6a6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.sources.ResolvedDataSource
+import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -105,7 +107,8 @@ case class CreateMetastoreDataSource(
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String],
- allowExisting: Boolean) extends RunnableCommand {
+ allowExisting: Boolean,
+ managedIfNoPath: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
@@ -120,7 +123,7 @@ case class CreateMetastoreDataSource(
var isExternal = true
val optionsWithPath =
- if (!options.contains("path")) {
+ if (!options.contains("path") && managedIfNoPath) {
isExternal = false
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName))
} else {
@@ -141,22 +144,13 @@ case class CreateMetastoreDataSource(
case class CreateMetastoreDataSourceAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
-
- if (hiveContext.catalog.tableExists(tableName :: Nil)) {
- if (allowExisting) {
- return Seq.empty[Row]
- } else {
- sys.error(s"Table $tableName already exists.")
- }
- }
-
- val df = DataFrame(hiveContext, query)
+ var createMetastoreTable = false
var isExternal = true
val optionsWithPath =
if (!options.contains("path")) {
@@ -166,15 +160,82 @@ case class CreateMetastoreDataSourceAsSelect(
options
}
- // Create the relation based on the data of df.
- ResolvedDataSource(sqlContext, provider, optionsWithPath, df)
+ if (sqlContext.catalog.tableExists(Seq(tableName))) {
+ // Check if we need to throw an exception or just return.
+ mode match {
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"Table $tableName already exists. " +
+ s"If you want to append into it, please set mode to SaveMode.Append. " +
+ s"Or, if you want to overwrite it, please set mode to SaveMode.Overwrite.")
+ case SaveMode.Ignore =>
+ // Since the table already exists and the save mode is Ignore, we will just return.
+ return Seq.empty[Row]
+ case SaveMode.Append =>
+ // Check if the specified data source match the data source of the existing table.
+ val resolved =
+ ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath)
+ val createdRelation = LogicalRelation(resolved.relation)
+ EliminateAnalysisOperators(sqlContext.table(tableName).logicalPlan) match {
+ case l @ LogicalRelation(i: InsertableRelation) =>
+ if (l.schema != createdRelation.schema) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the schema of this " +
+ s"DataFrame does not match the schema of table $tableName."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Schemas ==
+ |${sideBySide(
+ s"== Expected Schema ==" +:
+ l.schema.treeString.split("\\\n"),
+ s"== Actual Schema ==" +:
+ createdRelation.schema.treeString.split("\\\n")).mkString("\n")}
+ """.stripMargin
+ sys.error(errorMessage)
+ } else if (i != createdRelation.relation) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the resolved relation does not " +
+ s"match the existing relation of $tableName. " +
+ s"You can use insertInto($tableName, false) to append this DataFrame to the " +
+ s"table $tableName and using its data source and options."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Relations ==
+ |${sideBySide(
+ s"== Expected Relation ==" ::
+ l.toString :: Nil,
+ s"== Actual Relation ==" ::
+ createdRelation.toString :: Nil).mkString("\n")}
+ """.stripMargin
+ sys.error(errorMessage)
+ }
+ case o =>
+ sys.error(s"Saving data in ${o.toString} is not supported.")
+ }
+ case SaveMode.Overwrite =>
+ hiveContext.sql(s"DROP TABLE IF EXISTS $tableName")
+ // Need to create the table again.
+ createMetastoreTable = true
+ }
+ } else {
+ // The table does not exist. We need to create it in metastore.
+ createMetastoreTable = true
+ }
- hiveContext.catalog.createDataSourceTable(
- tableName,
- None,
- provider,
- optionsWithPath,
- isExternal)
+ val df = DataFrame(hiveContext, query)
+
+ // Create the relation based on the data of df.
+ ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df)
+
+ if (createMetastoreTable) {
+ hiveContext.catalog.createDataSourceTable(
+ tableName,
+ Some(df.schema),
+ provider,
+ optionsWithPath,
+ isExternal)
+ }
Seq.empty[Row]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 7c1d1133c3..840fbc1972 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -20,9 +20,6 @@ package org.apache.spark.sql.hive.test
import java.io.File
import java.util.{Set => JavaSet}
-import scala.collection.mutable
-import scala.language.implicitConversions
-
import org.apache.hadoop.hive.ql.exec.FunctionRegistry
import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat}
import org.apache.hadoop.hive.ql.metadata.Table
@@ -30,16 +27,18 @@ import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.serde2.RegexSerDe
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.hive.serde2.avro.AvroSerDe
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.execution.HiveNativeCommand
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkContext}
+
+import scala.collection.mutable
+import scala.language.implicitConversions
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -224,11 +223,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
}
}),
TestTable("src_thrift", () => {
- import org.apache.thrift.protocol.TBinaryProtocol
- import org.apache.hadoop.hive.serde2.thrift.test.Complex
import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
- import org.apache.hadoop.mapred.SequenceFileInputFormat
- import org.apache.hadoop.mapred.SequenceFileOutputFormat
+ import org.apache.hadoop.hive.serde2.thrift.test.Complex
+ import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat}
+ import org.apache.thrift.protocol.TBinaryProtocol
val srcThrift = new Table("default", "src_thrift")
srcThrift.setFields(Nil)
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
new file mode 100644
index 0000000000..9744a2aa3f
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.spark.sql.sources.SaveMode;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.QueryTest$;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.hive.test.TestHive$;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaMetastoreDataSourcesSuite {
+ private transient JavaSparkContext sc;
+ private transient HiveContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ Path hiveManagedPath;
+ FileSystem fs;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List<Row> expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestHive$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+ hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable"));
+ fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration());
+ if (fs.exists(hiveManagedPath)){
+ fs.delete(hiveManagedPath, true);
+ }
+
+ List<String> jsonObjects = new ArrayList<String>(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD<String> rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Clean up tables.
+ sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable");
+ sqlContext.sql("DROP TABLE IF EXISTS externalTable");
+ }
+
+ @Test
+ public void saveExternalTableAndQueryIt() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ df.collectAsList());
+ }
+
+ @Test
+ public void saveExternalTableWithSchemaAndQueryIt() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ List<StructField> fields = new ArrayList<>();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(
+ loadedDF,
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ }
+
+ @Test
+ public void saveTableAndQueryIt() {
+ Map<String, String> options = new HashMap<String, String>();
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ba39129388..0270e63557 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
+import scala.collection.JavaConversions._
-import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
@@ -55,9 +53,36 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(rdd, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param rdd the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -73,18 +98,20 @@ class QueryTest extends PlanTest {
}
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
|${rdd.logicalPlan}
|== Analyzed Plan ==
@@ -93,22 +120,21 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
+ return None
}
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(rdd, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
}
-
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 869d01eb39..43da7519ac 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,7 +19,11 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.scalatest.BeforeAndAfter
+
import com.google.common.io.Files
+
+import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.types._
@@ -29,15 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._
case class TestData(key: Int, value: String)
-class InsertIntoHiveTableSuite extends QueryTest {
+class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val testData = TestHive.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString)))
- testData.registerTempTable("testData")
+
+ before {
+ // Since every we are doing tests for DDL statements,
+ // it is better to reset before every test.
+ TestHive.reset()
+ // Register the testData, which will be used in every test.
+ testData.registerTempTable("testData")
+ }
test("insertInto() HiveTable") {
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE createAndInsertTest (key int, value string)")
// Add some data.
testData.insertInto("createAndInsertTest")
@@ -68,16 +79,18 @@ class InsertIntoHiveTableSuite extends QueryTest {
}
test("Double create fails when allowExisting = false") {
- createTable[TestData]("doubleCreateAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
- intercept[org.apache.hadoop.hive.ql.metadata.HiveException] {
- createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false)
- }
+ val message = intercept[QueryExecutionException] {
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ }.getMessage
+
+ println("message!!!!" + message)
}
test("Double create does not fail when allowExisting = true") {
- createTable[TestData]("createAndInsertTest")
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)")
}
test("SPARK-4052: scala.collection.Map as value type of MapType") {
@@ -98,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
}
test("SPARK-4203:random partition directory order") {
- createTable[TestData]("tmp_table")
+ sql("CREATE TABLE tmp_table (key int, value string)")
val tmpDir = Files.createTempDir()
sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ")
sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 9ce058909f..f94aabd29a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.apache.spark.sql.sources.SaveMode
import org.scalatest.BeforeAndAfterEach
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.InvalidInputException
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql._
@@ -41,11 +43,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
override def afterEach(): Unit = {
reset()
- if (ctasPath.exists()) Utils.deleteRecursively(ctasPath)
+ if (tempPath.exists()) Utils.deleteRecursively(tempPath)
}
val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile
- var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
+ var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
test ("persistent JSON table") {
sql(
@@ -270,7 +272,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -297,7 +299,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -309,7 +311,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -325,7 +327,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE IF NOT EXISTS ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT a FROM jsonTable
""".stripMargin)
@@ -400,38 +402,122 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
sql("DROP TABLE jsonTable").collect().foreach(println)
}
- test("save and load table") {
+ test("save table") {
val originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
val df = jsonRDD(rdd)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ // Save the df as a managed table (by not specifiying the path).
df.saveAsTable("savedJsonTable")
checkAnswer(
sql("SELECT * FROM savedJsonTable"),
df.collect())
- createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false)
+ // Right now, we cannot append to an existing JSON table.
+ intercept[RuntimeException] {
+ df.saveAsTable("savedJsonTable", SaveMode.Append)
+ }
+
+ // We can overwrite it.
+ df.saveAsTable("savedJsonTable", SaveMode.Overwrite)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // When the save mode is Ignore, we will do nothing when the table already exists.
+ df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore)
+ assert(df.schema === table("savedJsonTable").schema)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Drop table will also delete the data.
+ sql("DROP TABLE savedJsonTable")
+ intercept[InvalidInputException] {
+ jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable"))
+ }
+
+ // Create an external table by specifying the path.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Data should not be deleted after we drop the table.
+ sql("DROP TABLE savedJsonTable")
+ checkAnswer(
+ jsonFile(tempPath.toString),
+ df.collect())
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ }
+
+ test("create external table") {
+ val originalDefaultSource = conf.defaultDataSourceName
+
+ val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ val df = jsonRDD(rdd)
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ createExternalTable("createdJsonTable", tempPath.toString)
assert(table("createdJsonTable").schema === df.schema)
checkAnswer(
sql("SELECT * FROM createdJsonTable"),
df.collect())
- val message = intercept[RuntimeException] {
- createTable("createdJsonTable", filePath.toString, false)
+ var message = intercept[RuntimeException] {
+ createExternalTable("createdJsonTable", filePath.toString)
}.getMessage
assert(message.contains("Table createdJsonTable already exists."),
"We should complain that ctasJsonTable already exists")
- createTable("createdJsonTable", filePath.toString, true)
- // createdJsonTable should be not changed.
- assert(table("createdJsonTable").schema === df.schema)
+ // Data should not be deleted.
+ sql("DROP TABLE createdJsonTable")
checkAnswer(
- sql("SELECT * FROM createdJsonTable"),
+ jsonFile(tempPath.toString),
df.collect())
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ // Try to specify the schema.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM createdJsonTable"),
+ sql("SELECT b FROM savedJsonTable").collect())
+
+ sql("DROP TABLE createdJsonTable")
+
+ message = intercept[RuntimeException] {
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map.empty[String, String])
+ }.getMessage
+ assert(
+ message.contains("Option 'path' not specified"),
+ "We should complain that path is not specified.")
+
+ sql("DROP TABLE savedJsonTable")
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
}