mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 15:31:48 +00:00
Dev tool: python script for text clustering based on local embeddings (#58691)
python script for text clustering based on local embeddings
This commit is contained in:
parent
335ad96623
commit
333edd3345
@ -13,3 +13,4 @@ python 3.11.3 system
|
||||
rust 1.73.0
|
||||
ruby 3.1.3
|
||||
pnpm 8.9.2
|
||||
python 3.11.3
|
||||
|
||||
47
dev/clustering/README.md
Normal file
47
dev/clustering/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
# Text Clustering
|
||||
|
||||
This directory contains Python code to cluster text data using sentence embeddings and KMeans clustering.
|
||||
|
||||
## Overview
|
||||
|
||||
The `cluster.py` script takes in a TSV file with a text field, generates sentence embeddings using the SentenceTransformers library, clusters the embeddings with KMeans, and outputs a TSV file with cluster assignments.
|
||||
|
||||
The goal is to group similar text snippets together into a predefined number of clusters.
|
||||
|
||||
## Usage
|
||||
|
||||
Ensure the required packges are installed:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
The script accepts the following arguments:
|
||||
|
||||
| Argument | Description | Default |
|
||||
| -------------- | ------------------------------------------------ | ---------- |
|
||||
| `--input` | Path to input TSV file | _Required_ |
|
||||
| `--text_field` | Name of text field in the tsv file to operate on | "text" |
|
||||
| `--clusters` | Number of clusters to generate | 4 |
|
||||
| `--output` | Path for output TSV file with clusters | _Optional_ |
|
||||
| `--model` | Sentence transformer model to use | _Optional_ |
|
||||
| `--silent` | Whether to hide plots | False |
|
||||
|
||||
Example
|
||||
|
||||
```
|
||||
python cluster.py --input data.tsv --text_field chat_message --clusters 5 --output out.tsv
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The output TSV file contains the original data plus a new "cluster" column with the assigned cluster IDs per row.
|
||||
|
||||
## Code Overview
|
||||
|
||||
**Libraries Used**
|
||||
|
||||
- pandas - for loading and manipulating data
|
||||
- SentenceTransformers - generating embeddings
|
||||
- sklearn - KMeans clustering
|
||||
- matplotlib - visualization
|
||||
58
dev/clustering/cluster.py
Normal file
58
dev/clustering/cluster.py
Normal file
@ -0,0 +1,58 @@
|
||||
import pandas as pd
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" #avoids parallelism warnings
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--clusters", default=4, type=int, help="Number of clusters to generate")
|
||||
parser.add_argument("--input", type=str, help="Path to input TSV file containing text")
|
||||
parser.add_argument("--text_field", default="text", type=str, help="Name of column in TSV containing the text to create embeddings and cluster")
|
||||
parser.add_argument("--output", type=str, help="Path to output file")
|
||||
parser.add_argument("--model", default="all-MiniLM-L6-v2", type=str, help="Sentence transformer model name")
|
||||
parser.add_argument("--quiet",action="store_true", help="hides progress bars and skip displaying plots")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize an embeddings model
|
||||
embedding_model = SentenceTransformer(args.model)
|
||||
|
||||
# Read the input tsv file into a dataframe
|
||||
df = pd.read_csv(args.input, sep='\t')
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = embedding_model.encode(df[args.text_field], show_progress_bar=not args.quiet,normalize_embeddings=True)
|
||||
|
||||
from sklearn.cluster import KMeans
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Perform KMeans clustering
|
||||
cluster_model = KMeans(n_clusters=args.clusters, random_state=42, n_init='auto')
|
||||
clusters = cluster_model.fit_predict(embeddings)
|
||||
|
||||
# update the dataframe with the cluster assignments
|
||||
df['cluster'] = cluster_model.labels_
|
||||
|
||||
|
||||
if not args.quiet:
|
||||
# plot the clusters
|
||||
# Initialize PCA and reduce dimensionality to 2 components
|
||||
pca = PCA(n_components=2)
|
||||
reduced_embeddings = pca.fit_transform(embeddings)
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=clusters, cmap='rainbow')
|
||||
plt.xlabel('PCA Component 1')
|
||||
plt.ylabel('PCA Component 2')
|
||||
plt.title('Clusters of English statements')
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
|
||||
# if an output was specified write the dataframe to a tsv file
|
||||
if args.output is not None:
|
||||
df.sort_values(by=["cluster"]).to_csv(args.output, sep='\t', index=False)
|
||||
4
dev/clustering/requirements.txt
Normal file
4
dev/clustering/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
pandas~=2.1.3
|
||||
scikit-learn~=1.3.2
|
||||
matplotlib~=3.8.2
|
||||
sentence-transformers~=2.2.2
|
||||
Loading…
Reference in New Issue
Block a user