diff options
author | Yun Ni <yunn@uber.com> | 2017-02-15 16:26:05 -0800 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2017-02-15 16:26:05 -0800 |
commit | 08c1972a0661d42f300520cc6e5fb31023de093b (patch) | |
tree | 8b392b4520df66ca32834c11fc376009be70e8b8 /examples/src/main/java | |
parent | 21b4ba2d6f21a9759af879471715c123073bd67a (diff) | |
download | spark-08c1972a0661d42f300520cc6e5fb31023de093b.tar.gz spark-08c1972a0661d42f300520cc6e5fb31023de093b.tar.bz2 spark-08c1972a0661d42f300520cc6e5fb31023de093b.zip |
[SPARK-18080][ML][PYTHON] Python API & Examples for Locality Sensitive Hashing
## What changes were proposed in this pull request?
This pull request includes python API and examples for LSH. The API changes was based on yanboliang 's PR #15768 and resolved conflicts and API changes on the Scala API. The examples are consistent with Scala examples of MinHashLSH and BucketedRandomProjectionLSH.
## How was this patch tested?
API and examples are tested using spark-submit:
`bin/spark-submit examples/src/main/python/ml/min_hash_lsh.py`
`bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh.py`
User guide changes are generated and manually inspected:
`SKIP_API=1 jekyll build`
Author: Yun Ni <yunn@uber.com>
Author: Yanbo Liang <ybliang8@gmail.com>
Author: Yunni <Euler57721@gmail.com>
Closes #16715 from Yunni/spark-18080.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java | 38 | ||||
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java | 57 |
2 files changed, 74 insertions, 21 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java index ca3ee5a285..4594e3462b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java @@ -35,8 +35,15 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.functions.col; // $example off$ +/** + * An example demonstrating BucketedRandomProjectionLSH. + * Run with: + * bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample + */ public class JavaBucketedRandomProjectionLSHExample { public static void main(String[] args) { SparkSession spark = SparkSession @@ -61,7 +68,7 @@ public class JavaBucketedRandomProjectionLSHExample { StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("keys", new VectorUDT(), false, Metadata.empty()) + new StructField("features", new VectorUDT(), false, Metadata.empty()) }); Dataset<Row> dfA = spark.createDataFrame(dataA, schema); Dataset<Row> dfB = spark.createDataFrame(dataB, schema); @@ -71,26 +78,31 @@ public class JavaBucketedRandomProjectionLSHExample { BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH() .setBucketLength(2.0) .setNumHashTables(3) - .setInputCol("keys") - .setOutputCol("values"); + .setInputCol("features") + .setOutputCol("hashes"); BucketedRandomProjectionLSHModel model = mh.fit(dfA); // Feature Transformation + System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':"); model.transform(dfA).show(); - // Cache the transformed columns - Dataset<Row> transformedA = model.transform(dfA).cache(); - Dataset<Row> transformedB = model.transform(dfB).cache(); - // Approximate similarity join - model.approxSimilarityJoin(dfA, dfB, 1.5).show(); - model.approxSimilarityJoin(transformedA, transformedB, 1.5).show(); - // Self Join - model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show(); + // Compute the locality sensitive hashes for the input rows, then perform approximate + // similarity join. + // We could avoid computing hashes by passing in the already-transformed dataset, e.g. + // `model.approxSimilarityJoin(transformedA, transformedB, 1.5)` + System.out.println("Approximately joining dfA and dfB on distance smaller than 1.5:"); + model.approxSimilarityJoin(dfA, dfB, 1.5, "EuclideanDistance") + .select(col("datasetA.id").alias("idA"), + col("datasetB.id").alias("idB"), + col("EuclideanDistance")).show(); - // Approximate nearest neighbor search + // Compute the locality sensitive hashes for the input rows, then perform approximate nearest + // neighbor search. + // We could avoid computing hashes by passing in the already-transformed dataset, e.g. + // `model.approxNearestNeighbors(transformedA, key, 2)` + System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:"); model.approxNearestNeighbors(dfA, key, 2).show(); - model.approxNearestNeighbors(transformedA, key, 2).show(); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java index 9dbbf6d117..0aace46939 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.spark.ml.feature.MinHashLSH; import org.apache.spark.ml.feature.MinHashLSHModel; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; @@ -34,8 +35,15 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.functions.col; // $example off$ +/** + * An example demonstrating MinHashLSH. + * Run with: + * bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample + */ public class JavaMinHashLSHExample { public static void main(String[] args) { SparkSession spark = SparkSession @@ -44,25 +52,58 @@ public class JavaMinHashLSHExample { .getOrCreate(); // $example on$ - List<Row> data = Arrays.asList( + List<Row> dataA = Arrays.asList( RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})), RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})), RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0})) ); + List<Row> dataB = Arrays.asList( + RowFactory.create(0, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0, 1.0, 1.0})), + RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0, 1.0, 1.0})), + RowFactory.create(2, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0, 1.0, 1.0})) + ); + StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("keys", new VectorUDT(), false, Metadata.empty()) + new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - Dataset<Row> dataFrame = spark.createDataFrame(data, schema); + Dataset<Row> dfA = spark.createDataFrame(dataA, schema); + Dataset<Row> dfB = spark.createDataFrame(dataB, schema); + + int[] indices = {1, 3}; + double[] values = {1.0, 1.0}; + Vector key = Vectors.sparse(6, indices, values); MinHashLSH mh = new MinHashLSH() - .setNumHashTables(1) - .setInputCol("keys") - .setOutputCol("values"); + .setNumHashTables(5) + .setInputCol("features") + .setOutputCol("hashes"); + + MinHashLSHModel model = mh.fit(dfA); + + // Feature Transformation + System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':"); + model.transform(dfA).show(); + + // Compute the locality sensitive hashes for the input rows, then perform approximate + // similarity join. + // We could avoid computing hashes by passing in the already-transformed dataset, e.g. + // `model.approxSimilarityJoin(transformedA, transformedB, 0.6)` + System.out.println("Approximately joining dfA and dfB on Jaccard distance smaller than 0.6:"); + model.approxSimilarityJoin(dfA, dfB, 0.6, "JaccardDistance") + .select(col("datasetA.id").alias("idA"), + col("datasetB.id").alias("idB"), + col("JaccardDistance")).show(); - MinHashLSHModel model = mh.fit(dataFrame); - model.transform(dataFrame).show(); + // Compute the locality sensitive hashes for the input rows, then perform approximate nearest + // neighbor search. + // We could avoid computing hashes by passing in the already-transformed dataset, e.g. + // `model.approxNearestNeighbors(transformedA, key, 2)` + // It may return less than 2 rows when not enough approximate near-neighbor candidates are + // found. + System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:"); + model.approxNearestNeighbors(dfA, key, 2).show(); // $example off$ spark.stop(); |