aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-01 14:37:00 -0700
committerMichael Armbrust <michael@databricks.com>2014-11-01 14:37:00 -0700
commit1d4f3552037cb667971bea2e5078d8b3ce6c2eae (patch)
treeb4318e8bddec8a5fceaf41ce5a5fd1c3fdab2f41 /sql/catalyst
parent59e626c701227634336110e1bc23afd94c535ede (diff)
downloadspark-1d4f3552037cb667971bea2e5078d8b3ce6c2eae.tar.gz
spark-1d4f3552037cb667971bea2e5078d8b3ce6c2eae.tar.bz2
spark-1d4f3552037cb667971bea2e5078d8b3ce6c2eae.zip
[SPARK-3569][SQL] Add metadata field to StructField
Add `metadata: Metadata` to `StructField` to store extra information of columns. `Metadata` is a simple wrapper over `Map[String, Any]` with value types restricted to Boolean, Long, Double, String, Metadata, and arrays of those types. SerDe is via JSON. Metadata is preserved through simple operations like `SELECT`. marmbrus liancheng Author: Xiangrui Meng <meng@databricks.com> Author: Michael Armbrust <michael@databricks.com> Closes #2701 from mengxr/structfield-metadata and squashes the following commits: dedda56 [Xiangrui Meng] merge remote 5ef930a [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into structfield-metadata c35203f [Xiangrui Meng] Merge pull request #1 from marmbrus/pr/2701 886b85c [Michael Armbrust] Expose Metadata and MetadataBuilder through the public scala and java packages. 589f314 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into structfield-metadata 1e2abcf [Xiangrui Meng] change default value of metadata to None in python 611d3c2 [Xiangrui Meng] move metadata from Expr to NamedExpr ddfcfad [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into structfield-metadata a438440 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into structfield-metadata 4266f4d [Xiangrui Meng] add StructField.toString back for backward compatibility 3f49aab [Xiangrui Meng] remove StructField.toString 24a9f80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into structfield-metadata 473a7c5 [Xiangrui Meng] merge master c9d7301 [Xiangrui Meng] organize imports 1fcbf13 [Xiangrui Meng] change metadata type in StructField for Scala/Java 60cc131 [Xiangrui Meng] add doc and header 60614c7 [Xiangrui Meng] add metadata e42c452 [Xiangrui Meng] merge master 93518fb [Xiangrui Meng] support metadata in python 905bb89 [Xiangrui Meng] java conversions 618e349 [Xiangrui Meng] make tests work in scala 61b8e0f [Xiangrui Meng] merge master 7e5a322 [Xiangrui Meng] do not output metadata in StructField.toString c41a664 [Xiangrui Meng] merge master d8af0ed [Xiangrui Meng] move tests to SQLQuerySuite 67fdebb [Xiangrui Meng] add test on join d65072e [Xiangrui Meng] remove Map.empty 367d237 [Xiangrui Meng] add test c194d5e [Xiangrui Meng] add metadata field to StructField and Attribute
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala255
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala82
7 files changed, 382 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d76c743d3f..75923d9e8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -46,7 +46,7 @@ object ScalaReflection {
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
- s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+ s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
}
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 1eb260efa6..39b120e8de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
+import org.apache.spark.sql.catalyst.util.Metadata
abstract class Expression extends TreeNode[Expression] {
self: Product =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 9c865254e0..ab0701fd9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -43,7 +43,7 @@ abstract class Generator extends Expression {
override type EvaluatedType = TraversableOnce[Row]
override lazy val dataType =
- ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))
+ ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
override def nullable = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index fe13a661f6..3310566087 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.util.Metadata
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
@@ -43,6 +44,9 @@ abstract class NamedExpression extends Expression {
def toAttribute: Attribute
+ /** Returns the metadata when an expression is a reference to another expression with metadata. */
+ def metadata: Metadata = Metadata.empty
+
protected def typeSuffix =
if (resolved) {
dataType match {
@@ -88,10 +92,16 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType
override def nullable = child.nullable
+ override def metadata: Metadata = {
+ child match {
+ case named: NamedExpression => named.metadata
+ case _ => Metadata.empty
+ }
+ }
override def toAttribute = {
if (resolved) {
- AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
+ AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
} else {
UnresolvedAttribute(name)
}
@@ -108,15 +118,20 @@ case class Alias(child: Expression, name: String)
* @param name The name of this attribute, should only be used during analysis or for debugging.
* @param dataType The [[DataType]] of this attribute.
* @param nullable True if null is a valid value for this attribute.
+ * @param metadata The metadata of this attribute.
* @param exprId A globally unique id used to check if different AttributeReferences refer to the
* same attribute.
* @param qualifiers a list of strings that can be used to referred to this attribute in a fully
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
*/
-case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true)
- (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
- extends Attribute with trees.LeafNode[Expression] {
+case class AttributeReference(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean = true,
+ override val metadata: Metadata = Metadata.empty)(
+ val exprId: ExprId = NamedExpression.newExprId,
+ val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
@@ -128,10 +143,12 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
var h = 17
h = h * 37 + exprId.hashCode()
h = h * 37 + dataType.hashCode()
+ h = h * 37 + metadata.hashCode()
h
}
- override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+ override def newInstance() =
+ AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -140,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
if (nullable == newNullability) {
this
} else {
- AttributeReference(name, dataType, newNullability)(exprId, qualifiers)
+ AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers)
}
}
@@ -159,7 +176,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
if (newQualifiers.toSet == qualifiers.toSet) {
this
} else {
- AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
+ AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 4e6e1166bf..6069f9b0a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -24,16 +24,16 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
-import org.json4s.JsonAST.JValue
import org.json4s._
+import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
-
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -70,10 +70,11 @@ object DataType {
private def parseStructField(json: JValue): StructField = json match {
case JSortedObject(
+ ("metadata", metadata: JObject),
("name", JString(name)),
("nullable", JBool(nullable)),
("type", dataType: JValue)) =>
- StructField(name, parseDataType(dataType), nullable)
+ StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata))
}
@deprecated("Use DataType.fromJson instead", "1.2.0")
@@ -388,24 +389,34 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
* @param name The name of this field.
* @param dataType The data type of this field.
* @param nullable Indicates if values of this field can be `null` values.
+ * @param metadata The metadata of this field. The metadata should be preserved during
+ * transformation if the content of the column is not modified, e.g, in selection.
*/
-case class StructField(name: String, dataType: DataType, nullable: Boolean) {
+case class StructField(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean,
+ metadata: Metadata = Metadata.empty) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+ // override the default toString to be compatible with legacy parquet files.
+ override def toString: String = s"StructField($name,$dataType,$nullable)"
+
private[sql] def jsonValue: JValue = {
("name" -> name) ~
("type" -> dataType.jsonValue) ~
- ("nullable" -> nullable)
+ ("nullable" -> nullable) ~
+ ("metadata" -> metadata.jsonValue)
}
}
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
- StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
+ StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
}
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -439,7 +450,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
}
protected[sql] def toAttributes =
- fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+ fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
def treeString: String = {
val builder = new StringBuilder
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala
new file mode 100644
index 0000000000..2f2082fa3c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.collection.mutable
+
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+/**
+ * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
+ * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and
+ * Array[Metadata]. JSON is used for serialization.
+ *
+ * The default constructor is private. User should use either [[MetadataBuilder]] or
+ * [[Metadata$#fromJson]] to create Metadata instances.
+ *
+ * @param map an immutable map that stores the data
+ */
+sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable {
+
+ /** Gets a Long. */
+ def getLong(key: String): Long = get(key)
+
+ /** Gets a Double. */
+ def getDouble(key: String): Double = get(key)
+
+ /** Gets a Boolean. */
+ def getBoolean(key: String): Boolean = get(key)
+
+ /** Gets a String. */
+ def getString(key: String): String = get(key)
+
+ /** Gets a Metadata. */
+ def getMetadata(key: String): Metadata = get(key)
+
+ /** Gets a Long array. */
+ def getLongArray(key: String): Array[Long] = get(key)
+
+ /** Gets a Double array. */
+ def getDoubleArray(key: String): Array[Double] = get(key)
+
+ /** Gets a Boolean array. */
+ def getBooleanArray(key: String): Array[Boolean] = get(key)
+
+ /** Gets a String array. */
+ def getStringArray(key: String): Array[String] = get(key)
+
+ /** Gets a Metadata array. */
+ def getMetadataArray(key: String): Array[Metadata] = get(key)
+
+ /** Converts to its JSON representation. */
+ def json: String = compact(render(jsonValue))
+
+ override def toString: String = json
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: Metadata =>
+ if (map.keySet == that.map.keySet) {
+ map.keys.forall { k =>
+ (map(k), that.map(k)) match {
+ case (v0: Array[_], v1: Array[_]) =>
+ v0.view == v1.view
+ case (v0, v1) =>
+ v0 == v1
+ }
+ }
+ } else {
+ false
+ }
+ case other =>
+ false
+ }
+ }
+
+ override def hashCode: Int = Metadata.hash(this)
+
+ private def get[T](key: String): T = {
+ map(key).asInstanceOf[T]
+ }
+
+ private[sql] def jsonValue: JValue = Metadata.toJsonValue(this)
+}
+
+object Metadata {
+
+ /** Returns an empty Metadata. */
+ def empty: Metadata = new Metadata(Map.empty)
+
+ /** Creates a Metadata instance from JSON. */
+ def fromJson(json: String): Metadata = {
+ fromJObject(parse(json).asInstanceOf[JObject])
+ }
+
+ /** Creates a Metadata instance from JSON AST. */
+ private[sql] def fromJObject(jObj: JObject): Metadata = {
+ val builder = new MetadataBuilder
+ jObj.obj.foreach {
+ case (key, JInt(value)) =>
+ builder.putLong(key, value.toLong)
+ case (key, JDouble(value)) =>
+ builder.putDouble(key, value)
+ case (key, JBool(value)) =>
+ builder.putBoolean(key, value)
+ case (key, JString(value)) =>
+ builder.putString(key, value)
+ case (key, o: JObject) =>
+ builder.putMetadata(key, fromJObject(o))
+ case (key, JArray(value)) =>
+ if (value.isEmpty) {
+ // If it is an empty array, we cannot infer its element type. We put an empty Array[Long].
+ builder.putLongArray(key, Array.empty)
+ } else {
+ value.head match {
+ case _: JInt =>
+ builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray)
+ case _: JDouble =>
+ builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray)
+ case _: JBool =>
+ builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray)
+ case _: JString =>
+ builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray)
+ case _: JObject =>
+ builder.putMetadataArray(
+ key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
+ case other =>
+ throw new RuntimeException(s"Do not support array of type ${other.getClass}.")
+ }
+ }
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ builder.build()
+ }
+
+ /** Converts to JSON AST. */
+ private def toJsonValue(obj: Any): JValue = {
+ obj match {
+ case map: Map[_, _] =>
+ val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) }
+ JObject(fields)
+ case arr: Array[_] =>
+ val values = arr.toList.map(toJsonValue)
+ JArray(values)
+ case x: Long =>
+ JInt(x)
+ case x: Double =>
+ JDouble(x)
+ case x: Boolean =>
+ JBool(x)
+ case x: String =>
+ JString(x)
+ case x: Metadata =>
+ toJsonValue(x.map)
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ }
+
+ /** Computes the hash code for the types we support. */
+ private def hash(obj: Any): Int = {
+ obj match {
+ case map: Map[_, _] =>
+ map.mapValues(hash).##
+ case arr: Array[_] =>
+ // Seq.empty[T] has the same hashCode regardless of T.
+ arr.toSeq.map(hash).##
+ case x: Long =>
+ x.##
+ case x: Double =>
+ x.##
+ case x: Boolean =>
+ x.##
+ case x: String =>
+ x.##
+ case x: Metadata =>
+ hash(x.map)
+ case other =>
+ throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ }
+ }
+}
+
+/**
+ * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former.
+ */
+class MetadataBuilder {
+
+ private val map: mutable.Map[String, Any] = mutable.Map.empty
+
+ /** Returns the immutable version of this map. Used for java interop. */
+ protected def getMap = map.toMap
+
+ /** Include the content of an existing [[Metadata]] instance. */
+ def withMetadata(metadata: Metadata): this.type = {
+ map ++= metadata.map
+ this
+ }
+
+ /** Puts a Long. */
+ def putLong(key: String, value: Long): this.type = put(key, value)
+
+ /** Puts a Double. */
+ def putDouble(key: String, value: Double): this.type = put(key, value)
+
+ /** Puts a Boolean. */
+ def putBoolean(key: String, value: Boolean): this.type = put(key, value)
+
+ /** Puts a String. */
+ def putString(key: String, value: String): this.type = put(key, value)
+
+ /** Puts a [[Metadata]]. */
+ def putMetadata(key: String, value: Metadata): this.type = put(key, value)
+
+ /** Puts a Long array. */
+ def putLongArray(key: String, value: Array[Long]): this.type = put(key, value)
+
+ /** Puts a Double array. */
+ def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value)
+
+ /** Puts a Boolean array. */
+ def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value)
+
+ /** Puts a String array. */
+ def putStringArray(key: String, value: Array[String]): this.type = put(key, value)
+
+ /** Puts a [[Metadata]] array. */
+ def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value)
+
+ /** Builds the [[Metadata]] instance. */
+ def build(): Metadata = {
+ new Metadata(map.toMap)
+ }
+
+ private def put(key: String, value: Any): this.type = {
+ map.put(key, value)
+ this
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
new file mode 100644
index 0000000000..0063d31666
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.json4s.jackson.JsonMethods.parse
+import org.scalatest.FunSuite
+
+class MetadataSuite extends FunSuite {
+
+ val baseMetadata = new MetadataBuilder()
+ .putString("purpose", "ml")
+ .putBoolean("isBase", true)
+ .build()
+
+ val summary = new MetadataBuilder()
+ .putLong("numFeatures", 10L)
+ .build()
+
+ val age = new MetadataBuilder()
+ .putString("name", "age")
+ .putLong("index", 1L)
+ .putBoolean("categorical", false)
+ .putDouble("average", 45.0)
+ .build()
+
+ val gender = new MetadataBuilder()
+ .putString("name", "gender")
+ .putLong("index", 5)
+ .putBoolean("categorical", true)
+ .putStringArray("categories", Array("male", "female"))
+ .build()
+
+ val metadata = new MetadataBuilder()
+ .withMetadata(baseMetadata)
+ .putBoolean("isBase", false) // overwrite an existing key
+ .putMetadata("summary", summary)
+ .putLongArray("long[]", Array(0L, 1L))
+ .putDoubleArray("double[]", Array(3.0, 4.0))
+ .putBooleanArray("boolean[]", Array(true, false))
+ .putMetadataArray("features", Array(age, gender))
+ .build()
+
+ test("metadata builder and getters") {
+ assert(age.getLong("index") === 1L)
+ assert(age.getDouble("average") === 45.0)
+ assert(age.getBoolean("categorical") === false)
+ assert(age.getString("name") === "age")
+ assert(metadata.getString("purpose") === "ml")
+ assert(metadata.getBoolean("isBase") === false)
+ assert(metadata.getMetadata("summary") === summary)
+ assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L))
+ assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0))
+ assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false))
+ assert(gender.getStringArray("categories").toSeq === Seq("male", "female"))
+ assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender))
+ }
+
+ test("metadata json conversion") {
+ val json = metadata.json
+ withClue("toJson must produce a valid JSON string") {
+ parse(json)
+ }
+ val parsed = Metadata.fromJson(json)
+ assert(parsed === metadata)
+ assert(parsed.## === metadata.##)
+ }
+}