aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorvidmantas zemleris <vidmantas@vinted.com>2015-04-21 14:47:09 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-21 14:47:09 -0700
commit2e8c6ca47df14681c1110f0736234ce76a3eca9b (patch)
treee233586bbe6b07e810df14f3a9e4cdd6407e634b /sql
parent04bf34e34f22e3d7e972fe755251774fc6a6d52e (diff)
downloadspark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.gz
spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.tar.bz2
spark-2e8c6ca47df14681c1110f0736234ce76a3eca9b.zip
[SPARK-6994] Allow to fetch field values by name in sql.Row
It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index. This tries to solve this issue. Author: vidmantas zemleris <vidmantas@vinted.com> Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits: 6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row 9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala71
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala10
6 files changed, 137 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index ac8a782976..4190b7ffe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -306,6 +306,38 @@ trait Row extends Serializable {
*/
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
+ /**
+ * Returns the value of a given fieldName.
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ * @throws ClassCastException when data type does not match.
+ */
+ def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
+
+ /**
+ * Returns the index of a given field name.
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ */
+ def fieldIndex(name: String): Int = {
+ throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
+ }
+
+ /**
+ * Returns a Map(name -> value) for the requested fieldNames
+ *
+ * @throws UnsupportedOperationException when schema is not defined.
+ * @throws IllegalArgumentException when fieldName do not exist.
+ * @throws ClassCastException when data type does not match.
+ */
+ def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
+ fieldNames.map { name =>
+ name -> getAs[T](name)
+ }.toMap
+ }
+
override def toString(): String = s"[${this.mkString(",")}]"
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index b6ec7d3417..981373477a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
/** No-arg constructor for serialization. */
protected def this() = this(null, null)
+
+ override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index a108413497..7cd7bd1914 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+ private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
@@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(fields.filter(f => names.contains(f.name)))
}
+ /**
+ * Returns index of a given field
+ */
+ def fieldIndex(name: String): Int = {
+ nameToIndex.getOrElse(name,
+ throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ }
+
protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
new file mode 100644
index 0000000000..bbb9739e9c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
+import org.apache.spark.sql.types._
+import org.scalatest.{Matchers, FunSpec}
+
+class RowTest extends FunSpec with Matchers {
+
+ val schema = StructType(
+ StructField("col1", StringType) ::
+ StructField("col2", StringType) ::
+ StructField("col3", IntegerType) :: Nil)
+ val values = Array("value1", "value2", 1)
+
+ val sampleRow: Row = new GenericRowWithSchema(values, schema)
+ val noSchemaRow: Row = new GenericRow(values)
+
+ describe("Row (without schema)") {
+ it("throws an exception when accessing by fieldName") {
+ intercept[UnsupportedOperationException] {
+ noSchemaRow.fieldIndex("col1")
+ }
+ intercept[UnsupportedOperationException] {
+ noSchemaRow.getAs("col1")
+ }
+ }
+ }
+
+ describe("Row (with schema)") {
+ it("fieldIndex(name) returns field index") {
+ sampleRow.fieldIndex("col1") shouldBe 0
+ sampleRow.fieldIndex("col3") shouldBe 2
+ }
+
+ it("getAs[T] retrieves a value by fieldname") {
+ sampleRow.getAs[String]("col1") shouldBe "value1"
+ sampleRow.getAs[Int]("col3") shouldBe 1
+ }
+
+ it("Accessing non existent field throws an exception") {
+ intercept[IllegalArgumentException] {
+ sampleRow.getAs[String]("non_existent")
+ }
+ }
+
+ it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") {
+ val expected = Map(
+ "col1" -> "value1",
+ "col2" -> "value2"
+ )
+ sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index a1341ea13d..d797510f36 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
}
}
+ test("extract field index from a StructType") {
+ val struct = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ assert(struct.fieldIndex("a") === 0)
+ assert(struct.fieldIndex("b") === 1)
+
+ intercept[IllegalArgumentException] {
+ struct.fieldIndex("non_existent")
+ }
+ }
+
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index bf6cf1321a..fb3ba4bc1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -62,4 +62,14 @@ class RowSuite extends FunSuite {
val de = instance.deserialize(ser).asInstanceOf[Row]
assert(de === row)
}
+
+ test("get values by field name on Row created via .toDF") {
+ val row = Seq((1, Seq(1))).toDF("a", "b").first()
+ assert(row.getAs[Int]("a") === 1)
+ assert(row.getAs[Seq[Int]]("b") === Seq(1))
+
+ intercept[IllegalArgumentException]{
+ row.getAs[Int]("c")
+ }
+ }
}