aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
blob: b652530d7c78ce5a7d3c5b71fdc11915241c24c1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.streaming

import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.types.StructType

object MemoryStream {
  protected val currentBlockId = new AtomicInteger(0)
  protected val memoryStreamId = new AtomicInteger(0)

  def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
    new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}

/**
 * A [[Source]] that produces value stored in memory as they are added by the user.  This [[Source]]
 * is primarily intended for use in unit tests as it can only replay data when the object is still
 * available.
 */
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
    extends Source with Logging {
  protected val encoder = encoderFor[A]
  protected val logicalPlan = StreamingExecutionRelation(this)
  protected val output = logicalPlan.output
  protected val batches = new ArrayBuffer[Dataset[A]]

  protected var currentOffset: LongOffset = new LongOffset(-1)

  def schema: StructType = encoder.schema

  def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
    Dataset(sqlContext, logicalPlan)
  }

  def toDF()(implicit sqlContext: SQLContext): DataFrame = {
    Dataset.ofRows(sqlContext, logicalPlan)
  }

  def addData(data: A*): Offset = {
    addData(data.toTraversable)
  }

  def addData(data: TraversableOnce[A]): Offset = {
    import sqlContext.implicits._
    this.synchronized {
      currentOffset = currentOffset + 1
      val ds = data.toVector.toDS()
      logDebug(s"Adding ds: $ds")
      batches.append(ds)
      currentOffset
    }
  }

  override def toString: String = s"MemoryStream[${output.mkString(",")}]"

  override def getOffset: Option[Offset] = if (batches.isEmpty) {
    None
  } else {
    Some(currentOffset)
  }

  /**
   * Returns the next batch of data that is available after `start`, if any is available.
   */
  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
    val startOrdinal =
      start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
    val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
    val newBlocks = batches.slice(startOrdinal, endOrdinal)

    logDebug(
      s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
    newBlocks
      .map(_.toDF())
      .reduceOption(_ union _)
      .getOrElse {
        sys.error("No data selected!")
      }
  }
}

/**
 * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
 * tests and does not provide durability.
 */
class MemorySink(val schema: StructType) extends Sink with Logging {
  /** An order list of batches that have been written to this [[Sink]]. */
  private val batches = new ArrayBuffer[Array[Row]]()

  /** Returns all rows that are stored in this [[Sink]]. */
  def allData: Seq[Row] = synchronized {
    batches.flatten
  }

  def lastBatch: Seq[Row] = batches.last

  def toDebugString: String = synchronized {
    batches.zipWithIndex.map { case (b, i) =>
      val dataStr = try b.mkString(" ") catch {
        case NonFatal(e) => "[Error converting to string]"
      }
      s"$i: $dataStr"
    }.mkString("\n")
  }

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId == batches.size) {
      logDebug(s"Committing batch $batchId")
      batches.append(data.collect())
    } else {
      logDebug(s"Skipping already committed batch: $batchId")
    }
  }
}