1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
|
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.types.StructType
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}
/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
* is primarily intended for use in unit tests as it can only replay data when the object is still
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]
protected var currentOffset: LongOffset = new LongOffset(-1)
def schema: StructType = encoder.schema
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
Dataset(sqlContext, logicalPlan)
}
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
Dataset.ofRows(sqlContext, logicalPlan)
}
def addData(data: A*): Offset = {
addData(data.toTraversable)
}
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
this.synchronized {
currentOffset = currentOffset + 1
val ds = data.toVector.toDS()
logDebug(s"Adding ds: $ds")
batches.append(ds)
currentOffset
}
}
override def toString: String = s"MemoryStream[${output.mkString(",")}]"
override def getOffset: Option[Offset] = if (batches.isEmpty) {
None
} else {
Some(currentOffset)
}
/**
* Returns the next batch of data that is available after `start`, if any is available.
*/
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startOrdinal =
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
val newBlocks = batches.slice(startOrdinal, endOrdinal)
logDebug(
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
newBlocks
.map(_.toDF())
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
}
}
}
/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
class MemorySink(val schema: StructType) extends Sink with Logging {
/** An order list of batches that have been written to this [[Sink]]. */
private val batches = new ArrayBuffer[Array[Row]]()
/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatten
}
def lastBatch: Seq[Row] = batches.last
def toDebugString: String = synchronized {
batches.zipWithIndex.map { case (b, i) =>
val dataStr = try b.mkString(" ") catch {
case NonFatal(e) => "[Error converting to string]"
}
s"$i: $dataStr"
}.mkString("\n")
}
override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId == batches.size) {
logDebug(s"Committing batch $batchId")
batches.append(data.collect())
} else {
logDebug(s"Skipping already committed batch: $batchId")
}
}
}
|