diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/streaming/tests.py | 16 | ||||
-rw-r--r-- | python/pyspark/tests.py | 3 |
2 files changed, 18 insertions, 1 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 57049beea4..91ce681fbe 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import os import sys from itertools import chain @@ -677,4 +678,19 @@ class KafkaStreamTests(PySparkStreamingTestCase): self._validateRddResult(sendData, rdd) if __name__ == "__main__": + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = glob.glob( + os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " + "remove all but one") % kafka_assembly_dir) + else: + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7826542368..17256dfc95 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1421,7 +1421,8 @@ class DaemonTests(unittest.TestCase): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) |