在今天的教程中,我们将带你了解torch.compile
的基本使用方法,并展示它相比PyTorch
之前的编译解决方案(如TorchScript
和FX Tracing
)的优势。
基础使用
可以将任意的 Python 函数传递给 torch.compile
进行优化。优化后的函数可以替代原始函数进行调用。
使用方法1
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
使用方法2
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(t1, t2))
使用方法3
t = torch.randn(10, 100)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))
torch.compile 处理嵌套调用
当你使用 torch.compile 装饰一个函数时,PyTorch 会尝试编译该函数及其所有嵌套函数。这意味着,如果你的函数中调用了其他函数,这些内部函数也会被编译和优化,从而提高整体性能。
def nested_function(x):
return torch.sin(x)
@torch.compile
def outer_function(x, y):
a = nested_function(x)
b = torch.cos(y)
return a + b
print(outer_function(t1, t2))
当你编译一个模块时,所有子模块和方法(除非它们在跳过列表中)也会被编译。
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)
def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))
outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))
你可以在函数内部使用 torch.compiler.disable
来禁用编译。通过设置 recursive=False
,你可以仅禁用当前函数的编译,而不会影响其嵌套函数。如果设置 recursive=True
(默认值),则会禁用当前函数及其所有嵌套函数的编译。
def complex_conjugate(z):
return torch.conj(z)
@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
z = torch.complex(real, imag)
return complex_conjugate(z)
def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)
# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)
torch.compile 实践建议
顶层编译:一种方法是尽可能在最顶层进行编译(即在顶层模块初始化/调用时),并在遇到过多的图断裂或错误时有选择地禁用编译。如果仍然存在许多编译问题,改为单独编译各个子组件。 模块化测试:在将各个函数和模块集成到更大的模型之前,先单独使用 torch.compile
进行测试,以隔离潜在问题。有选择地禁用编译:如果某些函数或子模块不能被 torch.compile
处理,使用torch.compiler.disable
上下文管理器递归地将它们从编译中排除。
# 学习大模型 & 讨论Kaggle #
每天大模型、算法竞赛、干货资讯