I was working on a system that needed to sort customer feedback into categories. We tried a few rule-based methods and simpler models, but they often missed the subtle meaning in people’s words. A frustrated customer might write, “This is fast,” which could be positive or sarcastic. Traditional models struggled with that. That’s when I turned to BERT. This article is my guide to building a text classification system that actually works in a real-world application, from start to finish. I’ll share the steps and code that helped me solve this problem.
Think about the last email you wrote. How did you choose the right words? You understood the whole sentence at once, considering every word’s relationship to the others. That’s what makes BERT special. It doesn’t just read text left to right. It looks at all words in a sentence together to grasp the full context. This is a big shift from older models.
Here’s the first thing you need to understand: tokenization. BERT doesn’t read words like we do. It breaks text into smaller pieces called tokens. This includes parts of words. Let’s see how this works in code.
from transformers import BertTokenizer
# Load the tokenizer that matches the BERT model we'll use
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Let's process a sample sentence
sample_text = "The product delivery was surprisingly swift."
tokens = tokenizer.tokenize(sample_text)
print(tokens)
# Output: ['the', 'product', 'delivery', 'was', 'surprisingly', 'swift', '.']
# Now, convert these tokens to the IDs BERT understands
input_ids = tokenizer.convert_tokens_to_ids(tokens)
print(input_ids)
# Output: [1996, 4037, 5353, 2001, 7962, 17250, 1012]
Notice how common words like ‘the’ and ‘was’ become single tokens, while others are just fine as they are. The tokenizer also adds special tokens, like [CLS] at the start and [SEP] for separations, which are crucial for BERT’s understanding. This process is the foundation of everything that follows.
How do we prepare our own data for this? We can’t just feed raw text. We need a structured dataset. PyTorch’s Dataset class is perfect for this. It helps us manage our texts and labels, and apply the tokenization consistently.
import torch
from torch.utils.data import Dataset
class FeedbackDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt', # Return PyTorch tensors
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
The __getitem__ method is the key. Every time the training loop asks for a piece of data, this method tokenizes a text, pads it to a fixed length, and returns it with its label and an attention mask. The mask tells the model which tokens are real words and which are just padding. This ensures every input is the same size.
With our data ready, the next step is building the model itself. We don’t start from scratch. We use a pre-trained BERT model and add a simple classification layer on top. This is called fine-tuning. It’s much faster and more effective than training a huge model from zero.
from transformers import BertModel, BertPreTrainedModel
import torch.nn as nn
class BertForTextClassification(BertPreTrainedModel):
def __init__(self, config, num_labels=3):
super().__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
# Add a dropout layer for regularization
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Add a simple linear classifier on top of BERT
self.classifier = nn.Linear(config.hidden_size, num_labels)
# Initialize the weights for the new classifier layer
self.init_weights()
def forward(self, input_ids, attention_mask, labels=None):
# Get the contextualized embeddings from BERT
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
# We use the embedding of the [CLS] token for classification
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return (loss, logits) if loss is not None else logits
The magic happens in the forward method. BERT processes the token IDs and produces a hidden state for each token. We take the state for the special [CLS] token, which is designed to hold a summary of the whole sentence. We pass this through our small neural network (the classifier) to get a prediction for each category. If a label is provided, we also calculate the loss.
Now, how do we train this model effectively? We need a training loop that manages the process, tracks progress, and saves the best version. This involves setting up an optimizer and a learning rate scheduler.
from transformers import AdamW, get_linear_schedule_with_warmup
import numpy as np
def train_epoch(model, data_loader, optimizer, scheduler, device):
model.train()
total_loss = 0
for batch in data_loader:
# Move the batch to the GPU if available
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Reset the gradients from the previous step
optimizer.zero_grad()
# Forward pass
loss, logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
# Backward pass
loss.backward()
# Prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
return total_loss / len(data_loader)
The optimizer (AdamW) adjusts the model’s weights to reduce error. The scheduler changes the learning rate during training, starting with a warm-up phase to stabilize the early updates. Clipping the gradients is a simple trick to stop them from becoming too large and causing training failure.
Once trained, how do we know if the model is any good? Accuracy alone can be misleading, especially if some categories are rare. We need a proper evaluation that looks at precision and recall for each class.
from sklearn.metrics import classification_report, accuracy_score
def evaluate_model(model, data_loader, device):
model.eval()
predictions = []
true_labels = []
with torch.no_grad(): # Turn off gradient calculation for speed
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
_, preds = torch.max(outputs, dim=1)
predictions.extend(preds.cpu().tolist())
true_labels.extend(labels.cpu().tolist())
# Generate a detailed performance report
report = classification_report(true_labels, predictions, output_dict=True)
accuracy = accuracy_score(true_labels, predictions)
print(f"Accuracy: {accuracy:.4f}")
for label_id, metrics in report.items():
if label_id.isdigit(): # Print metrics for each class
print(f"Class {label_id} - Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}")
return report
The key here is model.eval() and torch.no_grad(). They tell PyTorch we are testing, not training. This disables layers like dropout and stops the framework from tracking operations for gradient calculation, making inference much faster.
Finally, a model is only useful if others can use it. We need a simple way to serve predictions. A small API with FastAPI is a robust solution.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
app = FastAPI()
class PredictionRequest(BaseModel):
text: str
# Load the model and tokenizer once when the app starts
model = torch.load('models/best_model.pt', map_location=torch.device('cpu'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@app.post("/predict")
def predict(request: PredictionRequest):
try:
# Tokenize the incoming text
encoding = tokenizer.encode_plus(
request.text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Run the model
with torch.no_grad():
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
outputs = model(input_ids, attention_mask)
_, prediction = torch.max(outputs, dim=1)
return {"predicted_class": int(prediction[0])}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
This API wraps our model in a web service. It defines what input it expects (a JSON object with a text field), processes the text through the same tokenization pipeline we used for training, runs the model, and returns the result. It’s now ready to be integrated into a website or application.
Building a system like this changed how we handled customer data. It moved from a manual, error-prone task to an automated, insightful process. The journey from raw text to a live API involves several clear steps: preparing data, constructing a model, training it carefully, evaluating it thoroughly, and finally serving it reliably. Each piece builds on the last. Start simple, get a basic version working, and then iterate. What’s the first text classification problem you would solve with this?
If this guide helped you connect the pieces between theory and a working system, please share it with a colleague or team who might be facing a similar challenge. Have you tried implementing a BERT model before? What was your biggest hurdle? Let me know in the comments below.
As a best-selling author, I invite you to explore my books on Amazon. Don’t forget to follow me on Medium and show your support. Thank you! Your support means the world!
101 Books
101 Books is an AI-driven publishing company co-founded by author Aarav Joshi. By leveraging advanced AI technology, we keep our publishing costs incredibly low—some books are priced as low as $4—making quality knowledge accessible to everyone.
Check out our book Golang Clean Code available on Amazon.
Stay tuned for updates and exciting news. When shopping for books, search for Aarav Joshi to find more of our titles. Use the provided link to enjoy special discounts!
📘 Checkout my latest ebook for free on my channel!
Be sure to like, share, comment, and subscribe to the channel!
Our Creations
Be sure to check out our creations:
Investor Central | Investor Central Spanish | Investor Central German | Smart Living | Epochs & Echoes | Puzzling Mysteries | Hindutva | Elite Dev | JS Schools
We are on Medium
Tech Koala Insights | Epochs & Echoes World | Investor Central Medium | Puzzling Mysteries Medium | Science & Epochs Medium | Modern Hindutva