aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/simple_text_classification_pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/simple_text_classification_pipeline.py')
-rw-r--r--examples/src/main/python/ml/simple_text_classification_pipeline.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py
index c73edb7fd6..fab21f003b 100644
--- a/examples/src/main/python/ml/simple_text_classification_pipeline.py
+++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+from __future__ import print_function
+
from pyspark import SparkContext
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
@@ -37,10 +39,10 @@ if __name__ == "__main__":
# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
- training = sc.parallelize([(0L, "a b c d e spark", 1.0),
- (1L, "b d", 0.0),
- (2L, "spark f g h", 1.0),
- (3L, "hadoop mapreduce", 0.0)]) \
+ training = sc.parallelize([(0, "a b c d e spark", 1.0),
+ (1, "b d", 0.0),
+ (2, "spark f g h", 1.0),
+ (3, "hadoop mapreduce", 0.0)]) \
.map(lambda x: LabeledDocument(*x)).toDF()
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
@@ -54,16 +56,16 @@ if __name__ == "__main__":
# Prepare test documents, which are unlabeled.
Document = Row("id", "text")
- test = sc.parallelize([(4L, "spark i j k"),
- (5L, "l m n"),
- (6L, "mapreduce spark"),
- (7L, "apache hadoop")]) \
+ test = sc.parallelize([(4, "spark i j k"),
+ (5, "l m n"),
+ (6, "mapreduce spark"),
+ (7, "apache hadoop")]) \
.map(lambda x: Document(*x)).toDF()
# Make predictions on test documents and print columns of interest.
prediction = model.transform(test)
selected = prediction.select("id", "text", "prediction")
for row in selected.collect():
- print row
+ print(row)
sc.stop()