Contents Menu Expand Light mode Dark mode Auto light/dark, in light mode Auto light/dark, in dark mode
Back to top

Example

src/model.py
 1from torch.nn import BatchNorm2d, Conv2d, Dropout2d, Linear, Module
 2from torch.nn import functional as F
 3
 4
 5class Model(Module):
 6    def __init__(self):
 7        super(Model, self).__init__()
 8
 9        self.conv1 = Conv2d(1, 32, (3,3))
10        self.conv2 = Conv2d(32, 64, (3,3))
11
12        self.bn = BatchNorm2d(64)
13
14        self.drop1 = Dropout2d(0.25)
15        self.drop2 = Dropout2d(0.5)
16
17        self.fc1 = Linear(9216, 128)
18        self.fc2 = Linear(128, 10)
19
20    def forward(self, x):
21        x = F.relu(self.conv1(x))
22        x = F.relu(self.conv2(x))
23
24        x = self.bn(x)
25
26        x = F.dropout(F.max_pool2d(x, 2), 0.25)
27
28        x = x.flatten(1)
29        x = F.relu(self.fc1(x))
30
31        x = F.dropout(x, 0.5)
32        x = self.fc2(x)
33
34        return F.log_softmax(x, 1)
src/train.py
 1from ddpw import functional as DF
 2from torch import distributed as dist
 3from torch import empty
 4from torch.nn import functional as F
 5from torch.optim import SGD
 6from torch.utils.data import DataLoader
 7from src import MNISTModel
 8
 9
10def train(*args, **kwargs):
11    batch_size, epochs, dataset = args
12    global_rank, local_rank, platform = (
13        kwargs["global_rank"],
14        kwargs["local_rank"],
15        kwargs["platform"],
16    )
17    print(f"Global rank {global_rank}; local rank {local_rank}")
18
19    # set the current device
20    DF.set_device(local_rank, platform)
21
22    # move the (distributed) model to correct device (MPS/GPU)
23    model = DF.to(MNISTModel(), local_rank)
24
25    # sample the dataset
26    sampler = DF.get_dataset_sampler(dataset, global_rank, platform)
27    data = DataLoader(dataset, batch_size, sampler=sampler, pin_memory=True)
28
29    # the optimiser
30    optim = SGD(model.parameters(), lr=1e-2)
31
32    # losses
33    training_loss = empty((1,)).to(global_rank)
34
35    # training over epochs...
36    for e in range(epochs):
37        # ...and over batches
38        for imgs, labels in data:
39            optim.zero_grad()
40            preds = model(imgs.to(global_rank))
41            loss = F.nll_loss(preds, labels.to(global_rank))
42            loss.backward()
43            optim.step()
44        training_loss /= len(data)
45
46        if platform.requires_ipc:
47            dist.all_reduce(training_loss, dist.ReduceOp.SUM)
48            training_loss /= platform.world_size
49
50        # logging: console/tensorboard/wandb/etc.
51        if global_rank == 0:
52            print(training_loss.item())
src/evaluate.py
 1from ddpw import functional as DF
 2from torch import distributed as dist
 3from torch import empty, load, no_grad
 4from torch.backends import cudnn
 5from torch.utils.data import DataLoader
 6from src import MNISTModel
 7
 8cudnn.deterministic = True
 9
10
11@no_grad()
12def evaluate(*args, **kwargs):
13    batch_size, dataset, ckptfile = args
14    global_rank, platform = kwargs["global_rank"], kwargs["platform"]
15
16    # set the current device
17    DF.set_device(global_rank, platform)
18
19    # move the (distributed) model to correct device (MPS/GPU)
20    model = MNISTModel()
21    model.load_state_dict(load(ckptfile))
22    model = DF.to(model, global_rank).eval()
23
24    # sample the dataset
25    sampler = DF.get_dataset_sampler(dataset, global_rank, platform)
26    if platform.world_size > 1:
27        sampler.set_epoch(0)
28    data = DataLoader(dataset, batch_size, sampler=sampler, pin_memory=True)
29
30    # evaluation metrics
31    accuracy = empty((1,)).to(global_rank)
32
33    # evaluation in batches
34    for imgs, labels in data:
35        preds = model(imgs.to(global_rank))
36        accuracy += (
37            (preds.argmax(-1) == labels.to(global_rank)).sum()
38        ) / batch_size
39    accuracy /= len(data) / 100
40
41    if platform.requires_ipc:
42        dist.all_reduce(accuracy, dist.ReduceOp.SUM)
43        accuracy /= platform.world_size
44
45    if global_rank == 0:
46        print(accuracy.item())
main.py
 1from ddpw import Wrapper, Platform
 2from torchvision.datasets.mnist import MNIST
 3
 4from src import train, evaluate
 5
 6
 7if __name__ == "__main__":
 8    epochs = ...
 9    batch_size = ...
10    dataset = MNIST(root="./input/datasets/MNIST/", train=True, transform=...)
11
12    platform = Platform(...)
13    Wrapper(platform).start(train, model, batch_size, epochs, dataset)