https://github.com/pytorch/functorch Skip to content Sign up * Product + Features + Mobile + Actions + Codespaces + Packages + Security + Code review + Issues + Integrations + GitHub Sponsors + Customer stories * Team * Enterprise * Explore + Explore GitHub + Learn and contribute + Topics + Collections + Trending + Learning Lab + Open source guides + Connect with others + The ReadME Project + Events + Community forum + GitHub Education + GitHub Stars program * Marketplace * Pricing + Plans + Compare plans + Contact Sales + Education [ ] * # In this repository All GitHub | Jump to | * No suggested jump to results * # In this repository All GitHub | Jump to | * # In this organization All GitHub | Jump to | * # In this repository All GitHub | Jump to | Sign in Sign up {{ message }} pytorch / functorch Public * Notifications * Fork 61 * Star 930 functorch is JAX-like composable function transforms for PyTorch. pytorch.org/functorch/ License BSD-3-Clause license 930 stars 61 forks Star Notifications * Code * Issues 184 * Pull requests 25 * Actions * Projects 1 * Wiki * Security * Insights More * Code * Issues * Pull requests * Actions * Projects * Wiki * Security * Insights This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository. main Switch branches/tags [ ] Branches Tags Could not load branches Nothing to show {{ refName }} default View all branches Could not load tags Nothing to show {{ refName }} default View all tags 184 branches 2 tags Code Latest commit @samdow samdow fix unexpected successes for binary cross entropy, linalg.norm, linal... ... a8987d8 May 19, 2022 fix unexpected successes for binary cross entropy, linalg.norm, linal... ...g.matrix_norm a8987d8 Git stats * 1,064 commits Files Permalink Failed to load latest commit information. Type Name Latest commit message Commit time .circleci Update environment.yml (#811) May 17, 2022 .github/workflows add lintrunner support (#783) May 6, 2022 benchmarks Don't trace static_args (#435) Feb 3, 2022 codegen Fix a bunch of linalg coverage issuez (#765) Apr 29, 2022 docs In ux_limitations: use "elements" instead of "memory" Apr 11, 2022 examples Remove erroneous files Mar 23, 2022 functorch fixed some static argnums stuff May 19, 2022 notebooks Update neural_tangent_kernels.ipynb (#788) May 10, 2022 op_analysis fix flake Apr 29, 2022 packaging Windows circleci build (#696) Apr 14, 2022 test fix unexpected successes for binary cross entropy, linalg.norm, linal... May 19, 2022 tools/lint add lintrunner support (#783) May 6, 2022 .gitignore updated some decompositions and cleaned some stuff Apr 27, 2022 .lintrunner.toml add lintrunner support (#783) May 6, 2022 CODE_OF_CONDUCT.md Added code of conduct + contributing May 26, 2021 COMPILE_README.md Fix some lint issues (#606) Mar 17, 2022 CONTRIBUTING.md Add newlines to eof Jun 2, 2021 LICENSE Initial commit May 26, 2021 README.md Fix link to nightly binary Apr 11, 2022 setup.cfg Rewrite codegen based on PyTorch's codegen library (#528) Feb 23, 2022 setup.py Remove erroneous -g from opt build Apr 13, 2022 version.txt Bump functorch main version to 0.2.0 (#448) Feb 4, 2022 writing_batching_rules.md Split VMAP_SUPPORT macro into VMAP_SUPPORT/VMAP_SUPPORT2 (#462) Feb 5, 2022 View code [ ] functorch Why composable function transforms? Install Installing functorch from source Using Colab Locally functorch development setup Installing functorch beta (compatible with PyTorch 1.11) Using Colab pip What are the transforms? vmap grad vjp jvp jacrev, jacfwd, and hessian Tracing through the transformations Working with NN modules: make_functional and friends Documentation Debugging Future Plans License Citing functorch README.md functorch Why functorch? | Install guide | Transformations | Documentation | Future Plans This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library. functorch is JAX-like composable function transforms for PyTorch. It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance. In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance. Why composable function transforms? There are a number of use cases that are tricky to do in PyTorch today: * computing per-sample-gradients (or other per-sample quantities) * running ensembles of models on a single machine * efficiently batching together tasks in the inner-loop of MAML * efficiently computing Jacobians and Hessians * efficiently computing batched Jacobians and Hessians Composing vmap, grad, vjp, and jvp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework. Install There are two ways to install functorch: 1. functorch from source 2. functorch beta (compatible with PyTorch 1.11) We recommend trying out the functorch beta first. Installing functorch from source Click to expand Using Colab Follow the instructions in this Colab notebook Locally First, set up an environment. We will be installing a nightly PyTorch binary as well as functorch. If you're using conda, create a conda environment: conda create --name functorch conda activate functorch If you wish to use venv instead: python -m venv functorch-env source functorch-env/bin/activate Next, install one of the following following PyTorch nightly binaries. # For CUDA 10.2 pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html --upgrade # For CUDA 11.3 pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html --upgrade # For CPU-only build pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade If you already have a nightly of PyTorch installed and wanted to upgrade it (recommended!), append --upgrade to one of those commands. Install functorch: pip install ninja # Makes the build go faster pip install --user "git+https://github.com/pytorch/functorch.git" Run a quick sanity check in python: import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin()) functorch development setup functorch is a PyTorch C++ Extension module. To install, * Install PyTorch from source. functorch usually runs on the latest development version of PyTorch. * Run python setup.py install. You can use DEBUG=1 to compile in debug mode. Then, try to run some tests to make sure all is OK: pytest test/test_vmap.py -v pytest test/test_eager_transforms.py -v To do devel install: pip install -e . To install with optional dependencies, e.g. for AOTAutograd: pip install -e .[aot] Installing functorch beta (compatible with PyTorch 1.11) Click to expand Using Colab Follow the instructions here pip Prerequisite: Install PyTorch 1.11 pip install functorch Finally, run a quick sanity check in python: import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin()) What are the transforms? Right now, we support the following transforms: * grad, vjp, jvp, * jacrev, jacfwd, hessian * vmap Furthermore, we have some utilities for working with PyTorch modules. * make_functional(model) * make_functional_with_buffers(model) vmap Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring. vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a new function that maps func over some dimension (default: 0) of each Tensor in inputs. vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience: from functorch import vmap batch_size, feature_size = 3, 5 weights = torch.randn(feature_size, requires_grad=True) def model(feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() examples = torch.randn(batch_size, feature_size) result = vmap(model)(examples) grad grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0]. from functorch import grad x = torch.randn([]) cos_x = grad(lambda x: torch.sin(x))(x) assert torch.allclose(cos_x, x.cos()) # Second-order gradients neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) assert torch.allclose(neg_sin_x, -x.sin()) When composed with vmap, grad can be used to compute per-sample-gradients: from functorch import vmap batch_size, feature_size = 3, 5 def model(weights,feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() def compute_loss(weights, example, target): y = model(weights, example) return ((y - target) ** 2).mean() # MSELoss weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) inputs = (weights,examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) vjp The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors. from functorch import vjp outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) jvp The jvp transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs) as well as the jvps. from functorch import jvp x = torch.randn(5) y = torch.randn(5) f = lambda x, y: (x * y) _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) assert torch.allclose(output, x + y) jacrev, jacfwd, and hessian The jacrev transform returns a new function that takes in x and returns the Jacobian of torch.sin with respect to x using reverse-mode AD. from functorch import jacrev x = torch.randn(5) jacobian = jacrev(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians: x = torch.randn(64, 5) jacobian = vmap(jacrev(torch.sin))(x) assert jacobian.shape == (64, 5, 5) jacfwd is a drop-in replacement for jacrev that computes Jacobians using forward-mode AD: from functorch import jacfwd x = torch.randn(5) jacobian = jacfwd(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) Composing jacrev with itself or jacfwd can produce hessians: def f(x): return x.sin().sum() x = torch.randn(5) hessian0 = jacrev(jacrev(f))(x) hessian1 = jacfwd(jacrev(f))(x) The hessian is a convenience function that combines jacfwd and jacrev: from functorch import hessian def f(x): return x.sin().sum() x = torch.randn(5) hess = hessian(f)(x) Tracing through the transformations We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!). from functorch import make_fx, grad def f(x): return torch.sin(x).sum() x = torch.randn(100) grad_f = make_fx(grad(f))(x) print(grad_f.code) def forward(self, x_1): sin = torch.ops.aten.sin(x_1) sum_1 = torch.ops.aten.sum(sin, None); sin = None cos = torch.ops.aten.cos(x_1); x_1 = None _tensor_constant0 = self._tensor_constant0 mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None return mul Working with NN modules: make_functional and friends Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in: * model ensembling, where all of your weights and buffers have an additional dimension * per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function. * make_functional(model) returns a functional version of model and the model.parameters() * make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers(). Here's an example where we compute per-sample-gradients using an nn.Linear layer: import torch from functorch import make_functional, vmap, grad model = torch.nn.Linear(3, 3) data = torch.randn(64, 3) targets = torch.randn(64, 3) func_model, params = make_functional(model) def compute_loss(params, data, targets): preds = func_model(params, data) return torch.mean((preds - targets) ** 2) per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) If you're making an ensemble of models, you may find combine_state_for_ensemble useful. Documentation For more documentation, see our docs website. Debugging functorch._C.dump_tensor: Dumps dispatch keys on stack functorch._C._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you. Future Plans In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out. License Functorch has a BSD-style license, as found in the LICENSE file. Citing functorch If you use functorch in your publication, please cite it by using the following BibTeX entry. @Misc{functorch2021, author = {Horace He, Richard Zou}, title = {functorch: JAX-like composable function transforms for PyTorch}, howpublished = {\url{https://github.com/pytorch/functorch}}, year = {2021} } About functorch is JAX-like composable function transforms for PyTorch. pytorch.org/functorch/ Topics pytorch gradients hessians Resources Readme License BSD-3-Clause license Code of conduct Code of conduct Stars 930 stars Watchers 20 watching Forks 61 forks Releases 2 functorch 0.1.1 Latest Apr 12, 2022 + 1 release Packages 0 No packages published Used by 8 * @olivier-serris * @orobix * @Eiphodos * @RedTachyon * @Kipsora * @hyoyoung * @PyTorchKorea * @pytorch Contributors 30 * @zou3519 * @Chillee * @samdow * @vfdev-5 * @anijain2305 * @kshitij12345 * @bertmaher * @bdhirsh * @ezyang * @cyyever * @Padarn + 19 contributors Languages * Python 43.6% * C++ 43.1% * Jupyter Notebook 10.9% * Shell 1.4% * Other 1.0% * (c) 2022 GitHub, Inc. * Terms * Privacy * Security * Status * Docs * Contact GitHub * Pricing * API * Training * Blog * About You can't perform that action at this time. You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.