aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/logistic_regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/logistic_regression.py')
-rwxr-xr-xexamples/src/main/python/logistic_regression.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py
index 7d33be7e81..01c938454b 100755
--- a/examples/src/main/python/logistic_regression.py
+++ b/examples/src/main/python/logistic_regression.py
@@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
D = 10 # Number of dimensions
@@ -55,8 +55,13 @@ if __name__ == "__main__":
Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
to see how ML's implementation is used.""", file=sys.stderr)
- sc = SparkContext(appName="PythonLR")
- points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
+ spark = SparkSession\
+ .builder\
+ .appName("PythonLR")\
+ .getOrCreate()
+
+ points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
+ .mapPartitions(readPointBatch).cache()
iterations = int(sys.argv[2])
# Initialize w to a random value
@@ -80,4 +85,4 @@ if __name__ == "__main__":
print("Final w: " + str(w))
- sc.stop()
+ spark.stop()