Listen to this story
|
PyTorch announced the release of version 2.3, introducing several new features and improvements for performance and usability of large language models and sparse inference.
The release, which consists of 3,393 commits from 426 contributors, brings support for user-defined Triton kernels in torch.compile. This feature allows users to migrate their existing Triton kernels without experiencing performance regressions or graph breaks.
The feature also allows Torch Inductor to precompile user-defined Triton kernels and organise code more efficiently.
Another feature, Tensor Parallelism for efficient training of large language models. It facilitates various tensor manipulations across GPUs and hosts, integrating with FSDP (Fully Sharded Data Parallel) for efficient 2D parallelism.
The PyTorch team has validated Tensor Parallelism on training runs for models with over 100 billion parameters, demonstrating its effectiveness in handling large-scale language models.
PyTorch 2.3 introduces support for semi-structured sparsity, specifically 2:4 sparsity, by implementing it as a tensor subclass. This feature enhances performance, achieving up to 1.6 times faster processing than dense matrix multiplication, and includes advanced functionalities like mixing different data types during quantization, uses improved versions of cuSPARSELt and CUTLASS kernels, and is compatible with torch.compile for more efficient computation.
Compared to the previous version, PyTorch 2.2, which brought advancements like the integration of FlashAttention-v2 and the introduction of AOTInductor, PyTorch 2.3 builds upon these improvements and introduces new features specifically targeted at large language models and sparse inference.
With significant contributions from a large and active community, this version brings features like user-defined Triton kernels and Tensor Parallelism to collectively improve performance, scalability, and flexibility.