torch之并行计算

酥酥 发布于 2022-04-20 50 次阅读


命令式和符号式混合编程

				
					def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

fancy_func(1, 2, 3, 4)
				
			
10
				
					def add_str():
    return '''
def add(a, b):
    return a + b
'''

def fancy_func_str():
    return '''
def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g
'''

def evoke_str():
    return add_str() + fancy_func_str() + '''
print(fancy_func(1, 2, 3, 4))
'''

prog = evoke_str()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
				
			
def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

print(fancy_func(1, 2, 3, 4))

10
				
					import torch
import time

assert torch.cuda.device_count() >= 2
				
			
				
					x_gpu1 = torch.rand(size=(100, 100), device='cuda:0')
x_gpu2 = torch.rand(size=(100, 100), device='cuda:
				
			
				
					class Benchmark():  # 本类已保存在d2lzh_pytorch包中方便以后使用
    def __init__(self, prefix=None):
        self.prefix = prefix + ' ' if prefix else ''

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, *args):
        print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))
				
			
				
					def run(x):
    for _ in range(20000):
        y = torch.mm(x, x)
				
			
				
					with Benchmark('Run on GPU1.'):
    run(x_gpu1)
    torch.cuda.synchronize()

with Benchmark('Then run on GPU2.'):
    run(x_gpu2)
    torch.cuda.synchronize()
				
			
Run on GPU1. time: 0.2989 sec
Then run on GPU2. time: 0.3518 sec
				
					with Benchmark('Run on both GPU1 and GPU2 in parallel.'):
    run(x_gpu1)
    run(x_gpu2)
    torch.cuda.synchronize()
				
			
Run on both GPU1 and GPU2 in parallel. time: 0.5076 sec
				
					!nvidia-smi
				
			
				
					import torch
net = torch.nn.Linear(10, 1).cuda()
net
				
			
Linear(in_features=10, out_features=1, bias=True)
				
					net = torch.nn.DataParallel(net)
net
				
			
DataParallel(
  (module): Linear(in_features=10, out_features=1, bias=True)
)
				
					torch.save(net.state_dict(), "./8.4_model.pt")
				
			
				
					new_net = torch.nn.Linear(10, 1)
# new_net.load_state_dict(torch.load("./8.4_model.pt")) # 加载失败
				
			
				
					torch.save(net.module.state_dict(), "./8.4_model.pt")
new_net.load_state_dict(torch.load("./8.4_model.pt")) # 加载成功
				
			
				
					torch.save(net.state_dict(), "./8.4_model.pt")
new_net = torch.nn.Linear(10, 1)
new_net = torch.nn.DataParallel(new_net)
new_net.load_state_dict(torch.load("./8.4_model.pt")) # 加载成功