From 8de038eb366ded2ac74f72517e40545dbbab8cdd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 4 Apr 2014 21:15:33 -0700 Subject: [SQL] SPARK-1366 Consistent sql function across different types of SQLContexts Now users who want to use HiveQL should explicitly say `hiveql` or `hql`. Author: Michael Armbrust Closes #319 from marmbrus/standardizeSqlHql and squashes the following commits: de68d0e [Michael Armbrust] Fix sampling test. fbe4a54 [Michael Armbrust] Make `sql` always use spark sql parser, users of hive context can now use hql or hiveql to run queries using HiveQL instead. --- .../apache/spark/sql/examples/HiveFromSpark.scala | 12 +++--- .../org/apache/spark/sql/hive/HiveContext.scala | 17 ++++---- .../scala/org/apache/spark/sql/hive/TestHive.scala | 12 +++--- .../sql/hive/execution/HiveComparisonTest.scala | 10 ++--- .../spark/sql/hive/execution/HiveQuerySuite.scala | 12 +++++- .../sql/hive/execution/HiveResolutionSuite.scala | 2 +- .../spark/sql/hive/execution/PruningSuite.scala | 2 +- .../spark/sql/parquet/HiveParquetSuite.scala | 46 +++++++++++----------- 8 files changed, 63 insertions(+), 50 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala index abcc1f04d4..62329bde84 100644 --- a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala @@ -33,20 +33,20 @@ object HiveFromSpark { val hiveContext = new LocalHiveContext(sc) import hiveContext._ - sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - sql("SELECT * FROM src").collect.foreach(println) + hql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0) + val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") val rddAsStrings = rddFromSql.map { @@ -59,6 +59,6 @@ object HiveFromSpark { // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") - sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ff8eaacded..f66a667c0a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -67,14 +67,13 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => - override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql) - override def executePlan(plan: LogicalPlan): this.QueryExecution = + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } /** * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. */ - def hql(hqlQuery: String): SchemaRDD = { + def hiveql(hqlQuery: String): SchemaRDD = { val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) // We force query optimization to happen right away instead of letting it happen lazily like // when using the query DSL. This is so DDL commands behave as expected. This is only @@ -83,6 +82,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { result } + /** An alias for `hiveql`. */ + def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient protected val outputBuffer = new java.io.OutputStream { @@ -120,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { + override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { override def lookupRelation( databaseName: Option[String], tableName: String, @@ -132,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* An analyzer that uses the Hive metastore. */ @transient - override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + override protected[sql] lazy val analyzer = + new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -214,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } @transient - override val planner = hivePlanner + override protected[sql] val planner = hivePlanner @transient protected lazy val emptyResult = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) /** Extends QueryExecution with hive specific features. */ - abstract class QueryExecution extends super.QueryExecution { + protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 0a6bea0162..2fea970295 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - class SqlQueryExecution(sql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(sql) - def hiveExec() = runSqlHive(sql) - override def toString = sql + "\n" + super.toString + protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { + lazy val logical = HiveQl.parseSql(hql) + def hiveExec() = runSqlHive(hql) + override def toString = hql + "\n" + super.toString } /** @@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { case class TestTable(name: String, commands: (()=>Unit)*) - implicit class SqlCmd(sql: String) { - def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit + protected[hive] implicit class SqlCmd(sql: String) { + def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 18654b308d..3cc4562a88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -125,7 +125,7 @@ abstract class HiveComparisonTest } protected def prepareAnswer( - hiveQuery: TestHive.type#SqlQueryExecution, + hiveQuery: TestHive.type#HiveQLQueryExecution, answer: Seq[String]): Seq[String] = { val orderedAnswer = hiveQuery.logical match { // Clean out non-deterministic time schema info. @@ -227,7 +227,7 @@ abstract class HiveComparisonTest try { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") + TestHive.hql("SHOW TABLES") if (reset) { TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { @@ -256,7 +256,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_)) + val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { @@ -302,7 +302,7 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.SqlQueryExecution(queryString) + val query = new TestHive.HiveQLQueryExecution(queryString) try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Exception => val errorMessage = @@ -359,7 +359,7 @@ abstract class HiveComparisonTest // When we encounter an error we check to see if the environment is still okay by running a simple query. // If this fails then we halt testing since something must have gone seriously wrong. try { - new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult() + new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c184ebe288..0c27498a93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._ * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + + test("Query expressed in SQL") { + assert(sql("SELECT 1").collect() === Array(Seq(1))) + } + + test("Query expressed in HiveQL") { + hql("FROM src SELECT key").collect() + hiveql("FROM src SELECT key").collect() + } + createQueryTest("Simple Average", "SELECT AVG(key) FROM src") @@ -133,7 +143,7 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") test("sampling") { - sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 40c4e23f90..8883e5b16d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil) .registerAsTable("caseSensitivityTest") - sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 1318ac1968..d9ccb93e23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]) = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.SqlQueryExecution(sql).executedPlan + val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 314ca48ad8..aade62eb8f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -57,34 +57,34 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft } test("SELECT on Parquet table") { - val rdd = sql("SELECT * FROM testsource").collect() + val rdd = hql("SELECT * FROM testsource").collect() assert(rdd != null) assert(rdd.forall(_.size == 6)) } test("Simple column projection + filter on Parquet table") { - val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() + val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() assert(rdd.size === 5, "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } test("Converting Hive to Parquet Table via saveAsParquetFile") { - sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) + hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") - val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0)) - val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) + val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) } test("INSERT OVERWRITE TABLE Parquet table") { - sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) + hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - val rddCopy = sql("SELECT * FROM ptable").collect() - val rddOrig = sql("SELECT * FROM testsource").collect() + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + val rddCopy = hql("SELECT * FROM ptable").collect() + val rddOrig = hql("SELECT * FROM testsource").collect() assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } @@ -93,13 +93,13 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) .registerAsTable("tmp") val rddCopy = - sql("INSERT INTO TABLE tmp SELECT * FROM src") + hql("INSERT INTO TABLE tmp SELECT * FROM src") .collect() .sortBy[Int](_.apply(0) match { case x: Int => x case _ => 0 }) - val rddOrig = sql("SELECT * FROM src") + val rddOrig = hql("SELECT * FROM src") .collect() .sortBy(_.getInt(0)) compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String")) @@ -108,22 +108,22 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Appending to Parquet table") { createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) .registerAsTable("tmpnew") - sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() - sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() - sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() - val rddCopies = sql("SELECT * FROM tmpnew").collect() - val rddOrig = sql("SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + val rddCopies = hql("SELECT * FROM tmpnew").collect() + val rddOrig = hql("SELECT * FROM src").collect() assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number") } test("Appending to and then overwriting Parquet table") { createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) .registerAsTable("tmp") - sql("INSERT INTO TABLE tmp SELECT * FROM src").collect() - sql("INSERT INTO TABLE tmp SELECT * FROM src").collect() - sql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect() - val rddCopies = sql("SELECT * FROM tmp").collect() - val rddOrig = sql("SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmp SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmp SELECT * FROM src").collect() + hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect() + val rddCopies = hql("SELECT * FROM tmp").collect() + val rddOrig = hql("SELECT * FROM src").collect() assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite") } -- cgit v1.2.3