aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorvidmantas zemleris <vidmantas@vinted.com>2015-04-21 14:47:09 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-21 14:47:09 -0700
commit2e8c6ca47df14681c1110f0736234ce76a3eca9b (patch)
treee233586bbe6b07e810df14f3a9e4cdd6407e634b /sql/catalyst/src/main
parent04bf34e34f22e3d7e972fe755251774fc6a6d52e (diff)
downloadspark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.gz
spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.bz2
spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.zip
[SPARK-6994] Allow to fetch field values by name in sql.Row
It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index. This tries to solve this issue. Author: vidmantas zemleris <vidmantas@vinted.com> Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits: 6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row 9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala9
3 files changed, 43 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index ac8a782976..4190b7ffe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -306,6 +306,38 @@ trait Row extends Serializable {
*/
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
+ /**
+ * Returns the value of a given fieldName.
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ * @throws ClassCastException when data type does not match.
+ */
+ def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
+
+ /**
+ * Returns the index of a given field name.
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ */
+ def fieldIndex(name: String): Int = {
+ throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
+ }
+
+ /**
+ * Returns a Map(name -> value) for the requested fieldNames
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ * @throws ClassCastException when data type does not match.
+ */
+ def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
+ fieldNames.map { name =>
+ name -> getAs[T](name)
+ }.toMap
+ }
+
override def toString(): String = s"[${this.mkString(",")}]"
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index b6ec7d3417..981373477a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
/** No-arg constructor for serialization. */
protected def this() = this(null, null)
+
+ override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index a108413497..7cd7bd1914 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+ private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
@@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(fields.filter(f => names.contains(f.name)))
}
+ /**
+ * Returns index of a given field
+ */
+ def fieldIndex(name: String): Int = {
+ nameToIndex.getOrElse(name,
+ throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ }
+
protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())