From 3fdce814348fae34df379a6ab9655dbbb2c3427c Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Tue, 24 Jan 2017 22:13:17 -0800 Subject: [SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? - A separate subsection for Aggregations under “Getting Started” in the Spark SQL programming guide. It mentions which aggregate functions are predefined and how users can create their own. - Examples of using the `UserDefinedAggregateFunction` abstract class for untyped aggregations in Java and Scala. - Examples of using the `Aggregator` abstract class for type-safe aggregations in Java and Scala. - Python is not covered. - The PR might not resolve the ticket since I do not know what exactly was planned by the author. In total, there are four new standalone examples that can be executed via `spark-submit` or `run-example`. The updated Spark SQL programming guide references to these examples and does not contain hard-coded snippets. ## How was this patch tested? The patch was tested locally by building the docs. The examples were run as well. ![image](https://cloud.githubusercontent.com/assets/6235869/21292915/04d9d084-c515-11e6-811a-999d598dffba.png) Author: aokolnychyi Closes #16329 from aokolnychyi/SPARK-16046. --- .../sql/JavaUserDefinedTypedAggregation.java | 160 +++++++++++++++++++++ .../sql/JavaUserDefinedUntypedAggregation.java | 132 +++++++++++++++++ examples/src/main/resources/employees.json | 4 + .../examples/sql/UserDefinedTypedAggregation.scala | 91 ++++++++++++ .../sql/UserDefinedUntypedAggregation.scala | 100 +++++++++++++ 5 files changed, 487 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java create mode 100644 examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java create mode 100644 examples/src/main/resources/employees.json create mode 100644 examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala (limited to 'examples/src/main') diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java new file mode 100644 index 0000000000..78e9011be4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:typed_custom_aggregation$ +import java.io.Serializable; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.expressions.Aggregator; +// $example off:typed_custom_aggregation$ + +public class JavaUserDefinedTypedAggregation { + + // $example on:typed_custom_aggregation$ + public static class Employee implements Serializable { + private String name; + private long salary; + + // Constructors, getters, setters... + // $example off:typed_custom_aggregation$ + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public long getSalary() { + return salary; + } + + public void setSalary(long salary) { + this.salary = salary; + } + // $example on:typed_custom_aggregation$ + } + + public static class Average implements Serializable { + private long sum; + private long count; + + // Constructors, getters, setters... + // $example off:typed_custom_aggregation$ + public Average() { + } + + public Average(long sum, long count) { + this.sum = sum; + this.count = count; + } + + public long getSum() { + return sum; + } + + public void setSum(long sum) { + this.sum = sum; + } + + public long getCount() { + return count; + } + + public void setCount(long count) { + this.count = count; + } + // $example on:typed_custom_aggregation$ + } + + public static class MyAverage extends Aggregator { + // A zero value for this aggregation. Should satisfy the property that any b + zero = b + public Average zero() { + return new Average(0L, 0L); + } + // Combine two values to produce a new value. For performance, the function may modify `buffer` + // and return it instead of constructing a new object + public Average reduce(Average buffer, Employee employee) { + long newSum = buffer.getSum() + employee.getSalary(); + long newCount = buffer.getCount() + 1; + buffer.setSum(newSum); + buffer.setCount(newCount); + return buffer; + } + // Merge two intermediate values + public Average merge(Average b1, Average b2) { + long mergedSum = b1.getSum() + b2.getSum(); + long mergedCount = b1.getCount() + b2.getCount(); + b1.setSum(mergedSum); + b1.setCount(mergedCount); + return b1; + } + // Transform the output of the reduction + public Double finish(Average reduction) { + return ((double) reduction.getSum()) / reduction.getCount(); + } + // Specifies the Encoder for the intermediate value type + public Encoder bufferEncoder() { + return Encoders.bean(Average.class); + } + // Specifies the Encoder for the final output value type + public Encoder outputEncoder() { + return Encoders.DOUBLE(); + } + } + // $example off:typed_custom_aggregation$ + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL user-defined Datasets aggregation example") + .getOrCreate(); + + // $example on:typed_custom_aggregation$ + Encoder employeeEncoder = Encoders.bean(Employee.class); + String path = "examples/src/main/resources/employees.json"; + Dataset ds = spark.read().json(path).as(employeeEncoder); + ds.show(); + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + MyAverage myAverage = new MyAverage(); + // Convert the function to a `TypedColumn` and give it a name + TypedColumn averageSalary = myAverage.toColumn().name("average_salary"); + Dataset result = ds.select(averageSalary); + result.show(); + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:typed_custom_aggregation$ + spark.stop(); + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java new file mode 100644 index 0000000000..6da60a1fc6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:untyped_custom_aggregation$ +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off:untyped_custom_aggregation$ + +public class JavaUserDefinedUntypedAggregation { + + // $example on:untyped_custom_aggregation$ + public static class MyAverage extends UserDefinedAggregateFunction { + + private StructType inputSchema; + private StructType bufferSchema; + + public MyAverage() { + List inputFields = new ArrayList<>(); + inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); + inputSchema = DataTypes.createStructType(inputFields); + + List bufferFields = new ArrayList<>(); + bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); + bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); + bufferSchema = DataTypes.createStructType(bufferFields); + } + // Data types of input arguments of this aggregate function + public StructType inputSchema() { + return inputSchema; + } + // Data types of values in the aggregation buffer + public StructType bufferSchema() { + return bufferSchema; + } + // The data type of the returned value + public DataType dataType() { + return DataTypes.DoubleType; + } + // Whether this function always returns the same output on the identical input + public boolean deterministic() { + return true; + } + // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to + // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides + // the opportunity to update its values. Note that arrays and maps inside the buffer are still + // immutable. + public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, 0L); + buffer.update(1, 0L); + } + // Updates the given aggregation buffer `buffer` with new input data from `input` + public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + long updatedSum = buffer.getLong(0) + input.getLong(0); + long updatedCount = buffer.getLong(1) + 1; + buffer.update(0, updatedSum); + buffer.update(1, updatedCount); + } + } + // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` + public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); + long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); + buffer1.update(0, mergedSum); + buffer1.update(1, mergedCount); + } + // Calculates the final result + public Double evaluate(Row buffer) { + return ((double) buffer.getLong(0)) / buffer.getLong(1); + } + } + // $example off:untyped_custom_aggregation$ + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL user-defined DataFrames aggregation example") + .getOrCreate(); + + // $example on:untyped_custom_aggregation$ + // Register the function to access it + spark.udf().register("myAverage", new MyAverage()); + + Dataset df = spark.read().json("examples/src/main/resources/employees.json"); + df.createOrReplaceTempView("employees"); + df.show(); + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + Dataset result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees"); + result.show(); + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:untyped_custom_aggregation$ + + spark.stop(); + } +} diff --git a/examples/src/main/resources/employees.json b/examples/src/main/resources/employees.json new file mode 100644 index 0000000000..6b2e6329a1 --- /dev/null +++ b/examples/src/main/resources/employees.json @@ -0,0 +1,4 @@ +{"name":"Michael", "salary":3000} +{"name":"Andy", "salary":4500} +{"name":"Justin", "salary":3500} +{"name":"Berta", "salary":4000} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala new file mode 100644 index 0000000000..ac617d19d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +// $example on:typed_custom_aggregation$ +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.SparkSession +// $example off:typed_custom_aggregation$ + +object UserDefinedTypedAggregation { + + // $example on:typed_custom_aggregation$ + case class Employee(name: String, salary: Long) + case class Average(var sum: Long, var count: Long) + + object MyAverage extends Aggregator[Employee, Average, Double] { + // A zero value for this aggregation. Should satisfy the property that any b + zero = b + def zero: Average = Average(0L, 0L) + // Combine two values to produce a new value. For performance, the function may modify `buffer` + // and return it instead of constructing a new object + def reduce(buffer: Average, employee: Employee): Average = { + buffer.sum += employee.salary + buffer.count += 1 + buffer + } + // Merge two intermediate values + def merge(b1: Average, b2: Average): Average = { + b1.sum += b2.sum + b1.count += b2.count + b1 + } + // Transform the output of the reduction + def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count + // Specifies the Encoder for the intermediate value type + def bufferEncoder: Encoder[Average] = Encoders.product + // Specifies the Encoder for the final output value type + def outputEncoder: Encoder[Double] = Encoders.scalaDouble + } + // $example off:typed_custom_aggregation$ + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("Spark SQL user-defined Datasets aggregation example") + .getOrCreate() + + import spark.implicits._ + + // $example on:typed_custom_aggregation$ + val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee] + ds.show() + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + // Convert the function to a `TypedColumn` and give it a name + val averageSalary = MyAverage.toColumn.name("average_salary") + val result = ds.select(averageSalary) + result.show() + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:typed_custom_aggregation$ + + spark.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala new file mode 100644 index 0000000000..9c9ebc5516 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +// $example on:untyped_custom_aggregation$ +import org.apache.spark.sql.expressions.MutableAggregationBuffer +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession +// $example off:untyped_custom_aggregation$ + +object UserDefinedUntypedAggregation { + + // $example on:untyped_custom_aggregation$ + object MyAverage extends UserDefinedAggregateFunction { + // Data types of input arguments of this aggregate function + def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) + // Data types of values in the aggregation buffer + def bufferSchema: StructType = { + StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) + } + // The data type of the returned value + def dataType: DataType = DoubleType + // Whether this function always returns the same output on the identical input + def deterministic: Boolean = true + // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to + // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides + // the opportunity to update its values. Note that arrays and maps inside the buffer are still + // immutable. + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + buffer(1) = 0L + } + // Updates the given aggregation buffer `buffer` with new input data from `input` + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0)) { + buffer(0) = buffer.getLong(0) + input.getLong(0) + buffer(1) = buffer.getLong(1) + 1 + } + } + // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) + } + // Calculates the final result + def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1) + } + // $example off:untyped_custom_aggregation$ + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("Spark SQL user-defined DataFrames aggregation example") + .getOrCreate() + + // $example on:untyped_custom_aggregation$ + // Register the function to access it + spark.udf.register("myAverage", MyAverage) + + val df = spark.read.json("examples/src/main/resources/employees.json") + df.createOrReplaceTempView("employees") + df.show() + // +-------+------+ + // | name|salary| + // +-------+------+ + // |Michael| 3000| + // | Andy| 4500| + // | Justin| 3500| + // | Berta| 4000| + // +-------+------+ + + val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees") + result.show() + // +--------------+ + // |average_salary| + // +--------------+ + // | 3750.0| + // +--------------+ + // $example off:untyped_custom_aggregation$ + + spark.stop() + } + +} -- cgit v1.2.3