Contents Menu Expand Light mode Dark mode Auto light/dark, in light mode Auto light/dark, in dark mode
Back to top

Example

main.py
 1from ddpw import Platform, Wrapper
 2
 3from torchvision.datasets.mnist import MNIST
 4from torchvision import transforms as T
 5
 6from src.model import Model
 7from src.example import Example
 8
 9
10model = Model()
11t = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
12dataset = MNIST(root='./datasets/MNSIT', train=True, download=True, transform=t)
13
14platform = Platform(device='gpu', n_gpus=4)
15example = Example(model, dataset, platform=platform, batch_size=32, epochs=2)
16
17wrapper = Wrapper(platform=platform)
18wrapper.start(example)
src/model.py
 1import torch
 2from torch.nn import functional as F
 3
 4
 5class Model(torch.nn.Module):
 6    def __init__(self):
 7        super(Model, self).__init__()
 8
 9        self.conv1 = torch.nn.Conv2d(1, 32, (3,3))
10        self.conv2 = torch.nn.Conv2d(32, 64, (3,3))
11
12        self.bn = torch.nn.BatchNorm2d(64)
13
14        self.drop1 = torch.nn.Dropout2d(0.25)
15        self.drop2 = torch.nn.Dropout2d(0.5)
16
17        self.fc1 = torch.nn.Linear(9216, 128)
18        self.fc2 = torch.nn.Linear(128, 10)
19
20    def forward(self, x):
21        x = F.relu(self.conv1(x))
22        x = F.relu(self.conv2(x))
23
24        x = self.bn(x)
25
26        x = F.dropout(F.max_pool2d(x, 2), 0.25)
27
28        x = x.flatten(1)
29        x = F.relu(self.fc1(x))
30
31        x = F.dropout(x, 0.5)
32        x = self.fc2(x)
33
34        return F.log_softmax(x, 1)
src/example.py
 1from tqdm import tqdm
 2import torch
 3from torch import distributed as dist
 4from torch.nn import functional as F
 5
 6from ddpw import functional as DF
 7from torch.utils.data import DataLoader
 8
 9
10class Example:
11    def __init__(self, model, dataset, platform, batch_size, epochs):
12        self.model = model
13        self.dataset = dataset
14        self.platform = platform
15
16        self.batch_size = batch_size
17        self.epochs = epochs
18
19    def __call__(self, global_rank, local_rank):
20        print(f'Global rank {global_rank}; local rank {local_rank}')
21        model = DF.to(self.model, local_rank, device=self.platform.device)
22        dataloader = DataLoader(
23            self.dataset,
24            sampler=DF.get_dataset_sampler(self.dataset, global_rank, self.platform),
25            batch_size=self.batch_size,
26            pin_memory=True
27        )
28        optimiser = torch.optim.SGD(model.parameters(), lr=1e-3)
29
30        training_loss = torch.Tensor([0.0]).to(DF.device(model))
31        torch.cuda.set_device(local_rank)
32        print(f'Model on device {DF.device(model)}; dataset size: {len(dataloader) * self.batch_size}')
33
34        # for every epoch
35        for e in range(self.epochs):
36            print(f'Epoch {e} of {self.epochs}')
37
38            for _, (imgs, labels) in enumerate(tqdm(dataloader, position=local_rank)):
39                optimiser.zero_grad()
40
41                preds = model(imgs.to(DF.device(model)))
42                loss = F.nll_loss(preds, labels.to(DF.device(model)))
43                training_loss += loss
44                loss.backward()
45
46                optimiser.step()
47
48            training_loss /= len(dataloader)
49
50            # synchronise metrics
51            if self.platform.requires_ipc:
52              dist.all_reduce(training_loss, dist.ReduceOp.SUM)
53              training_loss /= dist.get_world_size()
54
55            if global_rank == 0:
56                # code for storing logs and saving state
57                print(training_loss.item())