From 333edd3345a2c4239c300d35cced0f54a2738a5b Mon Sep 17 00:00:00 2001 From: Chris Warwick Date: Mon, 4 Dec 2023 09:27:14 -0500 Subject: [PATCH] Dev tool: python script for text clustering based on local embeddings (#58691) python script for text clustering based on local embeddings --- .tool-versions | 1 + dev/clustering/README.md | 47 ++++++++++++++++++++++++++ dev/clustering/cluster.py | 58 +++++++++++++++++++++++++++++++++ dev/clustering/requirements.txt | 4 +++ 4 files changed, 110 insertions(+) create mode 100644 dev/clustering/README.md create mode 100644 dev/clustering/cluster.py create mode 100644 dev/clustering/requirements.txt diff --git a/.tool-versions b/.tool-versions index 398f4a9e970..31f967e05fa 100644 --- a/.tool-versions +++ b/.tool-versions @@ -13,3 +13,4 @@ python 3.11.3 system rust 1.73.0 ruby 3.1.3 pnpm 8.9.2 +python 3.11.3 diff --git a/dev/clustering/README.md b/dev/clustering/README.md new file mode 100644 index 00000000000..0c0ad3377a1 --- /dev/null +++ b/dev/clustering/README.md @@ -0,0 +1,47 @@ +# Text Clustering + +This directory contains Python code to cluster text data using sentence embeddings and KMeans clustering. + +## Overview + +The `cluster.py` script takes in a TSV file with a text field, generates sentence embeddings using the SentenceTransformers library, clusters the embeddings with KMeans, and outputs a TSV file with cluster assignments. + +The goal is to group similar text snippets together into a predefined number of clusters. + +## Usage + +Ensure the required packges are installed: + +``` +pip install -r requirements.txt +``` + +The script accepts the following arguments: + +| Argument | Description | Default | +| -------------- | ------------------------------------------------ | ---------- | +| `--input` | Path to input TSV file | _Required_ | +| `--text_field` | Name of text field in the tsv file to operate on | "text" | +| `--clusters` | Number of clusters to generate | 4 | +| `--output` | Path for output TSV file with clusters | _Optional_ | +| `--model` | Sentence transformer model to use | _Optional_ | +| `--silent` | Whether to hide plots | False | + +Example + +``` +python cluster.py --input data.tsv --text_field chat_message --clusters 5 --output out.tsv +``` + +## Output + +The output TSV file contains the original data plus a new "cluster" column with the assigned cluster IDs per row. + +## Code Overview + +**Libraries Used** + +- pandas - for loading and manipulating data +- SentenceTransformers - generating embeddings +- sklearn - KMeans clustering +- matplotlib - visualization diff --git a/dev/clustering/cluster.py b/dev/clustering/cluster.py new file mode 100644 index 00000000000..3a394493654 --- /dev/null +++ b/dev/clustering/cluster.py @@ -0,0 +1,58 @@ +import pandas as pd +from sentence_transformers import SentenceTransformer +from sklearn.decomposition import PCA + + +import sys +import argparse +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" #avoids parallelism warnings + +parser = argparse.ArgumentParser() +parser.add_argument("--clusters", default=4, type=int, help="Number of clusters to generate") +parser.add_argument("--input", type=str, help="Path to input TSV file containing text") +parser.add_argument("--text_field", default="text", type=str, help="Name of column in TSV containing the text to create embeddings and cluster") +parser.add_argument("--output", type=str, help="Path to output file") +parser.add_argument("--model", default="all-MiniLM-L6-v2", type=str, help="Sentence transformer model name") +parser.add_argument("--quiet",action="store_true", help="hides progress bars and skip displaying plots") + +args = parser.parse_args() + +# Initialize an embeddings model +embedding_model = SentenceTransformer(args.model) + +# Read the input tsv file into a dataframe +df = pd.read_csv(args.input, sep='\t') + +# Generate embeddings +embeddings = embedding_model.encode(df[args.text_field], show_progress_bar=not args.quiet,normalize_embeddings=True) + +from sklearn.cluster import KMeans +import matplotlib.pyplot as plt + + +# Perform KMeans clustering +cluster_model = KMeans(n_clusters=args.clusters, random_state=42, n_init='auto') +clusters = cluster_model.fit_predict(embeddings) + +# update the dataframe with the cluster assignments +df['cluster'] = cluster_model.labels_ + + +if not args.quiet: + # plot the clusters + # Initialize PCA and reduce dimensionality to 2 components + pca = PCA(n_components=2) + reduced_embeddings = pca.fit_transform(embeddings) + plt.figure(figsize=(10, 6)) + plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=clusters, cmap='rainbow') + plt.xlabel('PCA Component 1') + plt.ylabel('PCA Component 2') + plt.title('Clusters of English statements') + plt.grid(True) + plt.show() + + +# if an output was specified write the dataframe to a tsv file +if args.output is not None: + df.sort_values(by=["cluster"]).to_csv(args.output, sep='\t', index=False) diff --git a/dev/clustering/requirements.txt b/dev/clustering/requirements.txt new file mode 100644 index 00000000000..891aec19e23 --- /dev/null +++ b/dev/clustering/requirements.txt @@ -0,0 +1,4 @@ +pandas~=2.1.3 +scikit-learn~=1.3.2 +matplotlib~=3.8.2 +sentence-transformers~=2.2.2