Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for KFAC Optimization in LSTM and GRU Layers #188

Open
neuronphysics opened this issue Nov 22, 2023 · 4 comments
Open

Add Support for KFAC Optimization in LSTM and GRU Layers #188

neuronphysics opened this issue Nov 22, 2023 · 4 comments

Comments

@neuronphysics
Copy link

Feature

I kindly request the addition of support for the Kronecker-Factored Approximate Curvature (KFAC) optimization technique in LSTM and GRU layers within the existing KFAC Optimizer. Currently, most of the KFAC Optimizer classes are tailored for linear and 2D convolution layers. Extending its capabilities to encompass RNN layers would be a significant enhancement.

Proposal

The proposal entails integrating KFAC optimization support for LSTM and GRU layers into the KFAC optimizer. This would involve adapting the KFAC Optimizer to calculate the requisite statistics and computation of chain-structured linear Gaussian graphical model for LSTM and GRU layers which I could not find any public implementation of it.

Motivation

LSTM and GRU layers are foundational components in dealing with sequential data, and time-series analysis. I wonder how much KFAC can significantly improve model training using LSTM and GRU layers by providing accurate approximations of the Fisher information matrix? By integrating support for LSTM and GRU layers within the KFAC Optimizer, researchers would gain the ability to apply the KFAC optimization technique to a wider array of models, including reinforcement learning algorithms.

Additional Context

I have full confidence that the repository maintainers, particularly the first author of the paper titled

I appreciate your consideration of this feature request. Thank you.

@james-martens
Copy link
Collaborator

Yeah support for recurrent networks is something we have partially implemented internally. If there's interest I guess we could try to get this out sooner.

@neuronphysics
Copy link
Author

neuronphysics commented Nov 23, 2023

Great to hear that support for recurrent networks is implemented. There's definitely interest in this feature, and making it public sooner would be much appreciated, especially for its application in RL models which is my main interest.

@neuronphysics
Copy link
Author

Is there any update on publishing the KFAC code for RNNs?

@james-martens
Copy link
Collaborator

Sorry, no. Myself and others have been very busy and haven't had time. If you're interested in using a Kronecker-factored method compatible with RNNs out of the box, you could try Shampoo or TNT, which make fewer assumptions about the structure of the network. I imagine that these are implemented in some open source library, but don't know specifically. We might eventually release support for these approaches in kfac_jax, but I have no timeline for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants