Models Overview
SeisPolarity provides a series of deep learning models for seismic polarity classification. All models are implemented using PyTorch and support GPU acceleration.
Available Models
Model |
Input Length |
Classes |
|---|---|---|
Ross (SCSN) |
400 |
3 (U/D/N) |
Eqpolarity |
600 |
2 (U/D) |
DiTingMotion |
128 |
3 (U/D/N) |
CFM |
160 |
2 (U/D) |
RPNet |
400 |
2 (U/D) |
PolarCAP |
64 |
2 (U/D) |
APP |
400 |
3 (U/D/N) |
Loading Models
Pre-trained Models
Load pre-trained models from Hugging Face:
1from seispolarity.models import PPNet, RossNet, EqpolarityNet
2
3# Ross model
4model = RossNet(num_fm_classes=3)
5
6# Eqpolarity model
7model = EqpolarityNet()
8
9# Generic PPNet (for SCSN)
10model = PPNet(num_fm_classes=3)
Loading Custom Weights
1import torch
2from seispolarity.models import PPNet
3
4model = PPNet(num_fm_classes=3)
5
6# Load from checkpoint
7checkpoint = torch.load("checkpoints/model.pth")
8model.load_state_dict(checkpoint["model_state_dict"])
9
10# Or load directly
11model.load_state_dict(torch.load("model_weights.pth"))
Model API
All models in SeisPolarity expose a unified PyTorch interface:
1import torch
2from seispolarity.models import PPNet
3
4# Create model
5model = PPNet(num_fm_classes=3)
6model.eval()
7
8# Prepare input
9waveforms = torch.randn(10, 1, 400) # (batch, channels, length)
10
11# Forward pass
12with torch.no_grad():
13 logits = model(waveforms)
14 predictions = logits.argmax(dim=1) # Get predicted classes
Model Architectures
Ross (SCSN)
The Ross model is a CNN-based architecture optimized for SCSN data.
Input: 400 samples (4 seconds at 100 Hz sampling rate)
Architecture:
Convolutional layers with batch normalization
Max pooling
Fully connected layers
Dropout for regularization
1from seispolarity.models import PPNet
2
3model = PPNet(num_fm_classes=3)
Eqpolarity
Eqpolarity is a deep CNN model for polarity classification.
Input: 600 samples (6 seconds at 100 Hz sampling rate)
1from seispolarity.models import EqpolarityNet
2
3model = EqpolarityNet()
DiTingMotion
Motion-based polarity classification model.
Input: 128 samples (1.28 seconds at 100 Hz sampling rate)
1from seispolarity.models import DiTingMotionNet
2
3model = DiTingMotionNet()
CFM
Custom architecture for polarity detection.
Input: 160 samples
1from seispolarity.models import CFM
2
3model = CFM()
RPNet
Residual Polarity Network.
Input: 400 samples (4 seconds at 100 Hz sampling rate)
1from seispolarity.models import RPNet
2
3model = RPNet()
PolarCAP
Lightweight model for polarity classification.
Input: 64 samples (0.64 seconds at 100 Hz sampling rate)
1from seispolarity.models import PolarCAP, PolarCAPLoss
2
3model = PolarCAP()
4loss_fn = PolarCAPLoss()
APP
Adaptive Polarity Predictor.
1from seispolarity.models import APP
2
3model = APP()
Inference
Batch Inference
1import torch
2from seispolarity.models import PPNet
3from seispolarity import WaveformDataset
4
5# Load model
6model = PPNet(num_fm_classes=3)
7model.eval()
8model.to("cuda")
9
10# Load dataset
11dataset = WaveformDataset(path="data.hdf5", name="SCSN")
12loader = dataset.get_dataloader(batch_size=1024)
13
14# Inference
15all_predictions = []
16with torch.no_grad():
17 for waveforms, _ in loader:
18 waveforms = waveforms.to("cuda")
19 logits = model(waveforms)
20 predictions = logits.argmax(dim=1)
21 all_predictions.append(predictions.cpu())
22
23predictions = torch.cat(all_predictions)
Single Sample Inference
1import numpy as np
2from seispolarity.models import PPNet
3import torch
4
5# Load model
6model = PPNet(num_fm_classes=3)
7model.eval()
8
9# Single sample
10waveform = np.random.randn(400) # Single waveform
11waveform = torch.FloatTensor(waveform).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
12
13# Inference
14with torch.no_grad():
15 logits = model(waveform)
16 prediction = logits.argmax(dim=1).item()
17
18# Interpret result
19label_map = {0: "Up", 1: "Down", 2: "Unknown"}
20print(f"Predicted polarity: {label_map[prediction]}")
Model Output
Models output logits for each class:
1# Raw logits
2logits = model(waveforms) # Shape: (batch_size, num_classes)
3
4# Probabilities (softmax)
5import torch.nn.functional as F
6probabilities = F.softmax(logits, dim=1)
7
8# Predicted classes
9predictions = logits.argmax(dim=1)
Model Download
Pre-trained models are available on Hugging Face:
1from huggingface_hub import hf_hub_download
2import torch
3
4# Download model weights
5model_path = hf_hub_download(
6 repo_id="HeXingChen/SeisPolarity-Model",
7 filename="ross_scsn.pth"
8)
9
10# Load weights
11state_dict = torch.load(model_path)
12model.load_state_dict(state_dict)
Custom Models
To create a custom model:
1import torch.nn as nn
2from seispolarity.models.base import BasePolarityModel
3
4class CustomModel(BasePolarityModel):
5 def __init__(self, num_fm_classes=3):
6 super().__init__(num_fm_classes)
7 # Define your architecture here
8 self.conv1 = nn.Conv1d(1, 32, kernel_size=5)
9 self.fc = nn.Linear(32 * 396, num_fm_classes)
10
11 def forward(self, x):
12 # Forward pass
13 x = self.conv1(x)
14 x = x.view(x.size(0), -1)
15 x = self.fc(x)
16 return x
For detailed API documentation, see the API Reference.