aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
blob: da3a5dfe8628bbbbe434a042604883ca2c16019c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.examples.sql.streaming;

import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.MapGroupsWithStateFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.streaming.GroupState;
import org.apache.spark.sql.streaming.GroupStateTimeout;
import org.apache.spark.sql.streaming.StreamingQuery;

import java.io.Serializable;
import java.sql.Timestamp;
import java.util.*;

import scala.Tuple2;

/**
 * Counts words in UTF8 encoded, '\n' delimited text received from the network.
 * <p>
 * Usage: JavaStructuredNetworkWordCount <hostname> <port>
 * <hostname> and <port> describe the TCP server that Structured Streaming
 * would connect to receive data.
 * <p>
 * To run this on your local machine, you need to first run a Netcat server
 * `$ nc -lk 9999`
 * and then run the example
 * `$ bin/run-example sql.streaming.JavaStructuredSessionization
 * localhost 9999`
 */
public final class JavaStructuredSessionization {

  public static void main(String[] args) throws Exception {
    if (args.length < 2) {
      System.err.println("Usage: JavaStructuredSessionization <hostname> <port>");
      System.exit(1);
    }

    String host = args[0];
    int port = Integer.parseInt(args[1]);

    SparkSession spark = SparkSession
        .builder()
        .appName("JavaStructuredSessionization")
        .getOrCreate();

    // Create DataFrame representing the stream of input lines from connection to host:port
    Dataset<Row> lines = spark
        .readStream()
        .format("socket")
        .option("host", host)
        .option("port", port)
        .option("includeTimestamp", true)
        .load();

    FlatMapFunction<LineWithTimestamp, Event> linesToEvents =
      new FlatMapFunction<LineWithTimestamp, Event>() {
        @Override
        public Iterator<Event> call(LineWithTimestamp lineWithTimestamp) throws Exception {
          ArrayList<Event> eventList = new ArrayList<Event>();
          for (String word : lineWithTimestamp.getLine().split(" ")) {
            eventList.add(new Event(word, lineWithTimestamp.getTimestamp()));
          }
          System.out.println(
              "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size());
          return eventList.iterator();
        }
      };

    // Split the lines into words, treat words as sessionId of events
    Dataset<Event> events = lines
        .withColumnRenamed("value", "line")
        .as(Encoders.bean(LineWithTimestamp.class))
        .flatMap(linesToEvents, Encoders.bean(Event.class));

    // Sessionize the events. Track number of events, start and end timestamps of session, and
    // and report session updates.
    //
    // Step 1: Define the state update function
    MapGroupsWithStateFunction<String, Event, SessionInfo, SessionUpdate> stateUpdateFunc =
      new MapGroupsWithStateFunction<String, Event, SessionInfo, SessionUpdate>() {
        @Override public SessionUpdate call(
            String sessionId, Iterator<Event> events, GroupState<SessionInfo> state)
              throws Exception {
          // If timed out, then remove session and send final update
          if (state.hasTimedOut()) {
            SessionUpdate finalUpdate = new SessionUpdate(
                sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true);
            state.remove();
            return finalUpdate;

          } else {
            // Find max and min timestamps in events
            long maxTimestampMs = Long.MIN_VALUE;
            long minTimestampMs = Long.MAX_VALUE;
            int numNewEvents = 0;
            while (events.hasNext()) {
              Event e = events.next();
              long timestampMs = e.getTimestamp().getTime();
              maxTimestampMs = Math.max(timestampMs, maxTimestampMs);
              minTimestampMs = Math.min(timestampMs, minTimestampMs);
              numNewEvents += 1;
            }
            SessionInfo updatedSession = new SessionInfo();

            // Update start and end timestamps in session
            if (state.exists()) {
              SessionInfo oldSession = state.get();
              updatedSession.setNumEvents(oldSession.numEvents + numNewEvents);
              updatedSession.setStartTimestampMs(oldSession.startTimestampMs);
              updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs));
            } else {
              updatedSession.setNumEvents(numNewEvents);
              updatedSession.setStartTimestampMs(minTimestampMs);
              updatedSession.setEndTimestampMs(maxTimestampMs);
            }
            state.update(updatedSession);
            // Set timeout such that the session will be expired if no data received for 10 seconds
            state.setTimeoutDuration("10 seconds");
            return new SessionUpdate(
                sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false);
          }
        }
      };

    // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId
    Dataset<SessionUpdate> sessionUpdates = events
        .groupByKey(
            new MapFunction<Event, String>() {
              @Override public String call(Event event) throws Exception {
                return event.getSessionId();
              }
            }, Encoders.STRING())
        .mapGroupsWithState(
            stateUpdateFunc,
            Encoders.bean(SessionInfo.class),
            Encoders.bean(SessionUpdate.class),
            GroupStateTimeout.ProcessingTimeTimeout());

    // Start running the query that prints the session updates to the console
    StreamingQuery query = sessionUpdates
        .writeStream()
        .outputMode("update")
        .format("console")
        .start();

    query.awaitTermination();
  }

  /**
   * User-defined data type representing the raw lines with timestamps.
   */
  public static class LineWithTimestamp implements Serializable {
    private String line;
    private Timestamp timestamp;

    public Timestamp getTimestamp() { return timestamp; }
    public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; }

    public String getLine() { return line; }
    public void setLine(String sessionId) { this.line = sessionId; }
  }

  /**
   * User-defined data type representing the input events
   */
  public static class Event implements Serializable {
    private String sessionId;
    private Timestamp timestamp;

    public Event() { }
    public Event(String sessionId, Timestamp timestamp) {
      this.sessionId = sessionId;
      this.timestamp = timestamp;
    }

    public Timestamp getTimestamp() { return timestamp; }
    public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; }

    public String getSessionId() { return sessionId; }
    public void setSessionId(String sessionId) { this.sessionId = sessionId; }
  }

  /**
   * User-defined data type for storing a session information as state in mapGroupsWithState.
   */
  public static class SessionInfo implements Serializable {
    private int numEvents = 0;
    private long startTimestampMs = -1;
    private long endTimestampMs = -1;

    public int getNumEvents() { return numEvents; }
    public void setNumEvents(int numEvents) { this.numEvents = numEvents; }

    public long getStartTimestampMs() { return startTimestampMs; }
    public void setStartTimestampMs(long startTimestampMs) {
      this.startTimestampMs = startTimestampMs;
    }

    public long getEndTimestampMs() { return endTimestampMs; }
    public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; }

    public long getDurationMs() { return endTimestampMs - startTimestampMs; }
    @Override public String toString() {
      return "SessionInfo(numEvents = " + numEvents +
          ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")";
    }
  }

  /**
   * User-defined data type representing the update information returned by mapGroupsWithState.
   */
  public static class SessionUpdate implements Serializable {
    private String id;
    private long durationMs;
    private int numEvents;
    private boolean expired;

    public SessionUpdate() { }

    public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) {
      this.id = id;
      this.durationMs = durationMs;
      this.numEvents = numEvents;
      this.expired = expired;
    }

    public String getId() { return id; }
    public void setId(String id) { this.id = id; }

    public long getDurationMs() { return durationMs; }
    public void setDurationMs(long durationMs) { this.durationMs = durationMs; }

    public int getNumEvents() { return numEvents; }
    public void setNumEvents(int numEvents) { this.numEvents = numEvents; }

    public boolean isExpired() { return expired; }
    public void setExpired(boolean expired) { this.expired = expired; }
  }
}