aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala255
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala82
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java25
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java31
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java28
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala12
20 files changed, 573 insertions, 56 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f0bd3cbd98..93bfc25bca 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -313,12 +313,15 @@ class StructField(DataType):
"""
- def __init__(self, name, dataType, nullable):
+ def __init__(self, name, dataType, nullable, metadata=None):
"""Creates a StructField
:param name: the name of this field.
:param dataType: the data type of this field.
:param nullable: indicates whether values of this field
can be null.
+ :param metadata: metadata of this field, which is a map from string
+ to simple type that can be serialized to JSON
+ automatically
>>> (StructField("f1", StringType, True)
... == StructField("f1", StringType, True))
@@ -330,6 +333,7 @@ class StructField(DataType):
self.name = name
self.dataType = dataType
self.nullable = nullable
+ self.metadata = metadata or {}
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
@@ -338,13 +342,15 @@ class StructField(DataType):
def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
- "nullable": self.nullable}
+ "nullable": self.nullable,
+ "metadata": self.metadata}
@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
- json["nullable"])
+ json["nullable"],
+ json["metadata"])
class StructType(DataType):
@@ -423,7 +429,8 @@ def _parse_datatype_json_string(json_string):
... StructField("simpleArray", simple_arraytype, True),
... StructField("simpleMap", simple_maptype, True),
... StructField("simpleStruct", simple_structtype, True),
- ... StructField("boolean", BooleanType(), False)])
+ ... StructField("boolean", BooleanType(), False),
+ ... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
True
>>> # Complex ArrayType.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d76c743d3f..75923d9e8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -46,7 +46,7 @@ object ScalaReflection {
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
- s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+ s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
}
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 1eb260efa6..39b120e8de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
+import org.apache.spark.sql.catalyst.util.Metadata
abstract class Expression extends TreeNode[Expression] {
self: Product =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 9c865254e0..ab0701fd9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -43,7 +43,7 @@ abstract class Generator extends Expression {
override type EvaluatedType = TraversableOnce[Row]
override lazy val dataType =
- ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))
+ ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
override def nullable = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index fe13a661f6..3310566087 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.util.Metadata
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
@@ -43,6 +44,9 @@ abstract class NamedExpression extends Expression {
def toAttribute: Attribute
+ /** Returns the metadata when an expression is a reference to another expression with metadata. */
+ def metadata: Metadata = Metadata.empty
+
protected def typeSuffix =
if (resolved) {
dataType match {
@@ -88,10 +92,16 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType
override def nullable = child.nullable
+ override def metadata: Metadata = {
+ child match {
+ case named: NamedExpression => named.metadata
+ case _ => Metadata.empty
+ }
+ }
override def toAttribute = {
if (resolved) {
- AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
+ AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
} else {
UnresolvedAttribute(name)
}
@@ -108,15 +118,20 @@ case class Alias(child: Expression, name: String)
* @param name The name of this attribute, should only be used during analysis or for debugging.
* @param dataType The [[DataType]] of this attribute.
* @param nullable True if null is a valid value for this attribute.
+ * @param metadata The metadata of this attribute.
* @param exprId A globally unique id used to check if different AttributeReferences refer to the
* same attribute.
* @param qualifiers a list of strings that can be used to referred to this attribute in a fully
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
*/
-case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true)
- (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
- extends Attribute with trees.LeafNode[Expression] {
+case class AttributeReference(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean = true,
+ override val metadata: Metadata = Metadata.empty)(
+ val exprId: ExprId = NamedExpression.newExprId,
+ val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
@@ -128,10 +143,12 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
var h = 17
h = h * 37 + exprId.hashCode()
h = h * 37 + dataType.hashCode()
+ h = h * 37 + metadata.hashCode()
h
}
- override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+ override def newInstance() =
+ AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -140,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
if (nullable == newNullability) {
this
} else {
- AttributeReference(name, dataType, newNullability)(exprId, qualifiers)
+ AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers)
}
}
@@ -159,7 +176,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
if (newQualifiers.toSet == qualifiers.toSet) {
this
} else {
- AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
+ AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 4e6e1166bf..6069f9b0a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -24,16 +24,16 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
-import org.json4s.JsonAST.JValue
import org.json4s._
+import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
-
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -70,10 +70,11 @@ object DataType {
private def parseStructField(json: JValue): StructField = json match {
case JSortedObject(
+ ("metadata", metadata: JObject),
("name", JString(name)),
("nullable", JBool(nullable)),
("type", dataType: JValue)) =>
- StructField(name, parseDataType(dataType), nullable)
+ StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata))
}
@deprecated("Use DataType.fromJson instead", "1.2.0")
@@ -388,24 +389,34 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
* @param name The name of this field.
* @param dataType The data type of this field.
* @param nullable Indicates if values of this field can be `null` values.
+ * @param metadata The metadata of this field. The metadata should be preserved during
+ * transformation if the content of the column is not modified, e.g, in selection.
*/
-case class StructField(name: String, dataType: DataType, nullable: Boolean) {
+case class StructField(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean,
+ metadata: Metadata = Metadata.empty) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+ // override the default toString to be compatible with legacy parquet files.
+ override def toString: String = s"StructField($name,$dataType,$nullable)"
+
private[sql] def jsonValue: JValue = {
("name" -> name) ~
("type" -> dataType.jsonValue) ~
- ("nullable" -> nullable)
+ ("nullable" -> nullable) ~
+ ("metadata" -> metadata.jsonValue)
}
}
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
- StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
+ StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
}
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -439,7 +450,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
}
protected[sql] def toAttributes =
- fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+ fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
def treeString: String = {
val builder = new StringBuilder
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala
new file mode 100644
index 0000000000..2f2082fa3c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.collection.mutable
+
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+/**
+ * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
+ * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and
+ * Array[Metadata]. JSON is used for serialization.
+ *
+ * The default constructor is private. User should use either [[MetadataBuilder]] or
+ * [[Metadata$#fromJson]] to create Metadata instances.
+ *
+ * @param map an immutable map that stores the data
+ */
+sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable {
+
+ /** Gets a Long. */
+ def getLong(key: String): Long = get(key)
+
+ /** Gets a Double. */
+ def getDouble(key: String): Double = get(key)
+
+ /** Gets a Boolean. */
+ def getBoolean(key: String): Boolean = get(key)
+
+ /** Gets a String. */
+ def getString(key: String): String = get(key)
+
+ /** Gets a Metadata. */
+ def getMetadata(key: String): Metadata = get(key)
+
+ /** Gets a Long array. */
+ def getLongArray(key: String): Array[Long] = get(key)
+
+ /** Gets a Double array. */
+ def getDoubleArray(key: String): Array[Double] = get(key)
+
+ /** Gets a Boolean array. */
+ def getBooleanArray(key: String): Array[Boolean] = get(key)
+
+ /** Gets a String array. */
+ def getStringArray(key: String): Array[String] = get(key)
+
+ /** Gets a Metadata array. */
+ def getMetadataArray(key: String): Array[Metadata] = get(key)
+
+ /** Converts to its JSON representation. */
+ def json: String = compact(render(jsonValue))
+
+ override def toString: String = json
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: Metadata =>
+ if (map.keySet == that.map.keySet) {
+ map.keys.forall { k =>
+ (map(k), that.map(k)) match {
+ case (v0: Array[_], v1: Array[_]) =>
+ v0.view == v1.view
+ case (v0, v1) =>
+ v0 == v1
+ }
+ }
+ } else {
+ false
+ }
+ case other =>
+ false
+ }
+ }
+
+ override def hashCode: Int = Metadata.hash(this)
+
+ private def get[T](key: String): T = {
+ map(key).asInstanceOf[T]
+ }
+
+ private[sql] def jsonValue: JValue = Metadata.toJsonValue(this)
+}
+
+object Metadata {
+
+ /** Returns an empty Metadata. */
+ def empty: Metadata = new Metadata(Map.empty)
+
+ /** Creates a Metadata instance from JSON. */
+ def fromJson(json: String): Metadata = {
+ fromJObject(parse(json).asInstanceOf[JObject])
+ }
+
+ /** Creates a Metadata instance from JSON AST. */
+ private[sql] def fromJObject(jObj: JObject): Metadata = {
+ val builder = new MetadataBuilder
+ jObj.obj.foreach {
+ case (key, JInt(value)) =>
+ builder.putLong(key, value.toLong)
+ case (key, JDouble(value)) =>
+ builder.putDouble(key, value)
+ case (key, JBool(value)) =>
+ builder.putBoolean(key, value)
+ case (key, JString(value)) =>
+ builder.putString(key, value)
+ case (key, o: JObject) =>
+ builder.putMetadata(key, fromJObject(o))
+ case (key, JArray(value)) =>
+ if (value.isEmpty) {
+ // If it is an empty array, we cannot infer its element type. We put an empty Array[Long].
+ builder.putLongArray(key, Array.empty)
+ } else {
+ value.head match {
+ case _: JInt =>
+ builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray)
+ case _: JDouble =>
+ builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray)
+ case _: JBool =>
+ builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray)
+ case _: JString =>
+ builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray)
+ case _: JObject =>
+ builder.putMetadataArray(
+ key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
+ case other =>
+ throw new RuntimeException(s"Do not support array of type ${other.getClass}.")
+ }
+ }
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ builder.build()
+ }
+
+ /** Converts to JSON AST. */
+ private def toJsonValue(obj: Any): JValue = {
+ obj match {
+ case map: Map[_, _] =>
+ val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) }
+ JObject(fields)
+ case arr: Array[_] =>
+ val values = arr.toList.map(toJsonValue)
+ JArray(values)
+ case x: Long =>
+ JInt(x)
+ case x: Double =>
+ JDouble(x)
+ case x: Boolean =>
+ JBool(x)
+ case x: String =>
+ JString(x)
+ case x: Metadata =>
+ toJsonValue(x.map)
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ }
+
+ /** Computes the hash code for the types we support. */
+ private def hash(obj: Any): Int = {
+ obj match {
+ case map: Map[_, _] =>
+ map.mapValues(hash).##
+ case arr: Array[_] =>
+ // Seq.empty[T] has the same hashCode regardless of T.
+ arr.toSeq.map(hash).##
+ case x: Long =>
+ x.##
+ case x: Double =>
+ x.##
+ case x: Boolean =>
+ x.##
+ case x: String =>
+ x.##
+ case x: Metadata =>
+ hash(x.map)
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ }
+}
+
+/**
+ * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former.
+ */
+class MetadataBuilder {
+
+ private val map: mutable.Map[String, Any] = mutable.Map.empty
+
+ /** Returns the immutable version of this map. Used for java interop. */
+ protected def getMap = map.toMap
+
+ /** Include the content of an existing [[Metadata]] instance. */
+ def withMetadata(metadata: Metadata): this.type = {
+ map ++= metadata.map
+ this
+ }
+
+ /** Puts a Long. */
+ def putLong(key: String, value: Long): this.type = put(key, value)
+
+ /** Puts a Double. */
+ def putDouble(key: String, value: Double): this.type = put(key, value)
+
+ /** Puts a Boolean. */
+ def putBoolean(key: String, value: Boolean): this.type = put(key, value)
+
+ /** Puts a String. */
+ def putString(key: String, value: String): this.type = put(key, value)
+
+ /** Puts a [[Metadata]]. */
+ def putMetadata(key: String, value: Metadata): this.type = put(key, value)
+
+ /** Puts a Long array. */
+ def putLongArray(key: String, value: Array[Long]): this.type = put(key, value)
+
+ /** Puts a Double array. */
+ def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value)
+
+ /** Puts a Boolean array. */
+ def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value)
+
+ /** Puts a String array. */
+ def putStringArray(key: String, value: Array[String]): this.type = put(key, value)
+
+ /** Puts a [[Metadata]] array. */
+ def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value)
+
+ /** Builds the [[Metadata]] instance. */
+ def build(): Metadata = {
+ new Metadata(map.toMap)
+ }
+
+ private def put(key: String, value: Any): this.type = {
+ map.put(key, value)
+ this
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
new file mode 100644
index 0000000000..0063d31666
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.json4s.jackson.JsonMethods.parse
+import org.scalatest.FunSuite
+
+class MetadataSuite extends FunSuite {
+
+ val baseMetadata = new MetadataBuilder()
+ .putString("purpose", "ml")
+ .putBoolean("isBase", true)
+ .build()
+
+ val summary = new MetadataBuilder()
+ .putLong("numFeatures", 10L)
+ .build()
+
+ val age = new MetadataBuilder()
+ .putString("name", "age")
+ .putLong("index", 1L)
+ .putBoolean("categorical", false)
+ .putDouble("average", 45.0)
+ .build()
+
+ val gender = new MetadataBuilder()
+ .putString("name", "gender")
+ .putLong("index", 5)
+ .putBoolean("categorical", true)
+ .putStringArray("categories", Array("male", "female"))
+ .build()
+
+ val metadata = new MetadataBuilder()
+ .withMetadata(baseMetadata)
+ .putBoolean("isBase", false) // overwrite an existing key
+ .putMetadata("summary", summary)
+ .putLongArray("long[]", Array(0L, 1L))
+ .putDoubleArray("double[]", Array(3.0, 4.0))
+ .putBooleanArray("boolean[]", Array(true, false))
+ .putMetadataArray("features", Array(age, gender))
+ .build()
+
+ test("metadata builder and getters") {
+ assert(age.getLong("index") === 1L)
+ assert(age.getDouble("average") === 45.0)
+ assert(age.getBoolean("categorical") === false)
+ assert(age.getString("name") === "age")
+ assert(metadata.getString("purpose") === "ml")
+ assert(metadata.getBoolean("isBase") === false)
+ assert(metadata.getMetadata("summary") === summary)
+ assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L))
+ assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0))
+ assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false))
+ assert(gender.getStringArray("categories").toSeq === Seq("male", "female"))
+ assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender))
+ }
+
+ test("metadata json conversion") {
+ val json = metadata.json
+ withClue("toJson must produce a valid JSON string") {
+ parse(json)
+ }
+ val parsed = Metadata.fromJson(json)
+ assert(parsed === metadata)
+ assert(parsed.## === metadata.##)
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 37e88d72b9..0c85cdc0aa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -17,9 +17,7 @@
package org.apache.spark.sql.api.java;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
+import java.util.*;
/**
* The base type of all Spark SQL data types.
@@ -151,15 +149,31 @@ public abstract class DataType {
* Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and
* whether values of this field can be null values ({@code nullable}).
*/
- public static StructField createStructField(String name, DataType dataType, boolean nullable) {
+ public static StructField createStructField(
+ String name,
+ DataType dataType,
+ boolean nullable,
+ Metadata metadata) {
if (name == null) {
throw new IllegalArgumentException("name should not be null.");
}
if (dataType == null) {
throw new IllegalArgumentException("dataType should not be null.");
}
+ if (metadata == null) {
+ throw new IllegalArgumentException("metadata should not be null.");
+ }
+
+ return new StructField(name, dataType, nullable, metadata);
+ }
- return new StructField(name, dataType, nullable);
+ /**
+ * Creates a StructField with empty metadata.
+ *
+ * @see #createStructField(String, DataType, boolean, Metadata)
+ */
+ public static StructField createStructField(String name, DataType dataType, boolean nullable) {
+ return createStructField(name, dataType, nullable, (new MetadataBuilder()).build());
}
/**
@@ -191,5 +205,4 @@ public abstract class DataType {
return new StructType(fields);
}
-
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java
new file mode 100644
index 0000000000..0f819fb01a
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+/**
+ * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
+ * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and
+ * Array[Metadata]. JSON is used for serialization.
+ *
+ * The default constructor is private. User should use [[MetadataBuilder]].
+ */
+class Metadata extends org.apache.spark.sql.catalyst.util.Metadata {
+ Metadata(scala.collection.immutable.Map<String, Object> map) {
+ super(map);
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java
new file mode 100644
index 0000000000..6e6b12f072
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+/**
+ * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former.
+ */
+public class MetadataBuilder extends org.apache.spark.sql.catalyst.util.MetadataBuilder {
+ @Override
+ public Metadata build() {
+ return new Metadata(getMap());
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java
index b48e2a2c5f..7c60d492bc 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java;
+import java.util.Map;
+
/**
* A StructField object represents a field in a StructType object.
* A StructField object comprises three fields, {@code String name}, {@code DataType dataType},
@@ -24,20 +26,27 @@ package org.apache.spark.sql.api.java;
* The field of {@code dataType} specifies the data type of a StructField.
* The field of {@code nullable} specifies if values of a StructField can contain {@code null}
* values.
+ * The field of {@code metadata} provides extra information of the StructField.
*
* To create a {@link StructField},
- * {@link DataType#createStructField(String, DataType, boolean)}
+ * {@link DataType#createStructField(String, DataType, boolean, Metadata)}
* should be used.
*/
public class StructField {
private String name;
private DataType dataType;
private boolean nullable;
+ private Metadata metadata;
- protected StructField(String name, DataType dataType, boolean nullable) {
+ protected StructField(
+ String name,
+ DataType dataType,
+ boolean nullable,
+ Metadata metadata) {
this.name = name;
this.dataType = dataType;
this.nullable = nullable;
+ this.metadata = metadata;
}
public String getName() {
@@ -52,6 +61,10 @@ public class StructField {
return nullable;
}
+ public Metadata getMetadata() {
+ return metadata;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -62,6 +75,7 @@ public class StructField {
if (nullable != that.nullable) return false;
if (!dataType.equals(that.dataType)) return false;
if (!name.equals(that.name)) return false;
+ if (!metadata.equals(that.metadata)) return false;
return true;
}
@@ -71,6 +85,7 @@ public class StructField {
int result = name.hashCode();
result = 31 * result + dataType.hashCode();
result = 31 * result + (nullable ? 1 : 0);
+ result = 31 * result + metadata.hashCode();
return result;
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a41a500c9a..4953f8399a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 047dc85df6..eabe312f92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -117,10 +117,7 @@ private[sql] object JsonRDD extends Logging {
}
}.flatMap(field => field).toSeq
- StructType(
- (topLevelFields ++ structFields).sortBy {
- case StructField(name, _, _) => name
- })
+ StructType((topLevelFields ++ structFields).sortBy(_.name))
}
makeStruct(resolved.keySet.toSeq, Nil)
@@ -128,7 +125,7 @@ private[sql] object JsonRDD extends Logging {
private[sql] def nullTypeToStringType(struct: StructType): StructType = {
val fields = struct.fields.map {
- case StructField(fieldName, dataType, nullable) => {
+ case StructField(fieldName, dataType, nullable, _) => {
val newType = dataType match {
case NullType => StringType
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
@@ -163,9 +160,7 @@ private[sql] object JsonRDD extends Logging {
StructField(name, dataType, true)
}
}
- StructType(newFields.toSeq.sortBy {
- case StructField(name, _, _) => name
- })
+ StructType(newFields.toSeq.sortBy(_.name))
}
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
@@ -413,7 +408,7 @@ private[sql] object JsonRDD extends Logging {
// TODO: Reuse the row instead of creating a new one for every record.
val row = new GenericMutableRow(schema.fields.length)
schema.fields.zipWithIndex.foreach {
- case (StructField(name, dataType, _), i) =>
+ case (StructField(name, dataType, _, _), i) =>
row.update(i, json.get(name).flatMap(v => Option(v)).map(
enforceCorrectType(_, dataType)).orNull)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index e98d151286..f0e57e2a74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -125,6 +125,9 @@ package object sql {
@DeveloperApi
type DataType = catalyst.types.DataType
+ @DeveloperApi
+ val DataType = catalyst.types.DataType
+
/**
* :: DeveloperApi ::
*
@@ -414,4 +417,24 @@ package object sql {
*/
@DeveloperApi
val StructField = catalyst.types.StructField
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
+ * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and
+ * Array[Metadata]. JSON is used for serialization.
+ *
+ * The default constructor is private. User should use either [[MetadataBuilder]] or
+ * [[Metadata$#fromJson]] to create Metadata instances.
+ *
+ * @param map an immutable map that stores the data
+ */
+ @DeveloperApi
+ type Metadata = catalyst.util.Metadata
+
+ /**
+ * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former.
+ */
+ type MetadataBuilder = catalyst.util.MetadataBuilder
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 609f7db562..142598c904 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.types.util
import org.apache.spark.sql._
-import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField}
+import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
import scala.collection.JavaConverters._
@@ -31,7 +31,8 @@ protected[sql] object DataTypeConversions {
JDataType.createStructField(
scalaStructField.name,
asJavaDataType(scalaStructField.dataType),
- scalaStructField.nullable)
+ scalaStructField.nullable,
+ (new JMetaDataBuilder).withMetadata(scalaStructField.metadata).build())
}
/**
@@ -68,7 +69,8 @@ protected[sql] object DataTypeConversions {
StructField(
javaStructField.getName,
asScalaDataType(javaStructField.getDataType),
- javaStructField.isNullable)
+ javaStructField.isNullable,
+ javaStructField.getMetadata)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 100ecb45e9..6c9db639c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
-
class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
@@ -79,8 +77,12 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(ArrayType(StringType, false))
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ val metadata = new MetadataBuilder()
+ .putString("name", "age")
+ .build()
checkDataTypeJsonRepr(
StructType(Seq(
StructField("a", IntegerType, nullable = true),
- StructField("b", ArrayType(DoubleType), nullable = false))))
+ StructField("b", ArrayType(DoubleType), nullable = false),
+ StructField("c", DoubleType, nullable = false, metadata))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4acd92d33d..6befe1b755 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.sql
+import java.util.TimeZone
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.joins.BroadcastHashJoin
-import org.apache.spark.sql.test._
-import org.scalatest.BeforeAndAfterAll
-import java.util.TimeZone
-/* Implicits */
-import TestSQLContext._
-import TestData._
+import org.apache.spark.sql.test.TestSQLContext._
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
@@ -697,6 +696,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
("true", "false") :: Nil)
}
+ test("metadata is propagated correctly") {
+ val person = sql("SELECT * FROM person")
+ val schema = person.schema
+ val docKey = "doc"
+ val docValue = "first name"
+ val metadata = new MetadataBuilder()
+ .putString(docKey, docValue)
+ .build()
+ val schemaWithMeta = new StructType(Seq(
+ schema("id"), schema("name").copy(metadata = metadata), schema("age")))
+ val personWithMeta = applySchema(person, schemaWithMeta)
+ def validateMetadata(rdd: SchemaRDD): Unit = {
+ assert(rdd.schema("name").metadata.getString(docKey) == docValue)
+ }
+ personWithMeta.registerTempTable("personWithMeta")
+ validateMetadata(personWithMeta.select('name))
+ validateMetadata(personWithMeta.select("name".attr))
+ validateMetadata(personWithMeta.select('id, 'name))
+ validateMetadata(sql("SELECT * FROM personWithMeta"))
+ validateMetadata(sql("SELECT id, name FROM personWithMeta"))
+ validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
+ validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
+ }
+
test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index c4dd3e860f..836dd17fcc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -166,4 +166,15 @@ object TestData {
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
withEmptyParts.registerTempTable("withEmptyParts")
+
+ case class Person(id: Int, name: String, age: Int)
+ case class Salary(personId: Int, salary: Double)
+ val person = TestSQLContext.sparkContext.parallelize(
+ Person(0, "mike", 30) ::
+ Person(1, "jim", 20) :: Nil)
+ person.registerTempTable("person")
+ val salary = TestSQLContext.sparkContext.parallelize(
+ Salary(0, 2000.0) ::
+ Salary(1, 1000.0) :: Nil)
+ salary.registerTempTable("salary")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
index 8415af41be..e0e0ff9cb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql.api.java
-import org.apache.spark.sql.types.util.DataTypeConversions
import org.scalatest.FunSuite
-import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField}
-import org.apache.spark.sql.{StructType => SStructType}
-import DataTypeConversions._
+import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField, StructType => SStructType}
+import org.apache.spark.sql.types.util.DataTypeConversions._
class ScalaSideDataTypeConversionSuite extends FunSuite {
@@ -67,11 +65,15 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(simpleScalaStructType)
// Complex StructType.
+ val metadata = new MetadataBuilder()
+ .putString("name", "age")
+ .build()
val complexScalaStructType = SStructType(
SStructField("simpleArray", simpleScalaArrayType, true) ::
SStructField("simpleMap", simpleScalaMapType, true) ::
SStructField("simpleStruct", simpleScalaStructType, true) ::
- SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil)
+ SStructField("boolean", org.apache.spark.sql.BooleanType, false) ::
+ SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil)
checkDataType(complexScalaStructType)
// Complex ArrayType.