diff options
author | vidmantas zemleris <vidmantas@vinted.com> | 2015-04-21 14:47:09 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-04-21 14:47:09 -0700 |
commit | 2e8c6ca47df14681c1110f0736234ce76a3eca9b (patch) | |
tree | e233586bbe6b07e810df14f3a9e4cdd6407e634b /sql/catalyst/src/main | |
parent | 04bf34e34f22e3d7e972fe755251774fc6a6d52e (diff) | |
download | spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.gz spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.bz2 spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.zip |
[SPARK-6994] Allow to fetch field values by name in sql.Row
It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index.
This tries to solve this issue.
Author: vidmantas zemleris <vidmantas@vinted.com>
Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits:
6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row
9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)
Diffstat (limited to 'sql/catalyst/src/main')
3 files changed, 43 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ac8a782976..4190b7ffe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -306,6 +306,38 @@ trait Row extends Serializable { */ def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName)) + + /** + * Returns the index of a given field name. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + */ + def fieldIndex(name: String): Int = { + throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.") + } + + /** + * Returns a Map(name -> value) for the requested fieldNames + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = { + fieldNames.map { name => + name -> getAs[T](name) + }.toMap + } + override def toString(): String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b6ec7d3417..981373477a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) /** No-arg constructor for serialization. */ protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index a108413497..7cd7bd1914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not @@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(fields.filter(f => names.contains(f.name))) } + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) |