diff options
Diffstat (limited to 'sql/hive-thriftserver')
-rw-r--r-- | sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index aeabd6a158..517b01f183 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -50,8 +50,13 @@ private[hive] class SparkExecuteStatementOperation( with Logging { private var result: DataFrame = _ + + // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. + // This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`. + // In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution. + private var resultList: Option[Array[SparkRow]] = _ + private var iter: Iterator[SparkRow] = _ - private var iterHeader: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private var statementId: String = _ @@ -111,9 +116,15 @@ private[hive] class SparkExecuteStatementOperation( // Reset iter to header when fetching start from first row if (order.equals(FetchOrientation.FETCH_FIRST)) { - val (ita, itb) = iterHeader.duplicate - iter = ita - iterHeader = itb + iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { + resultList = None + result.toLocalIterator.asScala + } else { + if (resultList.isEmpty) { + resultList = Some(result.collect()) + } + resultList.get.iterator + } } if (!iter.hasNext) { @@ -227,17 +238,14 @@ private[hive] class SparkExecuteStatementOperation( } HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) iter = { - val useIncrementalCollect = - sqlContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { + if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { + resultList = None result.toLocalIterator.asScala } else { - result.collect().iterator + resultList = Some(result.collect()) + resultList.get.iterator } } - val (itra, itrb) = iter.duplicate - iterHeader = itra - iter = itrb dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray } catch { case e: HiveSQLException => |