diff options
14 files changed, 429 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index cc650128c2..36758f3114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -41,10 +41,25 @@ import org.apache.spark.sql.catalyst.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): LogicalPlan = { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) + // Special-case out set commands since the value fields can be + // complex to handle without RegexParsers. Also this approach + // is clearer for the several possible cases of set commands. + if (input.trim.toLowerCase.startsWith("set")) { + input.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } + } else { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } } } @@ -169,11 +184,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = + protected lazy val query: Parser[LogicalPlan] = ( select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) | insert + UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | + UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = SELECT ~> opt(DISTINCT) ~ projections ~ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f641cd3a6..7eeb98aea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -102,7 +102,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { */ abstract class Command extends LeafNode { self: Product => - def output: Seq[Attribute] = Seq.empty + def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this } /** @@ -112,6 +112,16 @@ abstract class Command extends LeafNode { case class NativeCommand(cmd: String) extends Command /** + * Commands of the form "SET (key) (= value)". + */ +case class SetCommand(key: Option[String], value: Option[String]) extends Command { + override def output = Seq( + AttributeReference("key", StringType, nullable = false)(), + AttributeReference("value", StringType, nullable = false)() + ) +} + +/** * Returned by a parser when the users only wants to see what query plan would be executed, without * actually performing the execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala new file mode 100644 index 0000000000..b378252ba2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Properties + +import scala.collection.JavaConverters._ + +/** + * SQLConf holds mutable config parameters and hints. These can be set and + * queried either by passing SET commands into Spark SQL's DSL + * functions (sql(), hql(), etc.), or by programmatically using setters and + * getters of this class. This class is thread-safe. + */ +trait SQLConf { + + /** Number of partitions to use for shuffle operators. */ + private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + + @transient + private val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + def set(props: Properties): Unit = { + props.asScala.foreach { case (k, v) => this.settings.put(k, v) } + } + + def set(key: String, value: String): Unit = { + require(key != null, "key cannot be null") + require(value != null, s"value cannot be null for ${key}") + settings.put(key, value) + } + + def get(key: String): String = { + if (!settings.containsKey(key)) { + throw new NoSuchElementException(key) + } + settings.get(key) + } + + def get(key: String, defaultValue: String): String = { + if (!settings.containsKey(key)) defaultValue else settings.get(key) + } + + def getAll: Array[(String, String)] = settings.asScala.toArray + + def getOption(key: String): Option[String] = { + if (!settings.containsKey(key)) None else Some(settings.get(key)) + } + + def contains(key: String): Boolean = settings.containsKey(key) + + def toDebugString: String = { + settings.synchronized { + settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n") + } + } + + private[spark] def clear() { + settings.clear() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index fde4c485b5..021e0e8245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryColumnarTableScan @@ -52,6 +52,7 @@ import org.apache.spark.sql.parquet.ParquetRelation @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends Logging + with SQLConf with dsl.ExpressionConversions with Serializable { @@ -190,6 +191,8 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext + def numPartitions = self.numShufflePartitions + val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: @@ -246,6 +249,10 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val planner = new SparkPlanner + @transient + protected[sql] lazy val emptyResult = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + /** * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and * inserting shuffle operations as needed. @@ -253,15 +260,10 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange) :: + Batch("Add exchange", Once, AddExchange(self)) :: Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil } - // TODO: or should we make QueryExecution protected[sql]? - protected[sql] def mkQueryExecution(plan: LogicalPlan) = new QueryExecution { - val logical = plan - } - /** * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. @@ -269,6 +271,22 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan + def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { + case SetCommand(key, value) => + // Only this case needs to be executed eagerly. The other cases will + // be taken care of when the actual results are being extracted. + // In the case of HiveContext, sqlConf is overridden to also pass the + // pair into its HiveConf. + if (key.isDefined && value.isDefined) { + set(key.get, value.get) + } + // It doesn't matter what we return here, since this is only used + // to force the evaluation to happen eagerly. To query the results, + // one must use SchemaRDD operations to extract them. + emptyResult + case _ => executedPlan.execute() + } + lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -276,7 +294,12 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = executedPlan.execute() + lazy val toRdd: RDD[Row] = { + logical match { + case s: SetCommand => eagerlyProcess(s) + case _ => executedPlan.execute() + } + } protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3b4acb72e8..cef294167f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLConf, SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ @@ -86,9 +86,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting * [[Exchange]] Operators where required. */ -private[sql] object AddExchange extends Rule[SparkPlan] { +private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - val numPartitions = 150 + def numPartitions = sqlContext.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 295c265b16..0455748d40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLConf, SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -193,8 +193,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - // TODO: Set - val numPartitions = 200 + def numPartitions = self.numPartitions + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => execution.Aggregate( @@ -234,11 +234,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // TODO: this should be merged with SPARK-1508's SetCommandStrategy case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.SetCommand(key, value) => + Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) case logical.ExplainCommand(child) => - val qe = context.mkQueryExecution(child) + val qe = context.executePlan(child) Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context)) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5371d2f479..9364506691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,10 +17,45 @@ package org.apache.spark.sql.execution +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) + (@transient context: SQLContext) extends LeafNode { + def execute(): RDD[Row] = (key, value) match { + // Set value for key k; the action itself would + // have been performed in QueryExecution eagerly. + case (Some(k), Some(v)) => context.emptyResult + // Query the value bound to key k. + case (Some(k), None) => + val resultString = context.getOption(k) match { + case Some(v) => s"$k=$v" + case None => s"$k is undefined" + } + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + // Query all key-value pairs that are set in the SQLConf of the context. + case (None, None) => + val pairs = context.getAll + val rows = pairs.map { case (k, v) => + new GenericRow(Array[Any](s"$k=$v")) + }.toSeq + // Assume config parameters can fit into one split (machine) ;) + context.sparkContext.parallelize(rows, 1) + // The only other case is invalid semantics and is impossible. + case _ => context.emptyResult + } +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) (@transient context: SQLContext) extends UnaryNode { def execute(): RDD[Row] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala new file mode 100644 index 0000000000..5eb73a4eff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -0,0 +1,71 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class SQLConfSuite extends QueryTest { + + val testKey = "test.key.0" + val testVal = "test.val.0" + + test("programmatic ways of basic setting and getting") { + assert(getOption(testKey).isEmpty) + assert(getAll.toSet === Set()) + + set(testKey, testVal) + assert(get(testKey) == testVal) + assert(get(testKey, testVal + "_") == testVal) + assert(getOption(testKey) == Some(testVal)) + assert(contains(testKey)) + + // Tests SQLConf as accessed from a SQLContext is mutable after + // the latter is initialized, unlike SparkConf inside a SparkContext. + assert(TestSQLContext.get(testKey) == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getOption(testKey) == Some(testVal)) + assert(TestSQLContext.contains(testKey)) + + clear() + } + + test("parse SQL set commands") { + sql(s"set $testKey=$testVal") + assert(get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + + sql("set mapred.reduce.tasks=20") + assert(get("mapred.reduce.tasks", "0") == "20") + sql("set mapred.reduce.tasks = 40") + assert(get("mapred.reduce.tasks", "0") == "40") + + val key = "spark.sql.key" + val vs = "val0,val_1,val2.3,my_table" + sql(s"set $key=$vs") + assert(get(key, "0") == vs) + + sql(s"set $key=") + assert(get(key, "0") == "") + + clear() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d651b967a6..f2d850ad6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -361,6 +361,41 @@ class SQLQuerySuite extends QueryTest { (1, "abc"), (2, "abc"), (3, null))) - } - + } + + test("SET commands semantics using sql()") { + clear() + val testKey = "test.key.0" + val testVal = "test.val.0" + val nonexistentKey = "nonexistent" + + // "set" itself returns all config variables currently specified in SQLConf. + assert(sql("SET").collect().size == 0) + + // "set key=val" + sql(s"SET $testKey=$testVal") + checkAnswer( + sql("SET"), + Seq(Seq(s"$testKey=$testVal")) + ) + + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + checkAnswer( + sql("set"), + Seq( + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) + ) + + // "set key" + checkAnswer( + sql(s"SET $testKey"), + Seq(Seq(s"$testKey=$testVal")) + ) + checkAnswer( + sql(s"SET $nonexistentKey"), + Seq(Seq(s"$nonexistentKey is undefined")) + ) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c563d63627..df6b118360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -30,8 +30,8 @@ class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u} - val physicalUnions = planned collect { case u: execution.Union => u} + val logicalUnions = query collect { case u: logical.Union => u } + val physicalUnions = planned collect { case u: execution.Union => u } assert(logicalUnions.size === 2) assert(physicalUnions.size === 1) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4b97dc25ac..6497821554 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql package hive -import scala.language.implicitConversions - import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.util.{ArrayList => JArrayList} +import scala.collection.JavaConversions._ +import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.hive.conf.HiveConf @@ -30,20 +30,15 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} -import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand, ExplainCommand} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. @@ -55,10 +50,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - // TODO: refactor this so we can work with other databases. - runSqlHive( - s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true") - runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath) + set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + set("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. @@ -129,12 +123,27 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + /** + * SQLConf and HiveConf contracts: when the hive session is first initialized, params in + * HiveConf will get picked up by the SQLConf. Additionally, any properties set by + * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as* + * in the HiveConf. + */ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = new SessionState(hiveconf) + @transient protected[hive] lazy val sessionState = { + val ss = new SessionState(hiveconf) + set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + ss + } sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + override def set(key: String, value: String): Unit = { + super.set(key, value) + runSqlHive(s"SET $key=$value") + } + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { @@ -236,30 +245,31 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient override protected[sql] val planner = hivePlanner - @transient - protected lazy val emptyResult = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = - analyzed match { - case NativeCommand(cmd) => - val output = runSqlHive(cmd) + override lazy val toRdd: RDD[Row] = { + def processCmd(cmd: String): RDD[Row] = { + val output = runSqlHive(cmd) + if (output.size == 0) { + emptyResult + } else { + val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) + sparkContext.parallelize(asRows, 1) + } + } - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - case _ => - executedPlan.execute().map(_.copy()) + logical match { + case s: SetCommand => eagerlyProcess(s) + case _ => analyzed match { + case NativeCommand(cmd) => processCmd(cmd) + case _ => executedPlan.execute().map(_.copy()) + } } + } protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -305,7 +315,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ def stringResult(): Seq[String] = analyzed match { case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => mkQueryExecution(plan).toString.split("\n") + case ExplainCommand(plan) => executePlan(plan).toString.split("\n") case query => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names @@ -318,6 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { case _: NativeCommand => "<Executed by Hive>" + case _: SetCommand => "<Set Command: Executed by Hive, and noted by SQLContext>" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index cc9e24a057..4e74d9bc90 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -207,8 +207,17 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { try { - if (sql.toLowerCase.startsWith("set")) { - NativeCommand(sql) + if (sql.trim.toLowerCase.startsWith("set")) { + // Split in two parts since we treat the part before the first "=" + // as key, and the part after as value, which may contain other "=" signs. + sql.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } } else if (sql.toLowerCase.startsWith("add jar")) { AddJar(sql.drop(8)) } else if (sql.toLowerCase.startsWith("add file")) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0f954103a8..357c7e654b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,6 +138,9 @@ abstract class HiveComparisonTest val orderedAnswer = hiveQuery.logical match { // Clean out non-deterministic time schema info. + // Hack: Hive simply prints the result of a SET command to screen, + // and does not return it as a query answer. + case _: SetCommand => Seq("0") case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer case plan => if (isSorted(plan)) answer else answer.sorted diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c56eee2580..6c239b02ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive @@ -171,4 +172,78 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.reset() } + test("parse HQL set commands") { + // Adapted from its SQL counterpart. + val testKey = "spark.sql.key.usedfortestonly" + val testVal = "val0,val_1,val2.3,my_table" + + hql(s"set $testKey=$testVal") + assert(get(testKey, testVal + "_") == testVal) + + hql("set mapred.reduce.tasks=20") + assert(get("mapred.reduce.tasks", "0") == "20") + hql("set mapred.reduce.tasks = 40") + assert(get("mapred.reduce.tasks", "0") == "40") + + hql(s"set $testKey=$testVal") + assert(get(testKey, "0") == testVal) + + hql(s"set $testKey=") + assert(get(testKey, "0") == "") + } + + test("SET commands semantics for a HiveContext") { + // Adapted from its SQL counterpart. + val testKey = "spark.sql.key.usedfortestonly" + var testVal = "test.val.0" + val nonexistentKey = "nonexistent" + def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + + clear() + + // "set" itself returns all config variables currently specified in SQLConf. + assert(hql("set").collect().size == 0) + + // "set key=val" + hql(s"SET $testKey=$testVal") + assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) + assert(hiveconf.get(testKey, "") == testVal) + + hql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(fromRows(hql("SET").collect()) sameElements + Array( + s"$testKey=$testVal", + s"${testKey + testKey}=${testVal + testVal}")) + assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + + // "set key" + assert(fromRows(hql(s"SET $testKey").collect()) sameElements + Array(s"$testKey=$testVal")) + assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements + Array(s"$nonexistentKey is undefined")) + + // Assert that sql() should have the same effects as hql() by repeating the above using sql(). + clear() + assert(sql("set").collect().size == 0) + + sql(s"SET $testKey=$testVal") + assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) + assert(hiveconf.get(testKey, "") == testVal) + + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(fromRows(sql("SET").collect()) sameElements + Array( + s"$testKey=$testVal", + s"${testKey + testKey}=${testVal + testVal}")) + assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + + assert(fromRows(sql(s"SET $testKey").collect()) sameElements + Array(s"$testKey=$testVal")) + assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements + Array(s"$nonexistentKey is undefined")) + } + + // Put tests that depend on specific Hive settings before these last two test, + // since they modify /clear stuff. + } |