Author: He XingChen
Last Updated: 2026-02-07

Training Overview

SeisPolarity provides a unified training interface with checkpoint saving, early stopping mechanisms, and flexible configuration.

Basic Training

Quick Start

 1from seispolarity.models import PPNet
 2from seispolarity.training import Trainer, TrainingConfig
 3from seispolarity import WaveformDataset
 4
 5# Load dataset
 6dataset = WaveformDataset(path="data.hdf5", name="SCSN", preload=True)
 7
 8# Create model
 9model = PPNet(num_fm_classes=3)
10
11# Configure training
12config = TrainingConfig(
13    batch_size=256,
14    epochs=50,
15    learning_rate=1e-4,
16    device="cuda"
17)
18
19# Create trainer
20trainer = Trainer(model=model, dataset=dataset, config=config)
21
22# Train
23trainer.train()

TrainingConfig

The TrainingConfig class provides all training parameters:

 1from seispolarity.training import TrainingConfig
 2
 3config = TrainingConfig(
 4    # Data parameters
 5    batch_size=256,
 6    num_workers=4,
 7    pin_memory=True,
 8
 9    # Optimization parameters
10    epochs=50,
11    learning_rate=1e-4,
12    optimizer="adam",           # or "adamw", "sgd"
13    weight_decay=1e-5,
14
15    # Learning rate scheduler
16    lr_scheduler=None,          # or "step", "cosine", "reduce_on_plateau"
17    lr_scheduler_params=None,
18
19    # Training behavior
20    gradient_clip_value=None,   # Gradient clipping
21    early_stopping_patience=10,
22    early_stopping_min_delta=1e-4,
23
24    # Checkpoints
25    save_dir="./checkpoints",
26    save_every=5,               # Save every N epochs
27    save_best_only=True,
28
29    # Validation
30    validation_split=0.1,       # Fraction of data for validation
31    validation_every=1,         # Validate every N epochs
32
33    # Device
34    device="cuda",              # or "cpu", "mps"
35
36    # Logging
37    log_every=100,              # Log every N batches
38    tensorboard_dir=None,       # TensorBoard log directory
39)

Optimizers

SeisPolarity supports multiple optimizers:

 1from seispolarity.training import TrainingConfig
 2
 3# Adam (default)
 4config = TrainingConfig(optimizer="adam", learning_rate=1e-4)
 5
 6# AdamW
 7config = TrainingConfig(optimizer="adamw", learning_rate=1e-4, weight_decay=1e-5)
 8
 9# SGD
10config = TrainingConfig(
11    optimizer="sgd",
12    learning_rate=1e-3,
13    weight_decay=1e-4,
14    momentum=0.9
15)

Learning Rate Scheduler

Step LR

1config = TrainingConfig(
2    lr_scheduler="step",
3    lr_scheduler_params={
4        "step_size": 10,
5        "gamma": 0.5
6    }
7)

Cosine Annealing

1config = TrainingConfig(
2    lr_scheduler="cosine",
3    lr_scheduler_params={
4        "T_max": 50
5    }
6)

Reduce on Plateau

1config = TrainingConfig(
2    lr_scheduler="reduce_on_plateau",
3    lr_scheduler_params={
4        "mode": "min",
5        "factor": 0.5,
6        "patience": 5
7    }
8)

Custom Training Loop

For greater control over the training process:

 1from seispolarity.models import PPNet
 2from seispolarity import WaveformDataset
 3import torch
 4import torch.nn as nn
 5import torch.optim as optim
 6from torch.utils.data import DataLoader
 7from tqdm import tqdm
 8
 9# Setup
10device = "cuda"
11model = PPNet(num_fm_classes=3).to(device)
12dataset = WaveformDataset(path="data.hdf5", name="SCSN")
13loader = DataLoader(dataset, batch_size=256, shuffle=True)
14
15criterion = nn.CrossEntropyLoss()
16optimizer = optim.Adam(model.parameters(), lr=1e-4)
17
18# Training loop
19model.train()
20for epoch in range(50):
21    total_loss = 0
22    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/50")
23
24    for waveforms, labels in pbar:
25        waveforms = waveforms.to(device)
26        labels = labels.to(device)
27
28        # Forward pass
29        outputs = model(waveforms)
30        loss = criterion(outputs, labels)
31
32        # Backward pass
33        optimizer.zero_grad()
34        loss.backward()
35        optimizer.step()
36
37        total_loss += loss.item()
38        pbar.set_postfix({"loss": loss.item()})
39
40    avg_loss = total_loss / len(loader)
41    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

Validation

Manual Validation

 1from torch.utils.data import random_split
 2
 3# Split dataset
 4train_size = int(0.9 * len(dataset))
 5val_size = len(dataset) - train_size
 6train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
 7
 8# Create data loaders
 9train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
10val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
11
12# Validation function
13def validate(model, val_loader, criterion, device):
14    model.eval()
15    total_loss = 0
16    correct = 0
17    total = 0
18
19    with torch.no_grad():
20        for waveforms, labels in val_loader:
21            waveforms = waveforms.to(device)
22            labels = labels.to(device)
23
24            outputs = model(waveforms)
25            loss = criterion(outputs, labels)
26
27            total_loss += loss.item()
28            _, predicted = outputs.max(1)
29            total += labels.size(0)
30            correct += predicted.eq(labels).sum().item()
31
32    avg_loss = total_loss / len(val_loader)
33    accuracy = 100. * correct / total
34
35    return avg_loss, accuracy

Checkpoints

Saving Checkpoints

1torch.save({
2    'epoch': epoch,
3    'model_state_dict': model.state_dict(),
4    'optimizer_state_dict': optimizer.state_dict(),
5    'loss': avg_loss,
6    'accuracy': accuracy
7}, f'checkpoint_epoch_{epoch}.pth')

Loading Checkpoints

1checkpoint = torch.load('checkpoint_epoch_50.pth')
2model.load_state_dict(checkpoint['model_state_dict'])
3optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
4epoch = checkpoint['epoch']
5loss = checkpoint['loss']

Early Stopping

 1class EarlyStopping:
 2    def __init__(self, patience=10, min_delta=1e-4):
 3        self.patience = patience
 4        self.min_delta = min_delta
 5        self.counter = 0
 6        self.best_loss = None
 7
 8    def __call__(self, val_loss):
 9        if self.best_loss is None:
10            self.best_loss = val_loss
11        elif val_loss > self.best_loss - self.min_delta:
12            self.counter += 1
13            if self.counter >= self.patience:
14                return True  # Stop training
15        else:
16            self.best_loss = val_loss
17            self.counter = 0
18        return False
19
20# Usage
21early_stopping = EarlyStopping(patience=10)
22
23for epoch in range(epochs):
24    # ... training code ...
25
26    val_loss, _ = validate(model, val_loader, criterion, device)
27    if early_stopping(val_loss):
28        print(f"Early stopping at epoch {epoch}")
29        break

TensorBoard Logging

 1from torch.utils.tensorboard import SummaryWriter
 2
 3writer = SummaryWriter("./logs")
 4
 5# Log loss during training
 6for batch_idx, (waveforms, labels) in enumerate(train_loader):
 7    # ... training code ...
 8    writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
 9
10# Log validation metrics
11writer.add_scalar('Loss/val', val_loss, epoch)
12writer.add_scalar('Accuracy/val', accuracy, epoch)
13
14writer.close()

Multi-GPU Training

 1import torch.multiprocessing as mp
 2import torch.distributed as dist
 3
 4def train_worker(rank, world_size):
 5    dist.init_process_group("nccl", rank=rank, world_size=world_size)
 6    model = PPNet(num_fm_classes=3)
 7    model = model.to(rank)
 8    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
 9
10    # ... training code ...
11
12if __name__ == "__main__":
13    world_size = torch.cuda.device_count()
14    mp.spawn(train_worker, args=(world_size,), nprocs=world_size, join=True)

Training Tips

  1. Data Loading: Use num_workers > 0 and pin_memory=True to speed up data loading

  2. Batch Size: Start with small batches (256-512) and increase if memory allows

  3. Learning Rate: Use smaller learning rates (1e-4 to 1e-5) for fine-tuning

  4. Validation: Always use a validation set to monitor overfitting

  5. Checkpoints: Save checkpoints regularly to avoid losing progress

Example: Complete Training Script

 1import torch
 2import torch.nn as nn
 3import torch.optim as optim
 4from torch.utils.data import DataLoader, random_split
 5from torch.utils.tensorboard import SummaryWriter
 6from seispolarity.models import PPNet
 7from seispolarity import WaveformDataset
 8
 9# Configuration
10DEVICE = "cuda"
11BATCH_SIZE = 256
12EPOCHS = 50
13LEARNING_RATE = 1e-4
14VALIDATION_SPLIT = 0.1
15
16# Load dataset
17dataset = WaveformDataset(path="data.hdf5", name="SCSN", preload=True)
18
19# Split train/validation
20train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
21val_size = len(dataset) - train_size
22train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
23
24# Data loaders
25train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
26val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
27
28# Model
29model = PPNet(num_fm_classes=3).to(DEVICE)
30criterion = nn.CrossEntropyLoss()
31optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
32
33# TensorBoard
34writer = SummaryWriter("./logs")
35
36best_accuracy = 0
37
38for epoch in range(EPOCHS):
39    # Training
40    model.train()
41    train_loss = 0
42    for waveforms, labels in train_loader:
43        waveforms = waveforms.to(DEVICE)
44        labels = labels.to(DEVICE)
45
46        optimizer.zero_grad()
47        outputs = model(waveforms)
48        loss = criterion(outputs, labels)
49        loss.backward()
50        optimizer.step()
51
52        train_loss += loss.item()
53
54    # Validation
55    model.eval()
56    val_loss = 0
57    correct = 0
58    total = 0
59
60    with torch.no_grad():
61        for waveforms, labels in val_loader:
62            waveforms = waveforms.to(DEVICE)
63            labels = labels.to(DEVICE)
64
65            outputs = model(waveforms)
66            loss = criterion(outputs, labels)
67
68            val_loss += loss.item()
69            _, predicted = outputs.max(1)
70            total += labels.size(0)
71            correct += predicted.eq(labels).sum().item()
72
73    # Metrics
74    train_loss /= len(train_loader)
75    val_loss /= len(val_loader)
76    accuracy = 100. * correct / total
77
78    # Logging
79    print(f"Epoch {epoch+1}/{EPOCHS}")
80    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy:.2f}%")
81
82    writer.add_scalar('Loss/train', train_loss, epoch)
83    writer.add_scalar('Loss/val', val_loss, epoch)
84    writer.add_scalar('Accuracy/val', accuracy, epoch)
85
86    # Save best model
87    if accuracy > best_accuracy:
88        best_accuracy = accuracy
89        torch.save(model.state_dict(), 'best_model.pth')
90
91writer.close()

For detailed API documentation, please refer to API Reference.