【FedAvg论文笔记】&【代码复现】

目录

n


一、FedAvg原始论文笔记

联邦平均算法经典论文:McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

我们知道联邦学习的思想就在于分布式的机器学习,同时兼顾了数据安全问题。而联邦平均算法是其中最典型的算法之一,FedAvg算法将每个客户端上的本地随机梯度下降和执行模型的平均服务器结合在一起。

1、联邦优化问题: 

 1、数据非独立同分布

 2、数据分布的不平衡性

 3、用户规模大

 4、通信有限

其中最重要的就是要理解什么是客户端数据集非独立同分布

举个栗子,假设某数据集A的train data中有5(1-5)个类别的手写数字250张,client1 本地数据集只有1、2手写数字50张(此时的1数据集占比为1/5),client2拥有的2、3、4、5手写图片张200(4/5),可想而知他们利用本地数据集进行学习,client1只能学习到1,2。client2只能学习到2、3、4再通过依靠数据集占比的权重聚合后,所得到的全局模型对1的学习能力会变得更弱。从这个例子来看,客户端数据集非独立同分布提现了样本类别少,不能代表全局样本的分布。

更有复杂的情况,样本标签混乱,不单一的情况下,数据集非独立同分布情况会更严重。

2、联邦平均算法:

我们需要注意的是,相比于传统的数据中心处理模式,在联邦学习中,客户端本地的计算量和服务器中聚合模型所花费的计算量是花费很小的,但客户端与服务器之间的通信代价较大,故文中提出两种方法以降低通信成本:

1、增加并行性(即使用更多的客户端独立训练模型

2、增加每个客户端计算量

首先本文提出FedSGD算法:

FedSGD算法

对K个客户端的数据计算其损失梯度,(F(Wt)表示在模型wt下数据的损失函数):

聚合K客户端的损失梯度,得到t+1轮模型参数:

而FedAvg算法就是在在本地执行了多次的FedSGD,在选定一定比例的客户端参加训练,而不是全部(实验部分会指出,全部的客户端参加比部分客户端才加的收敛速度慢,模型精度低。)

FedAvg算法:

在客户端进行局部模型的更新:

在服务器将局部模型上传,只进行一个平均算法:

可以看出,该算法将计算量放在了本地客户端,而服务器只用于聚合平均。故我们可以在平均步骤之前进行多次局部模型的更新。(这儿不防思考一下,这个次数是不是越多越好,我们知道过少本地数据集样本,过多的本地迭代轮次会造成什么问题?————过拟合

而上述计算量的大小由三个参数控制,即为C(客户端随机选取的比例)、E(客户端在第t轮通过本地数据集训练的次数)、B(参与本地局部模型更新所需的数据批量size)

所以,上述的FedSGD算法中有:C=1,E=1,B=无穷大

故,对于第K个客户端本地数据集大小为nk时,可得到这个客户端每轮的本地更新数为:

ps:客户端本地数据集与局部训练轮次的乘积/批量处理大小,为这个本轮客户端本地SGD的次数,FedAvg的伪代码如下:

实验结果:

1、基于mnist数据集手写照片的数字识别任务:

MNIST 2NN :一个简单的多层感知器,2个隐藏层,每个隐藏层200个单元,使用ReLu激活(199210个参数)

CNN:由两个5×5卷积层的CNN层(第一层有32个通道,第二个有64个,每个之后是2×2 max池化),一个全连接层(有512个单元)和ReLu激活,最后是一个softmax输出层(1663370个参数)

增加并行性实验:使用比例C控制并行处理的客户端数量

增加本地计算量实验结果:使用B(更新数据批量大小)和E(本地数据训练次数)来控制本地计算量

下图

可以看到随着比例C的增大,训练轮数在减小,C过大时会在指定时间内达不到希望的准确度。

 

(左图)可以看到对于独立同分布数据,B=无穷大、E=1(数据大批量更新、本地训练1次时)的效果最差,B=10、E=20时(小批量数据更新、本地训练数据次数20次)精度最高。

(右图)类似于右图效果。

而在下图中:可以清楚的看到并不是局部模型更新的次数越高越好,E=1比E=5的训练效果要好得多。

思考:FedAvg算法的局限性主要在于:对于网络的连通性要求十分严格,不同的客户端规定采用一致的局部模型更新次数的做法过于死板,可能会导致模型过拟合。

但是,FedAvg会“抛弃落后者”或者合并“落后者”信息,即直接丢弃无法完成指定计算轮数E的设备,或者将未完成的设备信息聚合,会影响模型的收敛,加大计算量。(后面的prox算法,主要会解决这个问题)

3、代码解释

在此之前,看到这儿的人一定要懂得,跑一个项目,就一定得看项目的readme文件,这个文件里面几乎什么都会写到,比如这个项目所依赖的环境。配置环境不难 但就是很烦人。

代码是在Git上获取的:federated-learning · GitHub

 3.1、main_fed.py主函数

首先,前面一大段基本是导入工具包的过程,这个不重要:

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

接下来是main函数:

首先传参,接下来调用设备,首选cuda 其次cpu 

if __name__ == '__main__':
    # parse args
    args = args_parser()#用于调用option.py的函数
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

接下来是加载数据集,划分数据集,这儿注意,‘../data/mnist/’的意思是 将mnist数据集下载到一级文件夹下的data文件夹中,也可以手动指定。。

    if args.dataset == 'mnist':
        #tensor就是个多维数组
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        # trans_mnist处理方式 将图片转化为tensor张量类型,进行归一化处理
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # 数据集的训练和测试调用datasets库 数据集内容被下载到data文件夹中的cifar和mnist文件夹

        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
            # 数据划分方式将数据分为 iid 和 non-iid 两种

    elif args.dataset == 'cifar':#类似对mnist上面的操作
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

接下来是build model 阶段:

# build model
    #这儿得使用model文件夹下定义的nets.py中的神经网络模型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)#打印具体网络结构
    net_glob.train()#对网络进行训练

接下来是复制权重与训练过程:

# copy weights复制权重
    w_glob = net_glob.state_dict()

    # training
    #fedavg 核心代码
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0 # 预测损失,计数器
    net_best = None
    best_loss = None 
    val_acc_list, net_list = [], []# 刚开始 先置空

    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]# 给参与训练的局部下发全局初始模型
    for iter in range(args.epochs):# epochs 局部迭代轮次
        loss_locals = [] # 局部预测损失
        if not args.all_clients:
            w_locals = []
        
        m = max(int(args.frac * args.num_users), 1)#每轮被选参与联邦学习的用户比例frac
        #sample client
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)#随机选取用户
        
        for idx in idxs_users:
            #local model training process
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            # 初始的本地模型利用deepcopy函数 深复制来源于 全局下发的初始模型 net_glob 传给(args.device)计算局部损失
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))# 局部损失以列表的形式往后添加
            #w_locals以列表的形式汇总本地客户端训练权重结果
            
        # update global weights全局更新
        w_glob = FedAvg(w_locals)# 调用FedAvg函数进行更新聚合 得到全局模型

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)#复制权重 准备下发

        # print loss在每轮后打印输出全局训练损失
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

接下来就是,测试:

# testing
    net_glob.eval()# eavl()函数 关闭batch normalization与dropout 处理
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

3.2、Fed.py:

FedAvg函数定义如下:

def FedAvg(w):
    w_avg = copy.deepcopy(w[0]) # 利用深拷贝获取初始w_0
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] # 累加
        w_avg[k] = torch.div(w_avg[k], len(w)) #平均
    return w_avg

3.3、Nets.py:模型定义

继承nn.Module类构造自己的神经网络,定义输入、隐藏、输出层,利用nn.linear设置网络中的全连接。定义前向传播 forward()

import torch
from torch import nn
import torch.nn.functional as F

class MLP(nn.Module):#多层感知机
    def __init__(self,dim_in,dim_hidden,dim_out):#定义
        super(MLP,self).__init__()#进行初始化
        self.layer_input = nn.Linear(dim_in, dim_hidden)#nn.linear线性变换
        self.relu = nn.ReLU()#激活函数
        self.dropout = nn.Dropout()#防止过拟合而设置的
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        #shape快速读取矩阵向量的形状,将其传入全连接层
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

定义处理mnist、cifar数据集的CNN:这个也是继承nn.module:

class CNNMnist(nn.Module):#处理MNIST的CNN
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        #两个卷积层
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        #卷积核大小为5*5,nn.conv2d为2维卷积神经网络
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        #in_channel=10,out_channel=20
        self.conv2_drop = nn.Dropout2d()
        #全连接层
        self.fc1 = nn.Linear(320, 50)#输入特征和输出特征数
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        #卷积层-》池化层-》激活函数
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])#展开数据,将要输入全连接层
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):#卷积神经网络
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        #两个卷积层
        self.conv1 = nn.Conv2d(3, 6, 5)#输入三个通道图片,产生6个特征
        self.pool = nn.MaxPool2d(2, 2)#最大池化层2*2
        self.conv2 = nn.Conv2d(6, 16, 5)#产生16个更深层次的特征
        self.fc1 = nn.Linear(16 * 5 * 5, 120)#添加全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)#平铺图片为16*5*5
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.4、option.py超参数设置

python文件中,实验参数可在这儿修改 也可以在终端运行的时候直接键入。

import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    return args

3.5、sampling.py:

将数据集中的数据样本划分成iid/non-iid数据样本,分配给Client。

对于独立同分布情况,将数据集中的数据打乱,为每个Client随机分配600个。

对于non-iid情况,根据数据集标签将数据集排序,将其划分为200组大小为300的数据切片,每个client分配两个切片。

import numpy as np
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users): # mnist独立同分布数据采样
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users) # num_items=MINIST数据集大小/用户数量
    # 数据集以矩阵形式存在,行为user,列为iterm,则有:len(Dataset)=num_user*num_item
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]

    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        # 从序列中随机采样,且不重用
        all_idxs = list(set(all_idxs) - dict_users[i])
        # all_idxs 作为序列顺序
    return dict_users

def mnist_noniid(dataset, num_users): # mnist非独立同分布数据采样
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 300
    # num_shards 200分片索引
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs) # idxs1~6000
    labels = dataset.train_labels.numpy()
    # 用numpy 将mnist数据转化成张量tensor格式

    # sort labels 标签分类
    idxs_labels = np.vstack((idxs, labels))
    # 按垂直方向将idxs 与 labels堆叠构成一个新的数组
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]# 排序
    idxs = idxs_labels[0,:]

    # divide and assign 分配
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        # 从idx中随机选择2个 分配给客户端,不重复
        idx_shard = list(set(idx_shard) - rand_set) # idx_shard序列0~...
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                axis=0)# 行拼接
            # concatenate() 对应数组拼接
            # idxs 存下标 num_imgs=300 当rand=8时,idxs[2400:2700]
            # dict_users[i]=【dict_user[i],300】 每个dict_users[i]有被随机分配300个下标数据
    return dict_users


def cifar_iid(dataset, num_users):# cifar 独立同分布数据
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # 将照片格式转化成张量形式 
                                       #进行归一化处理
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)

3.6、update.py :局部更新

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics


class DatasetSplit(Dataset): # 数据集划分
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self): # 数据集大小
        return len(self.idxs)

    def __getitem__(self, item):
        # sampling中idxs
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss() # 交叉熵损失函数
        self.selected_clients = [] # 用户选取
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
        # 将划分的数据集当做本地数据集 进行小批量更新 batch_size=local_bs
        # shuffle 用于打乱数据集,每次都会以不同的顺序返回

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        # 优化器 SGD,加入动量momentum 学习率:lr

        epoch_loss = [] # 每迭代一次的损失
        for iter in range(self.args.local_ep):
            batch_loss = [] # 为了提高计算效率,不会对每个client进行loss统计,统计batch_loss
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                # enumerate()函数将()里面的内容 转化成为一个序列,一个一个的取出 batch_size大小的数据,训练
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad() # 将其所有参数(包括子模块的参数)的梯度设置为零
                log_probs = net(images) # 获得前向传播结果
                loss = self.loss_func(log_probs, labels) #计算损失
                loss.backward() # 反向传播损失
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter,
                        batch_idx * len(images), 
                        len(self.ldr_train.dataset),
                        100. * batch_idx / len(self.ldr_train),
                        loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            # 总的批量损失/批量个数=一个epoch的损失
            # 一行一行附加到epoch_loss序列中
            
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
        # 局部迭代loss之和/迭代轮次=平均每epoch损失

3.7、main_nn.py对照组 普通的nn

注意,这儿Git上的的main_nn.py中定义了text函数,这与调用的pytest发生了矛盾,所以我将text()改成了ceshi()

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms

from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar

# main_nn.py普通nn对比main_Fed.py
# 运行测试集并输出准确率与Loss大小(交叉熵函数,适用于多标签分类任务)
def ceshi(net_g, data_loader):
    # testing
    net_g.eval() # 关闭归一化化与dropout
    test_loss = 0
    correct = 0
    l = len(data_loader) # 载入数据集大小
    for idx, (data, target) in enumerate(data_loader):# 一个一个取出载入的数据
        data, target = data.to(args.device), target.to(args.device) # 传到设备
        log_probs = net_g(data) # 获得前向传播结果
        test_loss += F.cross_entropy(log_probs, target).item()
        # 取出item的结果 计算交叉损失熵 付给test_loss
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        # 最大值得索引位置为y_pred
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
        # 通过与真实值的索引位置来对比


    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss

# 与main_fed.py中的main函数相比,不调用fed.py即可
if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    torch.manual_seed(args.seed)

    # load dataset and split users
    #分别对mnist cifar数据集载入 划分
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        img_size = dataset_train[0][0].shape
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    # training
    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = []
    net_glob.train()
    for epoch in range(args.epochs):
        batch_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            optimizer.zero_grad()
            output = net_glob(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

    # plot loss
    plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

    # testing
    if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = ceshi(net_glob, test_loader)

参考:

联邦学习方法FedAvg实战(Pytorch) – 知乎 (zhihu.com)

FedAvg源码学习_mnist_iid_idkmn_的博客-CSDN博客

机器学习中的独立同分布_半夜起来敲代码的博客-CSDN博客_机器学习 独立同分布

从零开始 | FedAvg 代码实现详解 – 知乎 (zhihu.com)

pytorch教程之nn.Module类详解——使用Module类来自定义模型_LoveMIss-Y的博客-CSDN博客

【代码解析(3)】Communication-Efficient Learning of Deep Networks from Decentralized Data_enumerate(self.trainloader)_缄默的天空之城的博客-CSDN博客

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年12月6日
下一篇 2023年12月6日

相关推荐