aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authoraokolnychyi <okolnychyyanton@gmail.com>2017-01-24 22:13:17 -0800
committergatorsmile <gatorsmile@gmail.com>2017-01-24 22:13:17 -0800
commit3fdce814348fae34df379a6ab9655dbbb2c3427c (patch)
treec49543e0d997f7b671b25c309be5330fb00eb3d7 /examples/src
parent40a4cfc7c7911107d1cf7a2663469031dcf1f576 (diff)
downloadspark-3fdce814348fae34df379a6ab9655dbbb2c3427c.tar.gz
spark-3fdce814348fae34df379a6ab9655dbbb2c3427c.tar.bz2
spark-3fdce814348fae34df379a6ab9655dbbb2c3427c.zip
[SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide
## What changes were proposed in this pull request? - A separate subsection for Aggregations under “Getting Started” in the Spark SQL programming guide. It mentions which aggregate functions are predefined and how users can create their own. - Examples of using the `UserDefinedAggregateFunction` abstract class for untyped aggregations in Java and Scala. - Examples of using the `Aggregator` abstract class for type-safe aggregations in Java and Scala. - Python is not covered. - The PR might not resolve the ticket since I do not know what exactly was planned by the author. In total, there are four new standalone examples that can be executed via `spark-submit` or `run-example`. The updated Spark SQL programming guide references to these examples and does not contain hard-coded snippets. ## How was this patch tested? The patch was tested locally by building the docs. The examples were run as well. ![image](https://cloud.githubusercontent.com/assets/6235869/21292915/04d9d084-c515-11e6-811a-999d598dffba.png) Author: aokolnychyi <okolnychyyanton@gmail.com> Closes #16329 from aokolnychyi/SPARK-16046.
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java160
-rw-r--r--examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java132
-rw-r--r--examples/src/main/resources/employees.json4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala91
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala100
5 files changed, 487 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java
new file mode 100644
index 0000000000..78e9011be4
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql;
+
+// $example on:typed_custom_aggregation$
+import java.io.Serializable;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.expressions.Aggregator;
+// $example off:typed_custom_aggregation$
+
+public class JavaUserDefinedTypedAggregation {
+
+ // $example on:typed_custom_aggregation$
+ public static class Employee implements Serializable {
+ private String name;
+ private long salary;
+
+ // Constructors, getters, setters...
+ // $example off:typed_custom_aggregation$
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public long getSalary() {
+ return salary;
+ }
+
+ public void setSalary(long salary) {
+ this.salary = salary;
+ }
+ // $example on:typed_custom_aggregation$
+ }
+
+ public static class Average implements Serializable {
+ private long sum;
+ private long count;
+
+ // Constructors, getters, setters...
+ // $example off:typed_custom_aggregation$
+ public Average() {
+ }
+
+ public Average(long sum, long count) {
+ this.sum = sum;
+ this.count = count;
+ }
+
+ public long getSum() {
+ return sum;
+ }
+
+ public void setSum(long sum) {
+ this.sum = sum;
+ }
+
+ public long getCount() {
+ return count;
+ }
+
+ public void setCount(long count) {
+ this.count = count;
+ }
+ // $example on:typed_custom_aggregation$
+ }
+
+ public static class MyAverage extends Aggregator<Employee, Average, Double> {
+ // A zero value for this aggregation. Should satisfy the property that any b + zero = b
+ public Average zero() {
+ return new Average(0L, 0L);
+ }
+ // Combine two values to produce a new value. For performance, the function may modify `buffer`
+ // and return it instead of constructing a new object
+ public Average reduce(Average buffer, Employee employee) {
+ long newSum = buffer.getSum() + employee.getSalary();
+ long newCount = buffer.getCount() + 1;
+ buffer.setSum(newSum);
+ buffer.setCount(newCount);
+ return buffer;
+ }
+ // Merge two intermediate values
+ public Average merge(Average b1, Average b2) {
+ long mergedSum = b1.getSum() + b2.getSum();
+ long mergedCount = b1.getCount() + b2.getCount();
+ b1.setSum(mergedSum);
+ b1.setCount(mergedCount);
+ return b1;
+ }
+ // Transform the output of the reduction
+ public Double finish(Average reduction) {
+ return ((double) reduction.getSum()) / reduction.getCount();
+ }
+ // Specifies the Encoder for the intermediate value type
+ public Encoder<Average> bufferEncoder() {
+ return Encoders.bean(Average.class);
+ }
+ // Specifies the Encoder for the final output value type
+ public Encoder<Double> outputEncoder() {
+ return Encoders.DOUBLE();
+ }
+ }
+ // $example off:typed_custom_aggregation$
+
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("Java Spark SQL user-defined Datasets aggregation example")
+ .getOrCreate();
+
+ // $example on:typed_custom_aggregation$
+ Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
+ String path = "examples/src/main/resources/employees.json";
+ Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
+ ds.show();
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ MyAverage myAverage = new MyAverage();
+ // Convert the function to a `TypedColumn` and give it a name
+ TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
+ Dataset<Double> result = ds.select(averageSalary);
+ result.show();
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:typed_custom_aggregation$
+ spark.stop();
+ }
+
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java
new file mode 100644
index 0000000000..6da60a1fc6
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql;
+
+// $example on:untyped_custom_aggregation$
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off:untyped_custom_aggregation$
+
+public class JavaUserDefinedUntypedAggregation {
+
+ // $example on:untyped_custom_aggregation$
+ public static class MyAverage extends UserDefinedAggregateFunction {
+
+ private StructType inputSchema;
+ private StructType bufferSchema;
+
+ public MyAverage() {
+ List<StructField> inputFields = new ArrayList<>();
+ inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
+ inputSchema = DataTypes.createStructType(inputFields);
+
+ List<StructField> bufferFields = new ArrayList<>();
+ bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
+ bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
+ bufferSchema = DataTypes.createStructType(bufferFields);
+ }
+ // Data types of input arguments of this aggregate function
+ public StructType inputSchema() {
+ return inputSchema;
+ }
+ // Data types of values in the aggregation buffer
+ public StructType bufferSchema() {
+ return bufferSchema;
+ }
+ // The data type of the returned value
+ public DataType dataType() {
+ return DataTypes.DoubleType;
+ }
+ // Whether this function always returns the same output on the identical input
+ public boolean deterministic() {
+ return true;
+ }
+ // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
+ // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
+ // the opportunity to update its values. Note that arrays and maps inside the buffer are still
+ // immutable.
+ public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, 0L);
+ buffer.update(1, 0L);
+ }
+ // Updates the given aggregation buffer `buffer` with new input data from `input`
+ public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ long updatedSum = buffer.getLong(0) + input.getLong(0);
+ long updatedCount = buffer.getLong(1) + 1;
+ buffer.update(0, updatedSum);
+ buffer.update(1, updatedCount);
+ }
+ }
+ // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
+ public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
+ long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
+ buffer1.update(0, mergedSum);
+ buffer1.update(1, mergedCount);
+ }
+ // Calculates the final result
+ public Double evaluate(Row buffer) {
+ return ((double) buffer.getLong(0)) / buffer.getLong(1);
+ }
+ }
+ // $example off:untyped_custom_aggregation$
+
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("Java Spark SQL user-defined DataFrames aggregation example")
+ .getOrCreate();
+
+ // $example on:untyped_custom_aggregation$
+ // Register the function to access it
+ spark.udf().register("myAverage", new MyAverage());
+
+ Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
+ df.createOrReplaceTempView("employees");
+ df.show();
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
+ result.show();
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:untyped_custom_aggregation$
+
+ spark.stop();
+ }
+}
diff --git a/examples/src/main/resources/employees.json b/examples/src/main/resources/employees.json
new file mode 100644
index 0000000000..6b2e6329a1
--- /dev/null
+++ b/examples/src/main/resources/employees.json
@@ -0,0 +1,4 @@
+{"name":"Michael", "salary":3000}
+{"name":"Andy", "salary":4500}
+{"name":"Justin", "salary":3500}
+{"name":"Berta", "salary":4000}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala
new file mode 100644
index 0000000000..ac617d19d3
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql
+
+// $example on:typed_custom_aggregation$
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.SparkSession
+// $example off:typed_custom_aggregation$
+
+object UserDefinedTypedAggregation {
+
+ // $example on:typed_custom_aggregation$
+ case class Employee(name: String, salary: Long)
+ case class Average(var sum: Long, var count: Long)
+
+ object MyAverage extends Aggregator[Employee, Average, Double] {
+ // A zero value for this aggregation. Should satisfy the property that any b + zero = b
+ def zero: Average = Average(0L, 0L)
+ // Combine two values to produce a new value. For performance, the function may modify `buffer`
+ // and return it instead of constructing a new object
+ def reduce(buffer: Average, employee: Employee): Average = {
+ buffer.sum += employee.salary
+ buffer.count += 1
+ buffer
+ }
+ // Merge two intermediate values
+ def merge(b1: Average, b2: Average): Average = {
+ b1.sum += b2.sum
+ b1.count += b2.count
+ b1
+ }
+ // Transform the output of the reduction
+ def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
+ // Specifies the Encoder for the intermediate value type
+ def bufferEncoder: Encoder[Average] = Encoders.product
+ // Specifies the Encoder for the final output value type
+ def outputEncoder: Encoder[Double] = Encoders.scalaDouble
+ }
+ // $example off:typed_custom_aggregation$
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("Spark SQL user-defined Datasets aggregation example")
+ .getOrCreate()
+
+ import spark.implicits._
+
+ // $example on:typed_custom_aggregation$
+ val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
+ ds.show()
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ // Convert the function to a `TypedColumn` and give it a name
+ val averageSalary = MyAverage.toColumn.name("average_salary")
+ val result = ds.select(averageSalary)
+ result.show()
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:typed_custom_aggregation$
+
+ spark.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala
new file mode 100644
index 0000000000..9c9ebc5516
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql
+
+// $example on:untyped_custom_aggregation$
+import org.apache.spark.sql.expressions.MutableAggregationBuffer
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.SparkSession
+// $example off:untyped_custom_aggregation$
+
+object UserDefinedUntypedAggregation {
+
+ // $example on:untyped_custom_aggregation$
+ object MyAverage extends UserDefinedAggregateFunction {
+ // Data types of input arguments of this aggregate function
+ def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
+ // Data types of values in the aggregation buffer
+ def bufferSchema: StructType = {
+ StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
+ }
+ // The data type of the returned value
+ def dataType: DataType = DoubleType
+ // Whether this function always returns the same output on the identical input
+ def deterministic: Boolean = true
+ // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
+ // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
+ // the opportunity to update its values. Note that arrays and maps inside the buffer are still
+ // immutable.
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = 0L
+ buffer(1) = 0L
+ }
+ // Updates the given aggregation buffer `buffer` with new input data from `input`
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (!input.isNullAt(0)) {
+ buffer(0) = buffer.getLong(0) + input.getLong(0)
+ buffer(1) = buffer.getLong(1) + 1
+ }
+ }
+ // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
+ buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
+ }
+ // Calculates the final result
+ def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
+ }
+ // $example off:untyped_custom_aggregation$
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("Spark SQL user-defined DataFrames aggregation example")
+ .getOrCreate()
+
+ // $example on:untyped_custom_aggregation$
+ // Register the function to access it
+ spark.udf.register("myAverage", MyAverage)
+
+ val df = spark.read.json("examples/src/main/resources/employees.json")
+ df.createOrReplaceTempView("employees")
+ df.show()
+ // +-------+------+
+ // | name|salary|
+ // +-------+------+
+ // |Michael| 3000|
+ // | Andy| 4500|
+ // | Justin| 3500|
+ // | Berta| 4000|
+ // +-------+------+
+
+ val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
+ result.show()
+ // +--------------+
+ // |average_salary|
+ // +--------------+
+ // | 3750.0|
+ // +--------------+
+ // $example off:untyped_custom_aggregation$
+
+ spark.stop()
+ }
+
+}