Training Overview
SeisPolarity provides a unified training interface with checkpoint saving, early stopping mechanisms, and flexible configuration.
Basic Training
Quick Start
1from seispolarity.models import PPNet
2from seispolarity.training import Trainer, TrainingConfig
3from seispolarity import WaveformDataset
4
5# Load dataset
6dataset = WaveformDataset(path="data.hdf5", name="SCSN", preload=True)
7
8# Create model
9model = PPNet(num_fm_classes=3)
10
11# Configure training
12config = TrainingConfig(
13 batch_size=256,
14 epochs=50,
15 learning_rate=1e-4,
16 device="cuda"
17)
18
19# Create trainer
20trainer = Trainer(model=model, dataset=dataset, config=config)
21
22# Train
23trainer.train()
TrainingConfig
The TrainingConfig class provides all training parameters:
1from seispolarity.training import TrainingConfig
2
3config = TrainingConfig(
4 # Data parameters
5 batch_size=256,
6 num_workers=4,
7 pin_memory=True,
8
9 # Optimization parameters
10 epochs=50,
11 learning_rate=1e-4,
12 optimizer="adam", # or "adamw", "sgd"
13 weight_decay=1e-5,
14
15 # Learning rate scheduler
16 lr_scheduler=None, # or "step", "cosine", "reduce_on_plateau"
17 lr_scheduler_params=None,
18
19 # Training behavior
20 gradient_clip_value=None, # Gradient clipping
21 early_stopping_patience=10,
22 early_stopping_min_delta=1e-4,
23
24 # Checkpoints
25 save_dir="./checkpoints",
26 save_every=5, # Save every N epochs
27 save_best_only=True,
28
29 # Validation
30 validation_split=0.1, # Fraction of data for validation
31 validation_every=1, # Validate every N epochs
32
33 # Device
34 device="cuda", # or "cpu", "mps"
35
36 # Logging
37 log_every=100, # Log every N batches
38 tensorboard_dir=None, # TensorBoard log directory
39)
Optimizers
SeisPolarity supports multiple optimizers:
1from seispolarity.training import TrainingConfig
2
3# Adam (default)
4config = TrainingConfig(optimizer="adam", learning_rate=1e-4)
5
6# AdamW
7config = TrainingConfig(optimizer="adamw", learning_rate=1e-4, weight_decay=1e-5)
8
9# SGD
10config = TrainingConfig(
11 optimizer="sgd",
12 learning_rate=1e-3,
13 weight_decay=1e-4,
14 momentum=0.9
15)
Learning Rate Scheduler
Step LR
1config = TrainingConfig(
2 lr_scheduler="step",
3 lr_scheduler_params={
4 "step_size": 10,
5 "gamma": 0.5
6 }
7)
Cosine Annealing
1config = TrainingConfig(
2 lr_scheduler="cosine",
3 lr_scheduler_params={
4 "T_max": 50
5 }
6)
Reduce on Plateau
1config = TrainingConfig(
2 lr_scheduler="reduce_on_plateau",
3 lr_scheduler_params={
4 "mode": "min",
5 "factor": 0.5,
6 "patience": 5
7 }
8)
Custom Training Loop
For greater control over the training process:
1from seispolarity.models import PPNet
2from seispolarity import WaveformDataset
3import torch
4import torch.nn as nn
5import torch.optim as optim
6from torch.utils.data import DataLoader
7from tqdm import tqdm
8
9# Setup
10device = "cuda"
11model = PPNet(num_fm_classes=3).to(device)
12dataset = WaveformDataset(path="data.hdf5", name="SCSN")
13loader = DataLoader(dataset, batch_size=256, shuffle=True)
14
15criterion = nn.CrossEntropyLoss()
16optimizer = optim.Adam(model.parameters(), lr=1e-4)
17
18# Training loop
19model.train()
20for epoch in range(50):
21 total_loss = 0
22 pbar = tqdm(loader, desc=f"Epoch {epoch+1}/50")
23
24 for waveforms, labels in pbar:
25 waveforms = waveforms.to(device)
26 labels = labels.to(device)
27
28 # Forward pass
29 outputs = model(waveforms)
30 loss = criterion(outputs, labels)
31
32 # Backward pass
33 optimizer.zero_grad()
34 loss.backward()
35 optimizer.step()
36
37 total_loss += loss.item()
38 pbar.set_postfix({"loss": loss.item()})
39
40 avg_loss = total_loss / len(loader)
41 print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
Validation
Manual Validation
1from torch.utils.data import random_split
2
3# Split dataset
4train_size = int(0.9 * len(dataset))
5val_size = len(dataset) - train_size
6train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
7
8# Create data loaders
9train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
10val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
11
12# Validation function
13def validate(model, val_loader, criterion, device):
14 model.eval()
15 total_loss = 0
16 correct = 0
17 total = 0
18
19 with torch.no_grad():
20 for waveforms, labels in val_loader:
21 waveforms = waveforms.to(device)
22 labels = labels.to(device)
23
24 outputs = model(waveforms)
25 loss = criterion(outputs, labels)
26
27 total_loss += loss.item()
28 _, predicted = outputs.max(1)
29 total += labels.size(0)
30 correct += predicted.eq(labels).sum().item()
31
32 avg_loss = total_loss / len(val_loader)
33 accuracy = 100. * correct / total
34
35 return avg_loss, accuracy
Checkpoints
Saving Checkpoints
1torch.save({
2 'epoch': epoch,
3 'model_state_dict': model.state_dict(),
4 'optimizer_state_dict': optimizer.state_dict(),
5 'loss': avg_loss,
6 'accuracy': accuracy
7}, f'checkpoint_epoch_{epoch}.pth')
Loading Checkpoints
1checkpoint = torch.load('checkpoint_epoch_50.pth')
2model.load_state_dict(checkpoint['model_state_dict'])
3optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
4epoch = checkpoint['epoch']
5loss = checkpoint['loss']
Early Stopping
1class EarlyStopping:
2 def __init__(self, patience=10, min_delta=1e-4):
3 self.patience = patience
4 self.min_delta = min_delta
5 self.counter = 0
6 self.best_loss = None
7
8 def __call__(self, val_loss):
9 if self.best_loss is None:
10 self.best_loss = val_loss
11 elif val_loss > self.best_loss - self.min_delta:
12 self.counter += 1
13 if self.counter >= self.patience:
14 return True # Stop training
15 else:
16 self.best_loss = val_loss
17 self.counter = 0
18 return False
19
20# Usage
21early_stopping = EarlyStopping(patience=10)
22
23for epoch in range(epochs):
24 # ... training code ...
25
26 val_loss, _ = validate(model, val_loader, criterion, device)
27 if early_stopping(val_loss):
28 print(f"Early stopping at epoch {epoch}")
29 break
TensorBoard Logging
1from torch.utils.tensorboard import SummaryWriter
2
3writer = SummaryWriter("./logs")
4
5# Log loss during training
6for batch_idx, (waveforms, labels) in enumerate(train_loader):
7 # ... training code ...
8 writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
9
10# Log validation metrics
11writer.add_scalar('Loss/val', val_loss, epoch)
12writer.add_scalar('Accuracy/val', accuracy, epoch)
13
14writer.close()
Multi-GPU Training
1import torch.multiprocessing as mp
2import torch.distributed as dist
3
4def train_worker(rank, world_size):
5 dist.init_process_group("nccl", rank=rank, world_size=world_size)
6 model = PPNet(num_fm_classes=3)
7 model = model.to(rank)
8 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
9
10 # ... training code ...
11
12if __name__ == "__main__":
13 world_size = torch.cuda.device_count()
14 mp.spawn(train_worker, args=(world_size,), nprocs=world_size, join=True)
Training Tips
Data Loading: Use
num_workers > 0andpin_memory=Trueto speed up data loadingBatch Size: Start with small batches (256-512) and increase if memory allows
Learning Rate: Use smaller learning rates (1e-4 to 1e-5) for fine-tuning
Validation: Always use a validation set to monitor overfitting
Checkpoints: Save checkpoints regularly to avoid losing progress
Example: Complete Training Script
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader, random_split
5from torch.utils.tensorboard import SummaryWriter
6from seispolarity.models import PPNet
7from seispolarity import WaveformDataset
8
9# Configuration
10DEVICE = "cuda"
11BATCH_SIZE = 256
12EPOCHS = 50
13LEARNING_RATE = 1e-4
14VALIDATION_SPLIT = 0.1
15
16# Load dataset
17dataset = WaveformDataset(path="data.hdf5", name="SCSN", preload=True)
18
19# Split train/validation
20train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
21val_size = len(dataset) - train_size
22train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
23
24# Data loaders
25train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
26val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
27
28# Model
29model = PPNet(num_fm_classes=3).to(DEVICE)
30criterion = nn.CrossEntropyLoss()
31optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
32
33# TensorBoard
34writer = SummaryWriter("./logs")
35
36best_accuracy = 0
37
38for epoch in range(EPOCHS):
39 # Training
40 model.train()
41 train_loss = 0
42 for waveforms, labels in train_loader:
43 waveforms = waveforms.to(DEVICE)
44 labels = labels.to(DEVICE)
45
46 optimizer.zero_grad()
47 outputs = model(waveforms)
48 loss = criterion(outputs, labels)
49 loss.backward()
50 optimizer.step()
51
52 train_loss += loss.item()
53
54 # Validation
55 model.eval()
56 val_loss = 0
57 correct = 0
58 total = 0
59
60 with torch.no_grad():
61 for waveforms, labels in val_loader:
62 waveforms = waveforms.to(DEVICE)
63 labels = labels.to(DEVICE)
64
65 outputs = model(waveforms)
66 loss = criterion(outputs, labels)
67
68 val_loss += loss.item()
69 _, predicted = outputs.max(1)
70 total += labels.size(0)
71 correct += predicted.eq(labels).sum().item()
72
73 # Metrics
74 train_loss /= len(train_loader)
75 val_loss /= len(val_loader)
76 accuracy = 100. * correct / total
77
78 # Logging
79 print(f"Epoch {epoch+1}/{EPOCHS}")
80 print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy:.2f}%")
81
82 writer.add_scalar('Loss/train', train_loss, epoch)
83 writer.add_scalar('Loss/val', val_loss, epoch)
84 writer.add_scalar('Accuracy/val', accuracy, epoch)
85
86 # Save best model
87 if accuracy > best_accuracy:
88 best_accuracy = accuracy
89 torch.save(model.state_dict(), 'best_model.pth')
90
91writer.close()
For detailed API documentation, please refer to API Reference.