Author: He XingChen
Last Updated: 2026-02-07

Models Overview

SeisPolarity provides a series of deep learning models for seismic polarity classification. All models are implemented using PyTorch and support GPU acceleration.

Available Models

Model

Input Length

Classes

Ross (SCSN)

400

3 (U/D/N)

Eqpolarity

600

2 (U/D)

DiTingMotion

128

3 (U/D/N)

CFM

160

2 (U/D)

RPNet

400

2 (U/D)

PolarCAP

64

2 (U/D)

APP

400

3 (U/D/N)

Loading Models

Pre-trained Models

Load pre-trained models from Hugging Face:

 1from seispolarity.models import PPNet, RossNet, EqpolarityNet
 2
 3# Ross model
 4model = RossNet(num_fm_classes=3)
 5
 6# Eqpolarity model
 7model = EqpolarityNet()
 8
 9# Generic PPNet (for SCSN)
10model = PPNet(num_fm_classes=3)

Loading Custom Weights

 1import torch
 2from seispolarity.models import PPNet
 3
 4model = PPNet(num_fm_classes=3)
 5
 6# Load from checkpoint
 7checkpoint = torch.load("checkpoints/model.pth")
 8model.load_state_dict(checkpoint["model_state_dict"])
 9
10# Or load directly
11model.load_state_dict(torch.load("model_weights.pth"))

Model API

All models in SeisPolarity expose a unified PyTorch interface:

 1import torch
 2from seispolarity.models import PPNet
 3
 4# Create model
 5model = PPNet(num_fm_classes=3)
 6model.eval()
 7
 8# Prepare input
 9waveforms = torch.randn(10, 1, 400)  # (batch, channels, length)
10
11# Forward pass
12with torch.no_grad():
13    logits = model(waveforms)
14    predictions = logits.argmax(dim=1)  # Get predicted classes

Model Architectures

Ross (SCSN)

../../_images/ross.png

The Ross model is a CNN-based architecture optimized for SCSN data.

Input: 400 samples (4 seconds at 100 Hz sampling rate)

Architecture:

  • Convolutional layers with batch normalization

  • Max pooling

  • Fully connected layers

  • Dropout for regularization

1from seispolarity.models import PPNet
2
3model = PPNet(num_fm_classes=3)

Eqpolarity

../../_images/eqpolarity.png

Eqpolarity is a deep CNN model for polarity classification.

Input: 600 samples (6 seconds at 100 Hz sampling rate)

1from seispolarity.models import EqpolarityNet
2
3model = EqpolarityNet()

DiTingMotion

../../_images/ditingmotion.png

Motion-based polarity classification model.

Input: 128 samples (1.28 seconds at 100 Hz sampling rate)

1from seispolarity.models import DiTingMotionNet
2
3model = DiTingMotionNet()

CFM

../../_images/cfm.png

Custom architecture for polarity detection.

Input: 160 samples

1from seispolarity.models import CFM
2
3model = CFM()

RPNet

../../_images/rpnet.png

Residual Polarity Network.

Input: 400 samples (4 seconds at 100 Hz sampling rate)

1from seispolarity.models import RPNet
2
3model = RPNet()

PolarCAP

../../_images/polarcap.png

Lightweight model for polarity classification.

Input: 64 samples (0.64 seconds at 100 Hz sampling rate)

1from seispolarity.models import PolarCAP, PolarCAPLoss
2
3model = PolarCAP()
4loss_fn = PolarCAPLoss()

APP

../../_images/app.png

Adaptive Polarity Predictor.

1from seispolarity.models import APP
2
3model = APP()

Inference

Batch Inference

 1import torch
 2from seispolarity.models import PPNet
 3from seispolarity import WaveformDataset
 4
 5# Load model
 6model = PPNet(num_fm_classes=3)
 7model.eval()
 8model.to("cuda")
 9
10# Load dataset
11dataset = WaveformDataset(path="data.hdf5", name="SCSN")
12loader = dataset.get_dataloader(batch_size=1024)
13
14# Inference
15all_predictions = []
16with torch.no_grad():
17    for waveforms, _ in loader:
18        waveforms = waveforms.to("cuda")
19        logits = model(waveforms)
20        predictions = logits.argmax(dim=1)
21        all_predictions.append(predictions.cpu())
22
23predictions = torch.cat(all_predictions)

Single Sample Inference

 1import numpy as np
 2from seispolarity.models import PPNet
 3import torch
 4
 5# Load model
 6model = PPNet(num_fm_classes=3)
 7model.eval()
 8
 9# Single sample
10waveform = np.random.randn(400)  # Single waveform
11waveform = torch.FloatTensor(waveform).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
12
13# Inference
14with torch.no_grad():
15    logits = model(waveform)
16    prediction = logits.argmax(dim=1).item()
17
18# Interpret result
19label_map = {0: "Up", 1: "Down", 2: "Unknown"}
20print(f"Predicted polarity: {label_map[prediction]}")

Model Output

Models output logits for each class:

1# Raw logits
2logits = model(waveforms)  # Shape: (batch_size, num_classes)
3
4# Probabilities (softmax)
5import torch.nn.functional as F
6probabilities = F.softmax(logits, dim=1)
7
8# Predicted classes
9predictions = logits.argmax(dim=1)

Model Download

Pre-trained models are available on Hugging Face:

 1from huggingface_hub import hf_hub_download
 2import torch
 3
 4# Download model weights
 5model_path = hf_hub_download(
 6    repo_id="HeXingChen/SeisPolarity-Model",
 7    filename="ross_scsn.pth"
 8)
 9
10# Load weights
11state_dict = torch.load(model_path)
12model.load_state_dict(state_dict)

Custom Models

To create a custom model:

 1import torch.nn as nn
 2from seispolarity.models.base import BasePolarityModel
 3
 4class CustomModel(BasePolarityModel):
 5    def __init__(self, num_fm_classes=3):
 6        super().__init__(num_fm_classes)
 7        # Define your architecture here
 8        self.conv1 = nn.Conv1d(1, 32, kernel_size=5)
 9        self.fc = nn.Linear(32 * 396, num_fm_classes)
10
11    def forward(self, x):
12        # Forward pass
13        x = self.conv1(x)
14        x = x.view(x.size(0), -1)
15        x = self.fc(x)
16        return x

For detailed API documentation, see the API Reference.