Kaggle知识点:torch模型编译与加速

学术   2024-10-17 20:34   北京  

在今天的教程中,我们将带你了解torch.compile的基本使用方法,并展示它相比PyTorch之前的编译解决方案(如TorchScriptFX Tracing)的优势。

unsetunset基础使用unsetunset

可以将任意的 Python 函数传递给 torch.compile 进行优化。优化后的函数可以替代原始函数进行调用。

  • 使用方法1
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
  • 使用方法2
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(t1, t2))
  • 使用方法3
t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))

unsetunsettorch.compile 处理嵌套调用unsetunset

当你使用 torch.compile 装饰一个函数时,PyTorch 会尝试编译该函数及其所有嵌套函数。这意味着,如果你的函数中调用了其他函数,这些内部函数也会被编译和优化,从而提高整体性能。

def nested_function(x):
    return torch.sin(x)

@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b

print(outer_function(t1, t2))

当你编译一个模块时,所有子模块和方法(除非它们在跳过列表中)也会被编译。

class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

你可以在函数内部使用 torch.compiler.disable 来禁用编译。通过设置 recursive=False,你可以仅禁用当前函数的编译,而不会影响其嵌套函数。如果设置 recursive=True(默认值),则会禁用当前函数及其所有嵌套函数的编译。

def complex_conjugate(z):
    return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)

def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)

# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

unsetunsettorch.compile 实践建议unsetunset

  • 顶层编译:一种方法是尽可能在最顶层进行编译(即在顶层模块初始化/调用时),并在遇到过多的图断裂或错误时有选择地禁用编译。如果仍然存在许多编译问题,改为单独编译各个子组件。
  • 模块化测试:在将各个函数和模块集成到更大的模型之前,先单独使用 torch.compile 进行测试,以隔离潜在问题。
  • 有选择地禁用编译:如果某些函数或子模块不能被 torch.compile 处理,使用 torch.compiler.disable 上下文管理器递归地将它们从编译中排除。

 学习大模型 & 讨论Kaggle  #


△长按添加竞赛小助手

每天大模型、算法竞赛、干货资讯

与 36000+来自竞赛爱好者一起交流~

Coggle数据科学
Coggle全称Communication For Kaggle,专注数据科学领域竞赛相关资讯分享。
 最新文章