diff options
Diffstat (limited to 'examples/src/main/python/logistic_regression.py')
-rwxr-xr-x | examples/src/main/python/logistic_regression.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 7d33be7e81..01c938454b 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -27,7 +27,7 @@ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession D = 10 # Number of dimensions @@ -55,8 +55,13 @@ if __name__ == "__main__": Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py to see how ML's implementation is used.""", file=sys.stderr) - sc = SparkContext(appName="PythonLR") - points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() + spark = SparkSession\ + .builder\ + .appName("PythonLR")\ + .getOrCreate() + + points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\ + .mapPartitions(readPointBatch).cache() iterations = int(sys.argv[2]) # Initialize w to a random value @@ -80,4 +85,4 @@ if __name__ == "__main__": print("Final w: " + str(w)) - sc.stop() + spark.stop() |