aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorZhan Zhang <zhazhan@gmail.com>2015-05-18 12:03:27 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-18 12:03:40 -0700
commitaa31e431fc09f0477f1c2351c6275769a31aca90 (patch)
tree58ec159706a2dc7703b1eeba5d466b92a20e3147 /sql
parent9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (diff)
downloadspark-aa31e431fc09f0477f1c2351c6275769a31aca90.tar.gz
spark-aa31e431fc09f0477f1c2351c6275769a31aca90.tar.bz2
spark-aa31e431fc09f0477f1c2351c6275769a31aca90.zip
[SPARK-2883] [SQL] ORC data source for Spark SQL
This PR updates PR #6135 authored by zhzhan from Hortonworks. ---- This PR implements a Spark SQL data source for accessing ORC files. > **NOTE** > > Although ORC is now an Apache TLP, the codebase is still tightly coupled with Hive. That's why the new ORC data source is under `org.apache.spark.sql.hive` package, and must be used with `HiveContext`. However, it doesn't require existing Hive installation to access ORC files. 1. Saving/loading ORC files without contacting Hive metastore 1. Support for complex data types (i.e. array, map, and struct) 1. Aware of common optimizations provided by Spark SQL: - Column pruning - Partitioning pruning - Filter push-down 1. Schema evolution support 1. Hive metastore table conversion This PR also include initial work done by scwf from Huawei (PR #3753). Author: Zhan Zhang <zhazhan@gmail.com> Author: Cheng Lian <lian@databricks.com> Closes #6194 from liancheng/polishing-orc and squashes the following commits: 55ecd96 [Cheng Lian] Reorganizes ORC test suites d4afeed [Cheng Lian] Addresses comments 21ada22 [Cheng Lian] Adds @since and @Experimental annotations 128bd3b [Cheng Lian] ORC filter bug fix d734496 [Cheng Lian] Polishes the ORC data source 2650a42 [Zhan Zhang] resolve review comments 3c9038e [Zhan Zhang] resolve review comments 7b3c7c5 [Zhan Zhang] save mode fix f95abfd [Zhan Zhang] reuse test suite 7cc2c64 [Zhan Zhang] predicate fix 4e61c16 [Zhan Zhang] minor change 305418c [Zhan Zhang] orc data source support
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala81
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala40
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala69
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala144
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala290
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala59
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala256
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala294
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala146
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala82
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala6
14 files changed, 1477 insertions, 76 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index f07bb196c1..6da910e332 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -43,6 +43,8 @@ private[spark] object SQLConf {
val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown"
val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi"
+ val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown"
+
val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath"
val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord"
@@ -143,6 +145,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def parquetUseDataSourceApi =
getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean
+ private[spark] def orcFilterPushDown =
+ getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean
+
/** When true uses verifyPartitionPath to prune the path which is not exists. */
private[spark] def verifyPartitionPath =
getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean
@@ -254,7 +259,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def dataFrameRetainGroupColumns: Boolean =
getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
-
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 7a73b6f1ac..516ba373f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -21,10 +21,9 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import scala.util.Try
-import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.{DataFrame, SaveMode}
/**
* A helper trait that provides convenient facilities for Parquet testing.
@@ -33,54 +32,9 @@ import org.apache.spark.util.Utils
* convenient to use tuples rather than special case classes when writing test cases/suites.
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
-private[sql] trait ParquetTest {
- val sqlContext: SQLContext
-
+private[sql] trait ParquetTest extends SQLTestUtils {
import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder}
- import sqlContext.{conf, sparkContext}
-
- protected def configuration = sparkContext.hadoopConfiguration
-
- /**
- * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
- * configurations.
- *
- * @todo Probably this method should be moved to a more general place
- */
- protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
- val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
- (keys, values).zipped.foreach(conf.setConf)
- try f finally {
- keys.zip(currentValues).foreach {
- case (key, Some(value)) => conf.setConf(key, value)
- case (key, None) => conf.unsetConf(key)
- }
- }
- }
-
- /**
- * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
- * a file/directory is created there by `f`, it will be delete after `f` returns.
- *
- * @todo Probably this method should be moved to a more general place
- */
- protected def withTempPath(f: File => Unit): Unit = {
- val path = Utils.createTempDir()
- path.delete()
- try f(path) finally Utils.deleteRecursively(path)
- }
-
- /**
- * Creates a temporary directory, which is then passed to `f` and will be deleted after `f`
- * returns.
- *
- * @todo Probably this method should be moved to a more general place
- */
- protected def withTempDir(f: File => Unit): Unit = {
- val dir = Utils.createTempDir().getCanonicalFile
- try f(dir) finally Utils.deleteRecursively(dir)
- }
+ import sqlContext.sparkContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
@@ -106,13 +60,6 @@ private[sql] trait ParquetTest {
}
/**
- * Drops temporary table `tableName` after calling `f`.
- */
- protected def withTempTable(tableName: String)(f: => Unit): Unit = {
- try f finally sqlContext.dropTempTable(tableName)
- }
-
- /**
* Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a
* temporary table named `tableName`, then call `f`. The temporary table together with the
* Parquet file will be dropped/deleted after `f` returns.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 37a569db31..a13ab74852 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -188,18 +188,20 @@ private[sql] class DDLParser(
private[sql] object ResolvedDataSource {
private val builtinSources = Map(
- "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
- "json" -> classOf[org.apache.spark.sql.json.DefaultSource],
- "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
+ "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
+ "json" -> "org.apache.spark.sql.json.DefaultSource",
+ "parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
+ "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
)
/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
+ val loader = Utils.getContextOrSparkClassLoader
+
if (builtinSources.contains(provider)) {
- return builtinSources(provider)
+ return loader.loadClass(builtinSources(provider))
}
- val loader = Utils.getContextOrSparkClassLoader
try {
loader.loadClass(provider)
} catch {
@@ -208,7 +210,11 @@ private[sql] object ResolvedDataSource {
loader.loadClass(provider + ".DefaultSource")
} catch {
case cnf: java.lang.ClassNotFoundException =>
- sys.error(s"Failed to load class for data source: $provider")
+ if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
+ sys.error("The ORC data source must be used with Hive support enabled.")
+ } else {
+ sys.error(s"Failed to load class for data source: $provider")
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala
new file mode 100644
index 0000000000..75d290625e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.io.File
+
+import scala.util.Try
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
+
+trait SQLTestUtils {
+ val sqlContext: SQLContext
+
+ import sqlContext.{conf, sparkContext}
+
+ protected def configuration = sparkContext.hadoopConfiguration
+
+ /**
+ * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
+ * configurations.
+ *
+ * @todo Probably this method should be moved to a more general place
+ */
+ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val (keys, values) = pairs.unzip
+ val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
+ (keys, values).zipped.foreach(conf.setConf)
+ try f finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => conf.setConf(key, value)
+ case (key, None) => conf.unsetConf(key)
+ }
+ }
+ }
+
+ /**
+ * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
+ * a file/directory is created there by `f`, it will be delete after `f` returns.
+ *
+ * @todo Probably this method should be moved to a more general place
+ */
+ protected def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path) finally Utils.deleteRecursively(path)
+ }
+
+ /**
+ * Creates a temporary directory, which is then passed to `f` and will be deleted after `f`
+ * returns.
+ *
+ * @todo Probably this method should be moved to a more general place
+ */
+ protected def withTempDir(f: File => Unit): Unit = {
+ val dir = Utils.createTempDir().getCanonicalFile
+ try f(dir) finally Utils.deleteRecursively(dir)
+ }
+
+ /**
+ * Drops temporary table `tableName` after calling `f`.
+ */
+ protected def withTempTable(tableName: String)(f: => Unit): Unit = {
+ try f finally sqlContext.dropTempTable(tableName)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 7c7666f6e4..0a694c70e4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
-import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
@@ -122,7 +122,7 @@ import scala.collection.JavaConversions._
* even a normal java object (POJO)
* UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet)
*
- * 3) ConstantObjectInspector:
+ * 3) ConstantObjectInspector:
* Constant object inspector can be either primitive type or Complex type, and it bundles a
* constant value as its property, usually the value is created when the constant object inspector
* constructed.
@@ -133,7 +133,7 @@ import scala.collection.JavaConversions._
}
}}}
* Hive provides 3 built-in constant object inspectors:
- * Primitive Object Inspectors:
+ * Primitive Object Inspectors:
* WritableConstantStringObjectInspector
* WritableConstantHiveVarcharObjectInspector
* WritableConstantHiveDecimalObjectInspector
@@ -147,9 +147,9 @@ import scala.collection.JavaConversions._
* WritableConstantByteObjectInspector
* WritableConstantBinaryObjectInspector
* WritableConstantDateObjectInspector
- * Map Object Inspector:
+ * Map Object Inspector:
* StandardConstantMapObjectInspector
- * List Object Inspector:
+ * List Object Inspector:
* StandardConstantListObjectInspector]]
* Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct
* Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union
@@ -250,9 +250,9 @@ private[hive] trait HiveInspectors {
poi.getWritableConstantValue.getHiveDecimal)
case poi: WritableConstantTimestampObjectInspector =>
poi.getWritableConstantValue.getTimestamp.clone()
- case poi: WritableConstantIntObjectInspector =>
+ case poi: WritableConstantIntObjectInspector =>
poi.getWritableConstantValue.get()
- case poi: WritableConstantDoubleObjectInspector =>
+ case poi: WritableConstantDoubleObjectInspector =>
poi.getWritableConstantValue.get()
case poi: WritableConstantBooleanObjectInspector =>
poi.getWritableConstantValue.get()
@@ -306,7 +306,7 @@ private[hive] trait HiveInspectors {
// In order to keep backward-compatible, we have to copy the
// bytes with old apis
val bw = x.getPrimitiveWritableObject(data)
- val result = new Array[Byte](bw.getLength())
+ val result = new Array[Byte](bw.getLength())
System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
result
case x: DateObjectInspector if x.preferWritable() =>
@@ -395,6 +395,30 @@ private[hive] trait HiveInspectors {
}
/**
+ * Builds specific unwrappers ahead of time according to object inspector
+ * types to avoid pattern matching and branching costs per row.
+ */
+ def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit =
+ field.getFieldObjectInspector match {
+ case oi: BooleanObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value))
+ case oi: ByteObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value))
+ case oi: ShortObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value))
+ case oi: IntObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value))
+ case oi: LongObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value))
+ case oi: FloatObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value))
+ case oi: DoubleObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
+ case oi =>
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
+ }
+
+ /**
* Converts native catalyst types to the types expected by Hive
* @param a the value to be wrapped
* @param oi This ObjectInspector associated with the value returned by this function, and
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala
new file mode 100644
index 0000000000..1e51173a19
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader}
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.hive.HiveMetastoreTypes
+import org.apache.spark.sql.types.StructType
+
+private[orc] object OrcFileOperator extends Logging{
+ def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = {
+ val conf = config.getOrElse(new Configuration)
+ val fspath = new Path(pathStr)
+ val fs = fspath.getFileSystem(conf)
+ val orcFiles = listOrcFiles(pathStr, conf)
+
+ // TODO Need to consider all files when schema evolution is taken into account.
+ OrcFile.createReader(fs, orcFiles.head)
+ }
+
+ def readSchema(path: String, conf: Option[Configuration]): StructType = {
+ val reader = getFileReader(path, conf)
+ val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector]
+ val schema = readerInspector.getTypeName
+ HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType]
+ }
+
+ def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = {
+ getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector]
+ }
+
+ def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
+ val origPath = new Path(pathStr)
+ val fs = origPath.getFileSystem(conf)
+ val path = origPath.makeQualified(fs)
+ val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
+ .filterNot(_.isDir)
+ .map(_.getPath)
+ .filterNot(_.getName.startsWith("_"))
+ .filterNot(_.getName.startsWith("."))
+
+ if (paths == null || paths.size == 0) {
+ throw new IllegalArgumentException(
+ s"orcFileOperator: path $path does not have valid orc files matching the pattern")
+ }
+
+ paths
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
new file mode 100644
index 0000000000..250e73a4db
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar}
+import org.apache.hadoop.hive.ql.io.sarg.SearchArgument
+import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder
+import org.apache.hadoop.hive.serde2.io.DateWritable
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.sources._
+
+/**
+ * It may be optimized by push down partial filters. But we are conservative here.
+ * Because if some filters fail to be parsed, the tree may be corrupted,
+ * and cannot be used anymore.
+ */
+private[orc] object OrcFilters extends Logging {
+ def createFilter(expr: Array[Filter]): Option[SearchArgument] = {
+ expr.reduceOption(And).flatMap { conjunction =>
+ val builder = SearchArgument.FACTORY.newBuilder()
+ buildSearchArgument(conjunction, builder).map(_.build())
+ }
+ }
+
+ private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = {
+ def newBuilder = SearchArgument.FACTORY.newBuilder()
+
+ def isSearchableLiteral(value: Any) = value match {
+ // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method.
+ case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar |
+ _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true
+ case _ => false
+ }
+
+ // lian: I probably missed something here, and had to end up with a pretty weird double-checking
+ // pattern when converting `And`/`Or`/`Not` filters.
+ //
+ // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`,
+ // and `startNot()` mutate internal state of the builder instance. This forces us to translate
+ // all convertible filters with a single builder instance. However, before actually converting a
+ // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible
+ // filter is found, we may already end up with a builder whose internal state is inconsistent.
+ //
+ // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and
+ // then try to convert its children. Say we convert `left` child successfully, but find that
+ // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is
+ // inconsistent now.
+ //
+ // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their
+ // children with brand new builders, and only do the actual conversion with the right builder
+ // instance when the children are proven to be convertible.
+ //
+ // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only.
+ // Usage of builder methods mentioned above can only be found in test code, where all tested
+ // filters are known to be convertible.
+
+ expression match {
+ case And(left, right) =>
+ val tryLeft = buildSearchArgument(left, newBuilder)
+ val tryRight = buildSearchArgument(right, newBuilder)
+
+ val conjunction = for {
+ _ <- tryLeft
+ _ <- tryRight
+ lhs <- buildSearchArgument(left, builder.startAnd())
+ rhs <- buildSearchArgument(right, lhs)
+ } yield rhs.end()
+
+ // For filter `left AND right`, we can still push down `left` even if `right` is not
+ // convertible, and vice versa.
+ conjunction
+ .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder)))
+ .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder)))
+
+ case Or(left, right) =>
+ for {
+ _ <- buildSearchArgument(left, newBuilder)
+ _ <- buildSearchArgument(right, newBuilder)
+ lhs <- buildSearchArgument(left, builder.startOr())
+ rhs <- buildSearchArgument(right, lhs)
+ } yield rhs.end()
+
+ case Not(child) =>
+ for {
+ _ <- buildSearchArgument(child, newBuilder)
+ negate <- buildSearchArgument(child, builder.startNot())
+ } yield negate.end()
+
+ case EqualTo(attribute, value) =>
+ Option(value)
+ .filter(isSearchableLiteral)
+ .map(builder.equals(attribute, _))
+
+ case LessThan(attribute, value) =>
+ Option(value)
+ .filter(isSearchableLiteral)
+ .map(builder.lessThan(attribute, _))
+
+ case LessThanOrEqual(attribute, value) =>
+ Option(value)
+ .filter(isSearchableLiteral)
+ .map(builder.lessThanEquals(attribute, _))
+
+ case GreaterThan(attribute, value) =>
+ Option(value)
+ .filter(isSearchableLiteral)
+ .map(builder.startNot().lessThanEquals(attribute, _).end())
+
+ case GreaterThanOrEqual(attribute, value) =>
+ Option(value)
+ .filter(isSearchableLiteral)
+ .map(builder.startNot().lessThan(attribute, _).end())
+
+ case IsNull(attribute) =>
+ Some(builder.isNull(attribute))
+
+ case IsNotNull(attribute) =>
+ Some(builder.startNot().isNull(attribute).end())
+
+ case In(attribute, values) =>
+ Option(values)
+ .filter(_.forall(isSearchableLiteral))
+ .map(builder.in(attribute, _))
+
+ case _ => None
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
new file mode 100644
index 0000000000..9708199f07
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -0,0 +1,290 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.util.{Objects, Properties}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit}
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
+import org.apache.hadoop.io.{NullWritable, Writable}
+import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+import org.apache.spark.rdd.{HadoopRDD, RDD}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim}
+import org.apache.spark.sql.sources.{Filter, _}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{Logging, SerializableWritable}
+
+/* Implicit conversions */
+import scala.collection.JavaConversions._
+
+private[sql] class DefaultSource extends HadoopFsRelationProvider {
+ def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ schema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation = {
+ assert(
+ sqlContext.isInstanceOf[HiveContext],
+ "The ORC data source can only be used with HiveContext.")
+
+ val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition]))
+ OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext)
+ }
+}
+
+private[orc] class OrcOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors {
+
+ private val serializer = {
+ val table = new Properties()
+ table.setProperty("columns", dataSchema.fieldNames.mkString(","))
+ table.setProperty("columns.types", dataSchema.map { f =>
+ HiveMetastoreTypes.toMetastoreType(f.dataType)
+ }.mkString(":"))
+
+ val serde = new OrcSerde
+ serde.initialize(context.getConfiguration, table)
+ serde
+ }
+
+ // Object inspector converted from the schema of the relation to be written.
+ private val structOI = {
+ val typeInfo =
+ TypeInfoUtils.getTypeInfoFromTypeString(
+ HiveMetastoreTypes.toMetastoreType(dataSchema))
+
+ TypeInfoUtils
+ .getStandardJavaObjectInspectorFromTypeInfo(typeInfo)
+ .asInstanceOf[StructObjectInspector]
+ }
+
+ // Used to hold temporary `Writable` fields of the next row to be written.
+ private val reusableOutputBuffer = new Array[Any](dataSchema.length)
+
+ // Used to convert Catalyst values into Hadoop `Writable`s.
+ private val wrappers = structOI.getAllStructFieldRefs.map { ref =>
+ wrapperFor(ref.getFieldObjectInspector)
+ }.toArray
+
+ // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this
+ // flag to decide whether `OrcRecordWriter.close()` needs to be called.
+ private var recordWriterInstantiated = false
+
+ private lazy val recordWriter: RecordWriter[NullWritable, Writable] = {
+ recordWriterInstantiated = true
+
+ val conf = context.getConfiguration
+ val partition = context.getTaskAttemptID.getTaskID.getId
+ val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc"
+
+ new OrcOutputFormat().getRecordWriter(
+ new Path(path, filename).getFileSystem(conf),
+ conf.asInstanceOf[JobConf],
+ new Path(path, filename).toUri.getPath,
+ Reporter.NULL
+ ).asInstanceOf[RecordWriter[NullWritable, Writable]]
+ }
+
+ override def write(row: Row): Unit = {
+ var i = 0
+ while (i < row.length) {
+ reusableOutputBuffer(i) = wrappers(i)(row(i))
+ i += 1
+ }
+
+ recordWriter.write(
+ NullWritable.get(),
+ serializer.serialize(reusableOutputBuffer, structOI))
+ }
+
+ override def close(): Unit = {
+ if (recordWriterInstantiated) {
+ recordWriter.close(Reporter.NULL)
+ }
+ }
+}
+
+@DeveloperApi
+private[sql] case class OrcRelation(
+ override val paths: Array[String],
+ parameters: Map[String, String],
+ maybeSchema: Option[StructType] = None,
+ maybePartitionSpec: Option[PartitionSpec] = None)(
+ @transient val sqlContext: SQLContext)
+ extends HadoopFsRelation(maybePartitionSpec)
+ with Logging {
+
+ override val dataSchema: StructType = maybeSchema.getOrElse {
+ OrcFileOperator.readSchema(
+ paths.head, Some(sqlContext.sparkContext.hadoopConfiguration))
+ }
+
+ override def userDefinedPartitionColumns: Option[StructType] =
+ maybePartitionSpec.map(_.partitionColumns)
+
+ override def needConversion: Boolean = false
+
+ override def equals(other: Any): Boolean = other match {
+ case that: OrcRelation =>
+ paths.toSet == that.paths.toSet &&
+ dataSchema == that.dataSchema &&
+ schema == that.schema &&
+ partitionColumns == that.partitionColumns
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Objects.hashCode(
+ paths.toSet,
+ dataSchema,
+ schema,
+ maybePartitionSpec)
+ }
+
+ override def buildScan(requiredColumns: Array[String],
+ filters: Array[Filter],
+ inputPaths: Array[String]): RDD[Row] = {
+ val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes
+ OrcTableScan(output, this, filters, inputPaths).execute()
+ }
+
+ override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new OrcOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+}
+
+private[orc] case class OrcTableScan(
+ attributes: Seq[Attribute],
+ @transient relation: OrcRelation,
+ filters: Array[Filter],
+ inputPaths: Array[String])
+ extends Logging
+ with HiveInspectors {
+
+ @transient private val sqlContext = relation.sqlContext
+
+ private def addColumnIds(
+ output: Seq[Attribute],
+ relation: OrcRelation,
+ conf: Configuration): Unit = {
+ val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer)
+ val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip
+ HiveShim.appendReadColumns(conf, sortedIds, sortedNames)
+ }
+
+ // Transform all given raw `Writable`s into `Row`s.
+ private def fillObject(
+ path: String,
+ conf: Configuration,
+ iterator: Iterator[Writable],
+ nonPartitionKeyAttrs: Seq[(Attribute, Int)],
+ mutableRow: MutableRow): Iterator[Row] = {
+ val deserializer = new OrcSerde
+ val soi = OrcFileOperator.getObjectInspector(path, Some(conf))
+ val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map {
+ case (attr, ordinal) =>
+ soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal
+ }.unzip
+ val unwrappers = fieldRefs.map(unwrapperFor)
+ // Map each tuple to a row object
+ iterator.map { value =>
+ val raw = deserializer.deserialize(value)
+ var i = 0
+ while (i < fieldRefs.length) {
+ val fieldValue = soi.getStructFieldData(raw, fieldRefs(i))
+ if (fieldValue == null) {
+ mutableRow.setNullAt(fieldOrdinals(i))
+ } else {
+ unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i))
+ }
+ i += 1
+ }
+ mutableRow: Row
+ }
+ }
+
+ def execute(): RDD[Row] = {
+ val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
+ val conf = job.getConfiguration
+
+ // Tries to push down filters if ORC filter push-down is enabled
+ if (sqlContext.conf.orcFilterPushDown) {
+ OrcFilters.createFilter(filters).foreach { f =>
+ conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo)
+ conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true)
+ }
+ }
+
+ // Sets requested columns
+ addColumnIds(attributes, relation, conf)
+
+ if (inputPaths.nonEmpty) {
+ FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*)
+ }
+
+ val inputFormatClass =
+ classOf[OrcInputFormat]
+ .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]]
+
+ val rdd = sqlContext.sparkContext.hadoopRDD(
+ conf.asInstanceOf[JobConf],
+ inputFormatClass,
+ classOf[NullWritable],
+ classOf[Writable]
+ ).asInstanceOf[HadoopRDD[NullWritable, Writable]]
+
+ val wrappedConf = new SerializableWritable(conf)
+
+ rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) =>
+ val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
+ fillObject(
+ split.getPath.toString,
+ wrappedConf.value,
+ iterator.map(_._2),
+ attributes.zipWithIndex,
+ mutableRow)
+ }
+ }
+}
+
+private[orc] object OrcTableScan {
+ // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public.
+ private[orc] val SARG_PUSHDOWN = "sarg.pushdown"
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
new file mode 100644
index 0000000000..080af5bb23
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.sources.HadoopFsRelationTest
+import org.apache.spark.sql.types._
+
+class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
+ override val dataSourceName: String = classOf[DefaultSource].getCanonicalName
+
+ import sqlContext._
+ import sqlContext.implicits._
+
+ test("save()/load() - partitioned table - simple queries - partition columns in data") {
+ withTempDir { file =>
+ val basePath = new Path(file.getCanonicalPath)
+ val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
+ val qualifiedBasePath = fs.makeQualified(basePath)
+
+ for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
+ val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ sparkContext
+ .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
+ .toDF("a", "b", "p1")
+ .write
+ .format("orc")
+ .save(partitionDir.toString)
+ }
+
+ val dataSchemaWithPartition =
+ StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
+
+ checkQueries(
+ load(
+ source = dataSourceName,
+ options = Map(
+ "path" -> file.getCanonicalPath,
+ "dataSchema" -> dataSchemaWithPartition.json)))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
new file mode 100644
index 0000000000..88c99e3526
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.util.Utils
+import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+
+
+// The data where the partitioning key exists only in the directory structure.
+case class OrcParData(intField: Int, stringField: String)
+
+// The data that also includes the partitioning key
+case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
+
+// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot
+class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
+ val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal
+
+ def withTempDir(f: File => Unit): Unit = {
+ val dir = Utils.createTempDir().getCanonicalFile
+ try f(dir) finally Utils.deleteRecursively(dir)
+ }
+
+ def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ data: Seq[T], path: File): Unit = {
+ data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath)
+ }
+
+
+ def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ df: DataFrame, path: File): Unit = {
+ df.write.format("orc").mode("overwrite").save(path.getCanonicalPath)
+ }
+
+ protected def withTempTable(tableName: String)(f: => Unit): Unit = {
+ try f finally TestHive.dropTempTable(tableName)
+ }
+
+ protected def makePartitionDir(
+ basePath: File,
+ defaultPartitionName: String,
+ partitionCols: (String, Any)*): File = {
+ val partNames = partitionCols.map { case (k, v) =>
+ val valueString = if (v == null || v == "") defaultPartitionName else v.toString
+ s"$k=$valueString"
+ }
+
+ val partDir = partNames.foldLeft(basePath) { (parent, child) =>
+ new File(parent, child)
+ }
+
+ assert(partDir.mkdirs(), s"Couldn't create directory $partDir")
+ partDir
+ }
+
+ test("read partitioned table - normal case") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ read.format("orc").load(base.getCanonicalPath).registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, 1, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, i.toString, pi, "foo"))
+ }
+ }
+ }
+
+ test("read partitioned table - partition key included in orc file") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ read.format("orc").load(base.getCanonicalPath).registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, 1, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, "foo"))
+ }
+ }
+ }
+
+
+ test("read partitioned table - with nulls") {
+ withTempDir { base =>
+ for {
+ // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero...
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ read
+ .format("orc")
+ .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName)
+ .load(base.getCanonicalPath)
+ .registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi IS NULL"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, null, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ } yield Row(i, i.toString, pi, null))
+ }
+ }
+ }
+
+ test("read partitioned table - with nulls and partition keys are included in Orc file") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ read
+ .format("orc")
+ .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName)
+ .load(base.getCanonicalPath)
+ .registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, null))
+ }
+ }
+ }
+}
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
new file mode 100644
index 0000000000..cdd6e705f4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hadoop.hive.ql.io.orc.CompressionKind
+import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+case class AllDataTypesWithNonPrimitiveType(
+ stringField: String,
+ intField: Int,
+ longField: Long,
+ floatField: Float,
+ doubleField: Double,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean,
+ array: Seq[Int],
+ arrayContainsNull: Seq[Option[Int]],
+ map: Map[Int, Long],
+ mapValueContainsNull: Map[Int, Option[Long]],
+ data: (Seq[Int], (Int, String)))
+
+case class BinaryData(binaryData: Array[Byte])
+
+case class Contact(name: String, phone: String)
+
+case class Person(name: String, age: Int, contacts: Seq[Contact])
+
+class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest {
+ override val sqlContext = TestHive
+
+ import TestHive.read
+
+ def getTempFilePath(prefix: String, suffix: String = ""): File = {
+ val tempFile = File.createTempFile(prefix, suffix)
+ tempFile.delete()
+ tempFile
+ }
+
+ test("Read/write All Types") {
+ val data = (0 to 255).map { i =>
+ (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0)
+ }
+
+ withOrcFile(data) { file =>
+ checkAnswer(
+ read.format("orc").load(file),
+ data.toDF().collect())
+ }
+ }
+
+ test("Read/write binary data") {
+ withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file =>
+ val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0)
+ assert(new String(bytes, "utf8") === "test")
+ }
+ }
+
+ test("Read/write all types with non-primitive type") {
+ val data = (0 to 255).map { i =>
+ AllDataTypesWithNonPrimitiveType(
+ s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0,
+ 0 until i,
+ (0 until i).map(Option(_).filter(_ % 3 == 0)),
+ (0 until i).map(i => i -> i.toLong).toMap,
+ (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None),
+ (0 until i, (i, s"$i")))
+ }
+
+ withOrcFile(data) { file =>
+ checkAnswer(
+ read.format("orc").load(file),
+ data.toDF().collect())
+ }
+ }
+
+ test("Creating case class RDD table") {
+ val data = (1 to 100).map(i => (i, s"val_$i"))
+ sparkContext.parallelize(data).toDF().registerTempTable("t")
+ withTempTable("t") {
+ checkAnswer(sql("SELECT * FROM t"), data.toDF().collect())
+ }
+ }
+
+ test("Simple selection form ORC table") {
+ val data = (1 to 10).map { i =>
+ Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") })
+ }
+
+ withOrcTable(data, "t") {
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // expr = leaf-0
+ assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5)
+
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // expr = (not leaf-0)
+ assertResult(10) {
+ sql("SELECT name, contacts FROM t where age > 5")
+ .flatMap(_.getAs[Seq[_]]("contacts"))
+ .count()
+ }
+
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // leaf-1 = (LESS_THAN age 8)
+ // expr = (and (not leaf-0) leaf-1)
+ {
+ val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8")
+ assert(df.count() === 2)
+ assertResult(4) {
+ df.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ }
+ }
+
+ // ppd:
+ // leaf-0 = (LESS_THAN age 2)
+ // leaf-1 = (LESS_THAN_EQUALS age 8)
+ // expr = (or leaf-0 (not leaf-1))
+ {
+ val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8")
+ assert(df.count() === 3)
+ assertResult(6) {
+ df.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ }
+ }
+ }
+ }
+
+ test("save and load case class RDD with `None`s as orc") {
+ val data = (
+ None: Option[Int],
+ None: Option[Long],
+ None: Option[Float],
+ None: Option[Double],
+ None: Option[Boolean]
+ ) :: Nil
+
+ withOrcFile(data) { file =>
+ checkAnswer(
+ read.format("orc").load(file),
+ Row(Seq.fill(5)(null): _*))
+ }
+ }
+
+ // We only support zlib in Hive 0.12.0 now
+ test("Default compression options for writing to an ORC file") {
+ withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file =>
+ assertResult(CompressionKind.ZLIB) {
+ OrcFileOperator.getFileReader(file).getCompression
+ }
+ }
+ }
+
+ // Following codec is supported in hive-0.13.1, ignore it now
+ ignore("Other compression options for writing to an ORC file - 0.13.1 and above") {
+ val data = (1 to 100).map(i => (i, s"val_$i"))
+ val conf = sparkContext.hadoopConfiguration
+
+ conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY")
+ withOrcFile(data) { file =>
+ assertResult(CompressionKind.SNAPPY) {
+ OrcFileOperator.getFileReader(file).getCompression
+ }
+ }
+
+ conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE")
+ withOrcFile(data) { file =>
+ assertResult(CompressionKind.NONE) {
+ OrcFileOperator.getFileReader(file).getCompression
+ }
+ }
+
+ conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO")
+ withOrcFile(data) { file =>
+ assertResult(CompressionKind.LZO) {
+ OrcFileOperator.getFileReader(file).getCompression
+ }
+ }
+ }
+
+ test("simple select queries") {
+ withOrcTable((0 until 10).map(i => (i, i.toString)), "t") {
+ checkAnswer(
+ sql("SELECT `_1` FROM t where t.`_1` > 5"),
+ (6 until 10).map(Row.apply(_)))
+
+ checkAnswer(
+ sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"),
+ (0 until 5).map(Row.apply(_)))
+ }
+ }
+
+ test("appending") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ withOrcTable(data, "t") {
+ sql("INSERT INTO TABLE t SELECT * FROM tmp")
+ checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
+ }
+ catalog.unregisterTable(Seq("tmp"))
+ }
+
+ test("overwriting") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ withOrcTable(data, "t") {
+ sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
+ checkAnswer(table("t"), data.map(Row.fromTuple))
+ }
+ catalog.unregisterTable(Seq("tmp"))
+ }
+
+ test("self-join") {
+ // 4 rows, cells of column 1 of row 2 and row 4 are null
+ val data = (1 to 4).map { i =>
+ val maybeInt = if (i % 2 == 0) None else Some(i)
+ (maybeInt, i.toString)
+ }
+
+ withOrcTable(data, "t") {
+ val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`")
+ val queryOutput = selfJoin.queryExecution.analyzed.output
+
+ assertResult(4, "Field count mismatches")(queryOutput.size)
+ assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") {
+ queryOutput.filter(_.name == "_1").map(_.exprId).size
+ }
+
+ checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
+ }
+ }
+
+ test("nested data - struct with array field") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i"))))
+ withOrcTable(data, "t") {
+ checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map {
+ case Tuple1((_, Seq(string))) => Row(string)
+ })
+ }
+ }
+
+ test("nested data - array of struct") {
+ val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i")))
+ withOrcTable(data, "t") {
+ checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map {
+ case Tuple1(Seq((_, string))) => Row(string)
+ })
+ }
+ }
+
+ test("columns only referenced by pushed down filters should remain") {
+ withOrcTable((1 to 10).map(Tuple1.apply), "t") {
+ checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_)))
+ }
+ }
+
+ test("SPARK-5309 strings stored using dictionary compression in orc") {
+ withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") {
+ checkAnswer(
+ sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"),
+ (0 until 10).map(i => Row("same", "run_" + i, 100)))
+
+ checkAnswer(
+ sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"),
+ List(Row("same", "run_5", 100)))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
new file mode 100644
index 0000000000..82e08caf46
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.{QueryTest, Row}
+
+case class OrcData(intField: Int, stringField: String)
+
+abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
+ var orcTableDir: File = null
+ var orcTableAsDir: File = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ orcTableAsDir = File.createTempFile("orctests", "sparksql")
+ orcTableAsDir.delete()
+ orcTableAsDir.mkdir()
+
+ // Hack: to prepare orc data files using hive external tables
+ orcTableDir = File.createTempFile("orctests", "sparksql")
+ orcTableDir.delete()
+ orcTableDir.mkdir()
+ import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+ sparkContext
+ .makeRDD(1 to 10)
+ .map(i => OrcData(i, s"part-$i"))
+ .toDF()
+ .registerTempTable(s"orc_temp_table")
+
+ sql(
+ s"""CREATE EXTERNAL TABLE normal_orc(
+ | intField INT,
+ | stringField STRING
+ |)
+ |STORED AS ORC
+ |LOCATION '${orcTableAsDir.getCanonicalPath}'
+ """.stripMargin)
+
+ sql(
+ s"""INSERT INTO TABLE normal_orc
+ |SELECT intField, stringField FROM orc_temp_table
+ """.stripMargin)
+ }
+
+ override def afterAll(): Unit = {
+ orcTableDir.delete()
+ orcTableAsDir.delete()
+ }
+
+ test("create temporary orc table") {
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source where intField > 5"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
+ (1 to 10).map(i => Row(1, s"part-$i")))
+ }
+
+ test("create temporary orc table as") {
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source WHERE intField > 5"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
+ (1 to 10).map(i => Row(1, s"part-$i")))
+ }
+
+ test("appending insert") {
+ sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5")
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i =>
+ Seq.fill(2)(Row(i, s"part-$i"))
+ })
+ }
+
+ test("overwrite insert") {
+ sql(
+ """INSERT OVERWRITE TABLE normal_orc_as_source
+ |SELECT * FROM orc_temp_table WHERE intField > 5
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_as_source"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+ }
+}
+
+class OrcSourceSuite extends OrcSuite {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ sql(
+ s"""CREATE TEMPORARY TABLE normal_orc_source
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}'
+ |)
+ """.stripMargin)
+
+ sql(
+ s"""CREATE TEMPORARY TABLE normal_orc_as_source
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}'
+ |)
+ """.stripMargin)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
new file mode 100644
index 0000000000..750f0b04aa
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql._
+
+private[sql] trait OrcTest extends SQLTestUtils {
+ protected def hiveContext = sqlContext.asInstanceOf[HiveContext]
+
+ import sqlContext.sparkContext
+ import sqlContext.implicits._
+
+ /**
+ * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f`
+ * returns.
+ */
+ protected def withOrcFile[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T])
+ (f: String => Unit): Unit = {
+ withTempPath { file =>
+ sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath)
+ f(file.getCanonicalPath)
+ }
+ }
+
+ /**
+ * Writes `data` to a Orc file and reads it back as a [[DataFrame]],
+ * which is then passed to `f`. The Orc file will be deleted after `f` returns.
+ */
+ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T])
+ (f: DataFrame => Unit): Unit = {
+ withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path)))
+ }
+
+ /**
+ * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a
+ * temporary table named `tableName`, then call `f`. The temporary table together with the
+ * Orc file will be dropped/deleted after `f` returns.
+ */
+ protected def withOrcTable[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T], tableName: String)
+ (f: => Unit): Unit = {
+ withOrcDataFrame(data) { df =>
+ hiveContext.registerDataFrameAsTable(df, tableName)
+ withTempTable(tableName)(f)
+ }
+ }
+
+ protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ data: Seq[T], path: File): Unit = {
+ data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath)
+ }
+
+ protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ df: DataFrame, path: File): Unit = {
+ df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 9d9b436cab..ad4a4826c6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -23,12 +23,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.parquet.ParquetTest
+import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
-// TODO Don't extend ParquetTest
-// This test suite extends ParquetTest for some convenient utility methods. These methods should be
-// moved to some more general places, maybe QueryTest.
-class HadoopFsRelationTest extends QueryTest with ParquetTest {
+abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
override val sqlContext: SQLContext = TestHive
import sqlContext._