Dev tool: python script for text clustering based on local embeddings (#58691)

python script for text clustering based on local embeddings
This commit is contained in:
Chris Warwick 2023-12-04 09:27:14 -05:00 committed by GitHub
parent 335ad96623
commit 333edd3345
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 0 deletions

View File

@ -13,3 +13,4 @@ python 3.11.3 system
rust 1.73.0
ruby 3.1.3
pnpm 8.9.2
python 3.11.3

47
dev/clustering/README.md Normal file
View File

@ -0,0 +1,47 @@
# Text Clustering
This directory contains Python code to cluster text data using sentence embeddings and KMeans clustering.
## Overview
The `cluster.py` script takes in a TSV file with a text field, generates sentence embeddings using the SentenceTransformers library, clusters the embeddings with KMeans, and outputs a TSV file with cluster assignments.
The goal is to group similar text snippets together into a predefined number of clusters.
## Usage
Ensure the required packges are installed:
```
pip install -r requirements.txt
```
The script accepts the following arguments:
| Argument | Description | Default |
| -------------- | ------------------------------------------------ | ---------- |
| `--input` | Path to input TSV file | _Required_ |
| `--text_field` | Name of text field in the tsv file to operate on | "text" |
| `--clusters` | Number of clusters to generate | 4 |
| `--output` | Path for output TSV file with clusters | _Optional_ |
| `--model` | Sentence transformer model to use | _Optional_ |
| `--silent` | Whether to hide plots | False |
Example
```
python cluster.py --input data.tsv --text_field chat_message --clusters 5 --output out.tsv
```
## Output
The output TSV file contains the original data plus a new "cluster" column with the assigned cluster IDs per row.
## Code Overview
**Libraries Used**
- pandas - for loading and manipulating data
- SentenceTransformers - generating embeddings
- sklearn - KMeans clustering
- matplotlib - visualization

58
dev/clustering/cluster.py Normal file
View File

@ -0,0 +1,58 @@
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import sys
import argparse
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" #avoids parallelism warnings
parser = argparse.ArgumentParser()
parser.add_argument("--clusters", default=4, type=int, help="Number of clusters to generate")
parser.add_argument("--input", type=str, help="Path to input TSV file containing text")
parser.add_argument("--text_field", default="text", type=str, help="Name of column in TSV containing the text to create embeddings and cluster")
parser.add_argument("--output", type=str, help="Path to output file")
parser.add_argument("--model", default="all-MiniLM-L6-v2", type=str, help="Sentence transformer model name")
parser.add_argument("--quiet",action="store_true", help="hides progress bars and skip displaying plots")
args = parser.parse_args()
# Initialize an embeddings model
embedding_model = SentenceTransformer(args.model)
# Read the input tsv file into a dataframe
df = pd.read_csv(args.input, sep='\t')
# Generate embeddings
embeddings = embedding_model.encode(df[args.text_field], show_progress_bar=not args.quiet,normalize_embeddings=True)
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
# Perform KMeans clustering
cluster_model = KMeans(n_clusters=args.clusters, random_state=42, n_init='auto')
clusters = cluster_model.fit_predict(embeddings)
# update the dataframe with the cluster assignments
df['cluster'] = cluster_model.labels_
if not args.quiet:
# plot the clusters
# Initialize PCA and reduce dimensionality to 2 components
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(embeddings)
plt.figure(figsize=(10, 6))
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=clusters, cmap='rainbow')
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.title('Clusters of English statements')
plt.grid(True)
plt.show()
# if an output was specified write the dataframe to a tsv file
if args.output is not None:
df.sort_values(by=["cluster"]).to_csv(args.output, sep='\t', index=False)

View File

@ -0,0 +1,4 @@
pandas~=2.1.3
scikit-learn~=1.3.2
matplotlib~=3.8.2
sentence-transformers~=2.2.2