aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala24
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala39
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala17
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala68
11 files changed, 150 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 5094058164..5770f59b53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def simpleString: String = s"array<${elementType.simpleString}>"
- private[spark] override def asNullable: ArrayType =
+ override private[spark] def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
+
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || elementType.existsRecursively(f)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index f4428c2e8b..7bcd623b3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType
+ /**
+ * Returns true if any `DataType` of this DataType tree satisfies the given function `f`.
+ */
+ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)
+
override private[sql] def defaultConcreteType: DataType = this
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ac34b64282..00461e529c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -62,8 +62,12 @@ case class MapType(
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
- private[spark] override def asNullable: MapType =
+ override private[spark] def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
+
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 9cbc207538..d8968ef806 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
/**
@@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]
- private[spark] override def asNullable: StructType = {
+ override private[spark] def asNullable: StructType = {
val newFields = fields.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(name, dataType.asNullable, nullable = true, metadata)
@@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(newFields)
}
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || fields.exists(field => field.dataType.existsRecursively(f))
+ }
+
private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 88b221cd81..706ecd29d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite {
}
}
+ test("existsRecursively") {
+ val struct = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+ assert(struct.existsRecursively(_.isInstanceOf[LongType]))
+ assert(struct.existsRecursively(_.isInstanceOf[StructType]))
+ assert(!struct.existsRecursively(_.isInstanceOf[IntegerType]))
+
+ val mapType = MapType(struct, StringType)
+ assert(mapType.existsRecursively(_.isInstanceOf[LongType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[StructType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[StringType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[MapType]))
+ assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType]))
+
+ val arrayType = ArrayType(mapType)
+ assert(arrayType.existsRecursively(_.isInstanceOf[LongType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[StructType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[StringType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[MapType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType]))
+ assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
+ }
+
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 5e5497837a..6770462bb0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
-import org.apache.spark.sql.execution.{FileRelation, datasources}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
+import org.apache.spark.sql.execution.{FileRelation, datasources}
import org.apache.spark.sql.hive.client._
-import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode}
@@ -86,9 +85,9 @@ private[hive] object HiveSerDe {
serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")))
val key = source.toLowerCase match {
- case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet"
- case _ if source.startsWith("org.apache.spark.sql.orc") => "orc"
- case _ => source.toLowerCase
+ case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
+ case s if s.startsWith("org.apache.spark.sql.orc") => "orc"
+ case s => s
}
serdeMap.get(key)
@@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val hiveTable = (maybeSerDe, dataSource.relation) match {
case (Some(serde), relation: HadoopFsRelation)
if relation.paths.length == 1 && relation.partitionColumns.isEmpty =>
- logInfo {
- "Persisting data source relation with a single input path into Hive metastore in Hive " +
- s"compatible format. Input path: ${relation.paths.head}"
+ // Hive ParquetSerDe doesn't support decimal type until 1.2.0.
+ val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet"))
+ val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType])
+
+ val hiveParquetSupportsDecimal = client.version match {
+ case org.apache.spark.sql.hive.client.hive.v1_2 => true
+ case _ => false
+ }
+
+ if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) {
+ // If Hive version is below 1.2.0, we cannot save Hive compatible schema to
+ // metastore when the file format is Parquet and the schema has DecimalType.
+ logWarning {
+ "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " +
+ "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " +
+ s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384."
+ }
+ newSparkSQLSpecificMetastoreTable()
+ } else {
+ logInfo {
+ "Persisting data source relation with a single input path into Hive metastore in " +
+ s"Hive compatible format. Input path: ${relation.paths.head}"
+ }
+ newHiveCompatibleMetastoreTable(relation, serde)
}
- newHiveCompatibleMetastoreTable(relation, serde)
case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty =>
logWarning {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index a82e152dcd..3811c152a7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -88,6 +88,9 @@ private[hive] case class HiveTable(
*/
private[hive] trait ClientInterface {
+ /** Returns the Hive Version of this client. */
+ def version: HiveVersion
+
/** Returns the configuration for the given key in the current session. */
def getConf(key: String, defaultValue: String): String
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 3d05b583cf..f49c97de8f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils}
* this ClientWrapper.
*/
private[hive] class ClientWrapper(
- version: HiveVersion,
+ override val version: HiveVersion,
config: Map[String, String],
initClassLoader: ClassLoader)
extends ClientInterface
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
index 0503691a44..b1b8439efa 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
@@ -25,7 +25,7 @@ package object client {
val exclusions: Seq[String] = Nil)
// scalastyle:off
- private[client] object hive {
+ private[hive] object hive {
case object v12 extends HiveVersion("0.12.0")
case object v13 extends HiveVersion("0.13.1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 332c3ec0c2..59e65ff97b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
import java.io.File
-import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable}
+import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.DataSourceTest
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.{Logging, SparkFunSuite}
@@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
override val sqlContext = TestHive
- private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1)
+ private val testDF = range(1, 3).select(
+ ('id + 0.1) cast DecimalType(10, 3) as 'd1,
+ 'id cast StringType as 'd2
+ ).coalesce(1)
Seq(
"parquet" -> (
@@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes
val columns = hiveTable.schema
assert(columns.map(_.name) === Seq("d1", "d2"))
- assert(columns.map(_.hiveType) === Seq("int", "string"))
+ assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))
checkAnswer(table("t"), testDF)
- assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
+ assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
}
}
@@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes
val columns = hiveTable.schema
assert(columns.map(_.name) === Seq("d1", "d2"))
- assert(columns.map(_.hiveType) === Seq("int", "string"))
+ assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))
checkAnswer(table("t"), testDF)
- assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
+ assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 1e1972d1ac..0c29646114 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.hive
import java.io.File
import scala.collection.mutable.ArrayBuffer
-import scala.sys.process.{ProcessLogger, Process}
+import scala.sys.process.{Process, ProcessLogger}
+import org.scalatest.Matchers
+import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
+import org.apache.spark.sql.types.DecimalType
import org.apache.spark.util.{ResetSystemProperties, Utils}
-import org.scalatest.Matchers
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.time.SpanSugar._
/**
* This suite tests spark-submit with applications using HiveContext.
@@ -50,8 +52,8 @@ class HiveSparkSubmitSuite
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
- val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()
- val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath()
+ val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
+ val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath
val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",")
val args = Seq(
"--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"),
@@ -91,6 +93,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}
+ test("SPARK-9757 Persist Parquet relation with decimal column") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val args = Seq(
+ "--class", SPARK_9757.getClass.getName.stripSuffix("$"),
+ "--name", "SparkSQLConfTest",
+ "--master", "local-cluster[2,1,1024]",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging {
// before spark.sql.hive.metastore.jars get set, we will see the following exception:
// Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only
// be used when hive execution version == hive metastore version.
- // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars
+ // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars
// using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1.
val conf = new SparkConf() {
override def getAll: Array[(String, String)] = {
@@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging {
sc.stop()
}
}
+
+object SPARK_9757 extends QueryTest with Logging {
+ def main(args: Array[String]): Unit = {
+ Utils.configTestLog4j("INFO")
+
+ val sparkContext = new SparkContext(
+ new SparkConf()
+ .set("spark.sql.hive.metastore.version", "0.13.1")
+ .set("spark.sql.hive.metastore.jars", "maven"))
+
+ val hiveContext = new TestHiveContext(sparkContext)
+ import hiveContext.implicits._
+ import org.apache.spark.sql.functions._
+
+ val dir = Utils.createTempDir()
+ dir.delete()
+
+ try {
+ {
+ val df =
+ hiveContext
+ .range(10)
+ .select(('id + 0.1) cast DecimalType(10, 3) as 'dec)
+ df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
+ checkAnswer(hiveContext.table("t"), df)
+ }
+
+ {
+ val df =
+ hiveContext
+ .range(10)
+ .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct)
+ df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
+ checkAnswer(hiveContext.table("t"), df)
+ }
+ } finally {
+ dir.delete()
+ hiveContext.sql("DROP TABLE t")
+ sparkContext.stop()
+ }
+ }
+}