Example¶
src/model.py¶
1from torch.nn import BatchNorm2d, Conv2d, Dropout2d, Linear, Module
2from torch.nn import functional as F
3
4
5class Model(Module):
6 def __init__(self):
7 super(Model, self).__init__()
8
9 self.conv1 = Conv2d(1, 32, (3,3))
10 self.conv2 = Conv2d(32, 64, (3,3))
11
12 self.bn = BatchNorm2d(64)
13
14 self.drop1 = Dropout2d(0.25)
15 self.drop2 = Dropout2d(0.5)
16
17 self.fc1 = Linear(9216, 128)
18 self.fc2 = Linear(128, 10)
19
20 def forward(self, x):
21 x = F.relu(self.conv1(x))
22 x = F.relu(self.conv2(x))
23
24 x = self.bn(x)
25
26 x = F.dropout(F.max_pool2d(x, 2), 0.25)
27
28 x = x.flatten(1)
29 x = F.relu(self.fc1(x))
30
31 x = F.dropout(x, 0.5)
32 x = self.fc2(x)
33
34 return F.log_softmax(x, 1)
src/train.py¶
1from ddpw import functional as DF
2from torch import distributed as dist
3from torch import empty
4from torch.nn import functional as F
5from torch.optim import SGD
6from torch.utils.data import DataLoader
7from src import MNISTModel
8
9
10def train(*args, **kwargs):
11 batch_size, epochs, dataset = args
12 global_rank, local_rank, platform = (
13 kwargs["global_rank"],
14 kwargs["local_rank"],
15 kwargs["platform"],
16 )
17 print(f"Global rank {global_rank}; local rank {local_rank}")
18
19 # set the current device
20 DF.set_device(local_rank, platform)
21
22 # move the (distributed) model to correct device (MPS/GPU)
23 model = DF.to(MNISTModel(), local_rank)
24
25 # sample the dataset
26 sampler = DF.get_dataset_sampler(dataset, global_rank, platform)
27 data = DataLoader(dataset, batch_size, sampler=sampler, pin_memory=True)
28
29 # the optimiser
30 optim = SGD(model.parameters(), lr=1e-2)
31
32 # losses
33 training_loss = empty((1,)).to(global_rank)
34
35 # training over epochs...
36 for e in range(epochs):
37 # ...and over batches
38 for imgs, labels in data:
39 optim.zero_grad()
40 preds = model(imgs.to(global_rank))
41 loss = F.nll_loss(preds, labels.to(global_rank))
42 loss.backward()
43 optim.step()
44 training_loss /= len(data)
45
46 if platform.requires_ipc:
47 dist.all_reduce(training_loss, dist.ReduceOp.SUM)
48 training_loss /= platform.world_size
49
50 # logging: console/tensorboard/wandb/etc.
51 if global_rank == 0:
52 print(training_loss.item())
src/evaluate.py¶
1from ddpw import functional as DF
2from torch import distributed as dist
3from torch import empty, load, no_grad
4from torch.backends import cudnn
5from torch.utils.data import DataLoader
6from src import MNISTModel
7
8cudnn.deterministic = True
9
10
11@no_grad()
12def evaluate(*args, **kwargs):
13 batch_size, dataset, ckptfile = args
14 global_rank, platform = kwargs["global_rank"], kwargs["platform"]
15
16 # set the current device
17 DF.set_device(global_rank, platform)
18
19 # move the (distributed) model to correct device (MPS/GPU)
20 model = MNISTModel()
21 model.load_state_dict(load(ckptfile))
22 model = DF.to(model, global_rank).eval()
23
24 # sample the dataset
25 sampler = DF.get_dataset_sampler(dataset, global_rank, platform)
26 if platform.world_size > 1:
27 sampler.set_epoch(0)
28 data = DataLoader(dataset, batch_size, sampler=sampler, pin_memory=True)
29
30 # evaluation metrics
31 accuracy = empty((1,)).to(global_rank)
32
33 # evaluation in batches
34 for imgs, labels in data:
35 preds = model(imgs.to(global_rank))
36 accuracy += (
37 (preds.argmax(-1) == labels.to(global_rank)).sum()
38 ) / batch_size
39 accuracy /= len(data) / 100
40
41 if platform.requires_ipc:
42 dist.all_reduce(accuracy, dist.ReduceOp.SUM)
43 accuracy /= platform.world_size
44
45 if global_rank == 0:
46 print(accuracy.item())
main.py¶
1from ddpw import Wrapper, Platform
2from torchvision.datasets.mnist import MNIST
3
4from src import train, evaluate
5
6
7if __name__ == "__main__":
8 epochs = ...
9 batch_size = ...
10 dataset = MNIST(root="./input/datasets/MNIST/", train=True, transform=...)
11
12 platform = Platform(...)
13 Wrapper(platform).start(train, model, batch_size, epochs, dataset)