aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r--python/pyspark/streaming/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 57049beea4..91ce681fbe 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import glob
import os
import sys
from itertools import chain
@@ -677,4 +678,19 @@ class KafkaStreamTests(PySparkStreamingTestCase):
self._validateRddResult(sendData, rdd)
if __name__ == "__main__":
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
+ jars = glob.glob(
+ os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
+ "'build/mvn package' before running this test")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please "
+ "remove all but one") % kafka_assembly_dir)
+ else:
+ os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0]
unittest.main()