diff options
Diffstat (limited to 'external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala')
-rw-r--r-- | external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala new file mode 100644 index 0000000000..3b5a96534f --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging + + +/** + * Consumer of single topicpartition, intended for cached reuse. + * Underlying consumer is not threadsafe, so neither is this, + * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + */ +private[kafka010] case class CachedKafkaConsumer private( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object]) extends Logging { + + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + private val consumer = { + val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = -2L + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + if (offset != nextOffsetInFetchedData) { + logInfo(s"Initial fetch for $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + + if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + var record = fetchedData.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + record = fetchedData.next() + assert(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + } + + nextOffsetInFetchedData = offset + 1 + record + } + + private def close(): Unit = consumer.close() + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $groupId $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(pollTimeoutMs: Long): Unit = { + val p = consumer.poll(pollTimeoutMs) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + fetchedData = r.iterator + } +} + +private[kafka010] object CachedKafkaConsumer extends Logging { + + private case class CacheKey(groupId: String, topicPartition: TopicPartition) + + private lazy val cache = { + val conf = SparkEnv.get.conf + val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) + new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { + if (this.size > capacity) { + logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case e: SparkException => + logError(s"Error closing earliest Kafka consumer for ${entry.getKey}", e) + } + true + } else { + false + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + */ + def getOrCreate( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val topicPartition = new TopicPartition(topic, partition) + val key = CacheKey(groupId, topicPartition) + + // If this is reattempt at running the task, then invalidate cache and start with + // a new consumer + if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + cache.remove(key) + new CachedKafkaConsumer(topicPartition, kafkaParams) + } else { + if (!cache.containsKey(key)) { + cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + } + cache.get(key) + } + } +} |