Example¶
main.py¶
1from ddpw import Platform, Wrapper
2
3from torchvision.datasets.mnist import MNIST
4from torchvision import transforms as T
5
6from src.model import Model
7from src.example import Example
8
9
10model = Model()
11t = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
12dataset = MNIST(root='./datasets/MNSIT', train=True, download=True, transform=t)
13
14platform = Platform(device='gpu', n_gpus=4)
15example = Example(model, dataset, platform=platform, batch_size=32, epochs=2)
16
17wrapper = Wrapper(platform=platform)
18wrapper.start(example)
src/model.py¶
1import torch
2from torch.nn import functional as F
3
4
5class Model(torch.nn.Module):
6 def __init__(self):
7 super(Model, self).__init__()
8
9 self.conv1 = torch.nn.Conv2d(1, 32, (3,3))
10 self.conv2 = torch.nn.Conv2d(32, 64, (3,3))
11
12 self.bn = torch.nn.BatchNorm2d(64)
13
14 self.drop1 = torch.nn.Dropout2d(0.25)
15 self.drop2 = torch.nn.Dropout2d(0.5)
16
17 self.fc1 = torch.nn.Linear(9216, 128)
18 self.fc2 = torch.nn.Linear(128, 10)
19
20 def forward(self, x):
21 x = F.relu(self.conv1(x))
22 x = F.relu(self.conv2(x))
23
24 x = self.bn(x)
25
26 x = F.dropout(F.max_pool2d(x, 2), 0.25)
27
28 x = x.flatten(1)
29 x = F.relu(self.fc1(x))
30
31 x = F.dropout(x, 0.5)
32 x = self.fc2(x)
33
34 return F.log_softmax(x, 1)
src/example.py¶
1from tqdm import tqdm
2import torch
3from torch import distributed as dist
4from torch.nn import functional as F
5
6from ddpw import functional as DF
7from torch.utils.data import DataLoader
8
9
10class Example:
11 def __init__(self, model, dataset, platform, batch_size, epochs):
12 self.model = model
13 self.dataset = dataset
14 self.platform = platform
15
16 self.batch_size = batch_size
17 self.epochs = epochs
18
19 def __call__(self, global_rank, local_rank):
20 print(f'Global rank {global_rank}; local rank {local_rank}')
21 model = DF.to(self.model, local_rank, device=self.platform.device)
22 dataloader = DataLoader(
23 self.dataset,
24 sampler=DF.get_dataset_sampler(self.dataset, global_rank, self.platform),
25 batch_size=self.batch_size,
26 pin_memory=True
27 )
28 optimiser = torch.optim.SGD(model.parameters(), lr=1e-3)
29
30 training_loss = torch.Tensor([0.0]).to(DF.device(model))
31 torch.cuda.set_device(local_rank)
32 print(f'Model on device {DF.device(model)}; dataset size: {len(dataloader) * self.batch_size}')
33
34 # for every epoch
35 for e in range(self.epochs):
36 print(f'Epoch {e} of {self.epochs}')
37
38 for _, (imgs, labels) in enumerate(tqdm(dataloader, position=local_rank)):
39 optimiser.zero_grad()
40
41 preds = model(imgs.to(DF.device(model)))
42 loss = F.nll_loss(preds, labels.to(DF.device(model)))
43 training_loss += loss
44 loss.backward()
45
46 optimiser.step()
47
48 training_loss /= len(dataloader)
49
50 # synchronise metrics
51 if self.platform.requires_ipc:
52 dist.all_reduce(training_loss, dist.ReduceOp.SUM)
53 training_loss /= dist.get_world_size()
54
55 if global_rank == 0:
56 # code for storing logs and saving state
57 print(training_loss.item())