From 77523880704332c4d7bdb3ac656950f377eabfa6 Mon Sep 17 00:00:00 2001 From: albertfgu Date: Wed, 19 Aug 2020 18:03:42 -0700 Subject: [PATCH] Initial commit --- .gitignore | 126 ++++++++++ LICENSE | 201 +++++++++++++++ README.md | 82 ++++++ assets/hippo.png | Bin 0 -> 44310 bytes cfg/config.yaml | 26 ++ cfg/dataset/adding.yaml | 5 + cfg/dataset/copying.yaml | 9 + cfg/dataset/ct.yaml | 14 ++ cfg/dataset/imdb.yaml | 10 + cfg/dataset/mnist.yaml | 3 + cfg/runner/pl.yaml | 4 + csrc/hippo.cpp | 21 ++ csrc/hippolegs.cpp | 167 +++++++++++++ csrc/hippolegt.cpp | 59 +++++ csrc/setup.py | 11 + datasets/__init__.py | 232 +++++++++++++++++ datasets/adding.py | 28 +++ datasets/copying.py | 47 ++++ datasets/tasks.py | 60 +++++ datasets/uea.py | 390 +++++++++++++++++++++++++++++ datasets/utils.py | 48 ++++ model/components.py | 86 +++++++ model/exprnn/expm32.py | 315 +++++++++++++++++++++++ model/exprnn/initialization.py | 67 +++++ model/exprnn/orthogonal.py | 107 ++++++++ model/exprnn/parametrization.py | 127 ++++++++++ model/exprnn/trivializations.py | 24 ++ model/memory.py | 426 ++++++++++++++++++++++++++++++++ model/model.py | 113 +++++++++ model/op.py | 266 ++++++++++++++++++++ model/opcell.py | 107 ++++++++ model/orthogonalcell.py | 84 +++++++ model/rnn.py | 157 ++++++++++++ model/rnncell.py | 269 ++++++++++++++++++++ pl_runner.py | 34 +++ requirements.txt | 11 + tensorflow/hippo.py | 390 +++++++++++++++++++++++++++++ tests/test_legs_extension.py | 229 +++++++++++++++++ tests/test_legt_extension.py | 109 ++++++++ train.py | 124 ++++++++++ utils.py | 22 ++ 41 files changed, 4610 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 assets/hippo.png create mode 100644 cfg/config.yaml create mode 100644 cfg/dataset/adding.yaml create mode 100644 cfg/dataset/copying.yaml create mode 100644 cfg/dataset/ct.yaml create mode 100644 cfg/dataset/imdb.yaml create mode 100644 cfg/dataset/mnist.yaml create mode 100644 cfg/runner/pl.yaml create mode 100644 csrc/hippo.cpp create mode 100644 csrc/hippolegs.cpp create mode 100644 csrc/hippolegt.cpp create mode 100644 csrc/setup.py create mode 100644 datasets/__init__.py create mode 100644 datasets/adding.py create mode 100644 datasets/copying.py create mode 100644 datasets/tasks.py create mode 100644 datasets/uea.py create mode 100644 datasets/utils.py create mode 100644 model/components.py create mode 100644 model/exprnn/expm32.py create mode 100644 model/exprnn/initialization.py create mode 100644 model/exprnn/orthogonal.py create mode 100644 model/exprnn/parametrization.py create mode 100644 model/exprnn/trivializations.py create mode 100644 model/memory.py create mode 100644 model/model.py create mode 100644 model/op.py create mode 100644 model/opcell.py create mode 100644 model/orthogonalcell.py create mode 100644 model/rnn.py create mode 100644 model/rnncell.py create mode 100644 pl_runner.py create mode 100644 requirements.txt create mode 100644 tensorflow/hippo.py create mode 100644 tests/test_legs_extension.py create mode 100644 tests/test_legt_extension.py create mode 100644 train.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f9244f --- /dev/null +++ b/.gitignore @@ -0,0 +1,126 @@ +config/ +lightning_logs/ +results/ +outputs/ +wandb/ +datasets/mnist/ +datasets/cifar/ +datasets/eegseizure/ +datasets/imdb/ +datasets/timit/ +datasets/timit +datasets/vctk/ +datasets/wikitext-2/ +ray.sh +ray_config/redis_address +ray_config/redis_password + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# End of https://www.gitignore.io/api/python diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e26dc4c --- /dev/null +++ b/README.md @@ -0,0 +1,82 @@ +# HiPPO +![HiPPO Framework](assets/hippo.png "HiPPO Framework") +> **HiPPO: Recurrent Memory with Optimal Polynomial Projections**\ +> Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher RĂ©\ +> Stanford University\ +> Paper: https://arxiv.org/abs/2008.07669 + + + +> **Abstract.** A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by projection onto polynomial bases. Given a measure that specifies the importance of each time step in the past, HiPPO produces an optimal solution to a natural online function approximation problem. As special cases, our framework yields a short derivation of the recent Legendre Memory Unit (LMU) from first principles, and generalizes the ubiquitous gating mechanism of recurrent neural networks such as GRUs. This formal framework yields a new memory update mechanism (HiPPO-LegS) that scales through time to remember all history, avoiding priors on the timescale. HiPPO-LegS enjoys the theoretical benefits of timescale robustness, fast updates, and bounded gradients. By incorporating the memory dynamics into recurrent neural networks, HiPPO RNNs can empirically capture complex temporal dependencies. On the benchmark permuted MNIST dataset, HiPPO-LegS sets a new state-of-the-art accuracy of 98.3%. Finally, on a novel trajectory classification task testing robustness to out-of-distribution timescales and missing data, HiPPO-LegS outperforms RNN and neural ODE baselines by 25-40% accuracy. + +## Setup + +### Requirements +This repository requires Python 3.7+ and Pytorch 1.4+. +Other packages are listed in `requirements.txt` + + +## Experiments + +Launch experiments using `train.py`. + +Pass in `dataset=` to specify the dataset, whose default options are specified by the Hydra configs in `cfg/`. See for example `cfg/dataset/mnist.yaml`. + +Pass in `model.cell=` to specify the RNN cell. Default model options can be found in the initializers in the model classes. + +The following example command lines reproduce experiments in Sections 4.1 and 4.2 for the HiPPO-LegS model. The `model.cell` argument can be changed to any other model defined in `model/` (e.g. `lmu`, `lstm`, `gru`) for different types of RNN cells. + +### Permuted MNIST + +``` +python train.py runner=pl runner.ntrials=5 dataset=mnist dataset.permute=True model.cell=legs model.cell_args.hidden_size=512 train.epochs=50 train.batch_size=100 train.lr=0.001 +``` + +### CharacterTrajectories + +See documentation in `datasets.uea.postprocess_data` for explanation of flags. + +100Hz -> 200Hz: +``` +python train.py runner=pl runner.ntrials=2 dataset=ct dataset.timestamp=False dataset.train_ts=1 dataset.eval_ts=1 dataset.train_hz=0.5 dataset.eval_hz=1 dataset.train_uniform=True dataset.eval_uniform=True model.cell=legs model.cell_args.hidden_size=256 train.epochs=100 train.batch_size=100 train.lr=0.001 +``` +Use `dataset.train_hz=1 dataset.eval_hz=0.5` instead for 200Hz->100Hz experiment. + + +Missing values upsample: +``` +python train.py runner=pl runner.ntrials=3 dataset=ct dataset.timestamp=True dataset.train_ts=0.5 dataset.eval_ts=1 dataset.train_hz=1 dataset.eval_hz=1 dataset.train_uniform=False dataset.eval_uniform=False model.cell=tlsi model.cell_args.hidden_size=256 train.epochs=100 train.batch_size=100 train.lr=0.001 +``` +Use `dataset.train_ts=1 dataset.eval_ts=0.5` instead for downsample. + +Note that the model cell is called tlsi (short for "timestamped linear scale invariant") to denote a HiPPO-LegS model that additionally uses the timestamps. + + + +### HiPPO-LegS multiplication in C++ +To compile: +``` +cd csrc +python setup.py install +``` +To test: +``` +pytest tests/test_legs_extension.py +``` +To benchmark: +``` +python tests/test_legs_extension.py +``` + + + +## Citation +If you use this codebase, or otherwise found our work valuable, please cite: +``` +@article{hippo, + title={HiPPO: Recurrent Memory with Optimal Polynomial Projections}, + author={Albert Gu and Tri Dao and Stefano Ermon and Atri Rudra and Christopher R\'{e}}, + journal={arXiv preprint arXiv:2008.07669}, + year={2020} +} +``` diff --git a/assets/hippo.png b/assets/hippo.png new file mode 100644 index 0000000000000000000000000000000000000000..cb91017e4ace82dfc65e5dee05f4b9c74f86e243 GIT binary patch literal 44310 zcmYg&19W6Tw{>jWwryJz+qP}n6KCRNV%xSeaVEBH|2+@?`+BXr?oD;w+ue1lPMx#& zK8aG4mxPDGfdK*nf|r&OQw9P8nF9g>Mt}kVv@o8wLIX~~F3OT3Ks7UXr+_~^%`~OW zfBXQV0bD}?fdS(Hf&c3QI0S%k|NpfFFclE!f3*Xi6JZ4e_CI~(0q1`&GQjcgp8r0B z<^%svZ$Mi<$p7{RnaczP?WrptLL>bExjqb( zl$4~T`-kZUR~zm3a9L%Nl5y%Vsi-J}Go?eEc*FQ-24II|Krrg@JVb|C;~_YB4UbZnq3J zuShb%{bY`qmv?=0Q;+8PWj}%R*z>Lc_w`0M_j4>XJ=o5d3;Esn&=_{r8{{q%p15nQ9M*up*Pk zN_KJa;n9RnF-Z+UGZ}=cJ;NS%gWMy_f+y9VUyRkN9yO{OWv{dDYfms*w2T`( zd@O0W9Ue+MbRj=|+>f}S7BT_G;t~`Dq*R`Lrk1puQ=ySrrIB8zk-w9qKt#mamJ>TokH6p_0aOeF9aAt$o7ac!M&I+7(MGKhH}jqE>ne?a2a{DT#csC z3iO0lw+uAv1T-ben_igVD?B^ zFpwkhKAvmF3A6;6V@Byb6vtWq>Z=~y7A4mQ%hz_7E5CP{kV3hczs*a~qvIr(9|w1N zGbu%E%y5FM(#)%)HLax$S(P*NwA9?Us(%aEREsHHoSzPkFbayO8J4YQ@1es)`rprw zjz->b@$mTC{NCS%VL9If7 zEGIhNklPpVw)8P^IBa#f9M8-OyvOz#%+qU>s{+=*NYdik8n%Jf5%pkE9~y*-l&q|* zjLbkQxJm+TZP(pjo7&-*LB*1IF>!HlhPLaTcZTf!HH*B*o13ztIxcdgIs}kquYl+K zUsKuCnU)tp zvefBK#Fj$guYcp7T4ATT-E8lzx3&Q*+|1lOJR;&Ar^R|{VsdhFVFA@x>F_hmOX`}@ zC=3-9H35%1I2ngyzpHoES?TMz@o*N8hFd#doa|1{Cqje7hKY^&L52+ArGn0&RzVav zueWYrUtcTLtKR@F2w-&wf!ELX=LoR62WBJ$mVoci=QBn1@gRPr`VenJN;x|ZRRTF3 zwP@^bvRtgn%Kp>{ryeO-Fk85w`GbRl=dD2~)zhQ!O^G2LhqYFFz1-XxCO2)n-LKCX zX69mQ>a|LBuXwyuEDZOf&u8)hVeI!lA!vAA*<2n~7Ea^^SZsGJPKff6_sPj`y=LDT z!rG5V%vux(P0nrKd*16V&{<2Mz?+1nN)1N+F3*uzY|(HyBiiCmQSUDQriWBgaz#8> z2H1|uKZ}u~|CSXxuqk4iz)AGsFT+<5;vEAqkT@&w?PPMe)VB@-DwJ!`0F zzd5wpTl9V;?b;?p^E^(_lFef<$a$aBqe$VjG~2X#JXku<%;fkSHO@M7I0)Q6EX=Q> z8u-xr{UMpJT$n7~eEqPv?2zyEo><}hyuKzDEtSgz88j-;_*L*fN?xW}}tZHz%*1&+@-~bBj z7ZPw|Z!fihANX#f#Dl{5vR0+~Plm~{NJ8d9GVaW#mO>E~WsyQI=G#8{Rx-XDpkHX8 zv%i1mbo->6nRWZNiS=-eew8j_FKla{?zYQQ>^oG;q_)^#4%2xMT}DPncYHpre+zic z*ktCq@CB5uI!{Oa0tev;7lNh0V+zr+Zx35TI#v^LX5SehBO}|0`{ScH2UurGNlDJ* zI8!Y50F$KIHwsQhp;E@`4C#YG%LoexIh6or-!DAiM!u(!A}$s)nSsyetEWxyyc7Z%df=I;d`r!Mh zE&!22bY!H^&!DcwA)G0$dj`!qA~G`5g_0lPCB5%govW*wr^g%ZE=D3!x|&MZFp;4` z`O-*C`zDN!-#ekuqXb^RNdyBdW^>nhT$%h7xy}9t;ldi?YW(>2ES~_bsNWi=zFpEl z4z+!H{LywTuuy4jDxLD%ZoQ@3eAk_4!B5Yh)CM&j7K{NEOuJ<~+34EyNDLKmLkpT9 zngjwfNi+iCM;6C)eMwz7`|=2@l8HYki$ljL#UdwGOqEunM8`&*ES})4i}0n$wqw7Y z(BU3RkdTW+-*rir^6~U~z^23glxA&{Ds$z z;n;#gp6^~%;CQl{fB=R1(u;8w{G305&%HjP-FPqeFN2vh&K`NC5TU*!WXs!7(Z9U8 z?m|{U85%_NUK)KqM)!RmV@2}RDt93Xx43L{Hwz{*9 zh!n!>4RvIw;<;Up46zJM6vRq|^2101UrS5s>Si-p>nk<({oKFpj(7h!ZFQZX(>fXwi)n7P*cHqQ*vRmk)1hKvwRxYSEF~qB-G`ta`yEYd z_Lj*I8Tvpq*$x;b!Wd%p3)H;ilk}C{Qe3XL);etW1bjZOjKtyAuM+WT#0{U#kbx*~ zf}ibs?~f)3x}7cV4n|1E;h2>feSdvoWs;ugA z_sjP|-AikHWC~_4B(vLJsX(Z28L2M#3~G$-4-7?KD49r|Q5jwl7w>?eM4Gq^*}A zljHKC^BQML{2XhgL+`zV|8_dztKggoL9KGu_ru99OhNnTv8tL|)W0^&r~bf`U-(*` z6af&==Och=JMPrd_>8WMfho!D3@#T*B8q{B>mAw9S;>@msnLXEbh=;4kf~IO!E-N^ zDoE~%;Lsw|=a-9$cc|qiM&ftk^3FPEj6|QLkt> zPQ`xV+ixg+eD20lu19r4RY7-IRvr@1DMLfIhzmJ_ohOs0|PMtOeO;b zzE?%z0uv{rGzZ0p4~KOGLgT3LDdo~mnXTysd!H8bnG(MQxOisSoI;eKz#k-Y%5f5Z z*s&Vnuiy!OgOM8a_;S}bLcrs&dEFg|$6#cp+hGk+rg@@8y1xOuzfPgKQmIgWF;U>3 z=S!1O41xQpCYRH?I6Kgqgs23E8Cy{h_rO;F!>3oC?#@iEf6>8TB+K&9qLJIWH55M3 zO|bWx`Yhi#Gn`~AVqAoB*$d$QcqWUPY%4j$;bxGav>2fEl-TQySHQkkq@pZ?x0EOF z@rQlVKMLx``iu1Juw270*Zpo#ujchx{XvmvP8DhuVRyCDgR!%%K9x2)mBa3@CWvBV z=P-^3iVHl~C}7IkW>e3D1Wf`1)79wbao-;^o8-z%J59|9%(qOE654y;g_Yy$X^(&{ zZHg*z9U#|>-FM<$rDGKdN*Dk}r>#8o&DTN784?=QQpfbU;Za&kQZZMy}fTUzW6x^me*G4&Y;8nkr> zRLb0py3JeWS20+TLVXudO`Zckb8@Eruc%Zs$)PdB$~ zupHb$-$8>^M*%{gyz8MB3ZAzvpOCbDa6@+VyE5U^S$^9u|2+93(~ zyzi%zSy<7@$=_?+cGEmKDcy0~I%*qDQT`n8S7|^N>OQ(yId1P(VV}7bP|+c^`TnzR z2qNnpT(GwvAjVteCqj30b|#T~IdnN5k^%>z#|#3*b5_jEpGPxs3Al1{>TOA^- z6}h>O>FYfvorIu(is1!^H#N{50r&T(#ehJ0>d(rNbDlbT}(bko$A z=~NX^{`~nf0s)`p0}VPT(2mH_b^4cAPmkW$%V2XMN#k)ouf;d3%MO`Bs;|CAsY3@3 zH#NNtk(m{o=n@QM7+tX#Cf3%|!FY0-z5F60l4TTB6V^b`JZ3pWf3H$e9kz~JRdroO zU6T%7h%wrqKY!kzE-M9uev%?t7C@y~n9|{M-#yeBT{=6{Y;+4`e(kQ71_&Nvy4xh4 z?I^DB2Gbvo+MBigH$Py0)B5B-sjc5A9t_?-PT|$ z^KS?^&GM+WY6+?>Bh1lLJV;}clkd9`gppp1;x^=k(y3ub^(Lcf40_s<@Ovq1$YoYf;y)cn6+|fzp6b^H#?1UV-Do4y=vSE}=mxYLIL|iAvpfUh zjF_3H&!Y(YI%@73;bW!>wwt)faQsg7%UOA{Dc_=U6Izt0&Xd0s<$;2h#HU*Fw0^$Y z47D3_kIM%_L%Mc|qgoR(zbD7z%KeAxkV?e%S(+S9q#79+xeeNZ`$=x%Gfiqq0CvgI z&`_|(MfX?LyK&0Am-qKUZcQ>`&_R`7V~xVX!bC{8xU{?2$jHJ|?*M4ZK%kCoTCMk z24n<^?`0=Mr9^Hw#|+W&Zus2>E}p(PYdm|%d|d~8n^f=Y?HR-Dsn2<6 z_RERkfa7j8n6_#Qv8wFY&6o*%KB$n1LUGu)?Lt_w#`FR*gNn+U!dcs+h%9CoZ082R z_{Yl?wr#mC{FwIQqHl6Iv|awVpB+8+E#7QY5d>>D_RF)u*kIzP>=+#~Txwu?uaf~5 zg2m6F-b7G2UlPTJIR@>HW83c;fR9DT#N6NClRE(IKc`d8r;oqGf>6q5 z@9a`DGFmJxw%Dx)1_BWoOZEZs&c;{nWF0A~u{Njuj8V;WZF^M2N&d{~7)LU4a=(|| z2+KWWztPdr6&n{Xu%k2zusYx4eAh_^!VLA8EndK!KS=2XWm$dOKs`h17CO6n!K%BvJ9}y775r5;jc!*s z*z01gadc^TcwynQW!>J%)pcxk_7jkVa}VYcKsf=>mzgkb?)IG`b>3z`)amnlRju93 z61Y}{G9B!=bKPORQ~}s2@l0Pm@wgoGIVR^yhU3keilHOegvh=X^@O?LTHj}Quf^R>wBZiuTKKG&|Ee8mG1~7LQcHq5 zR{h!HaQJTXA*8=2lO>?{JUKa6j87BGTz2<$2j0Bn$$7l_4kfC)tI)jZ`{OO`_vHp}m5eDrIFTK(?W>oK#tFXO3rvMg@9 z?=L?<-heCMo6*?#pnlqq#gzYDZvySr{t5~SuUMY{fSIs6XvWQ**=W(*>T-OQags5f zd>s{iMVDL-1MB>BwchJ_d;M*@mHq1%i}}4 z0>9hKjZss^{9LQ;8i|!0oD|X4&YvWI7BA$XFdK9jaA@f}BPwc?bTOdezU7G()Sxhb zLf}lg^}%Z?u_=Fr@922^nmor(*zQ6Gt#LDDplTWS8VP)Y@nrW*xb}FXLX1(IkmqzT z9%O`+to!3xPJgfV>}SlSQYZfO!S|R{f(H0lQ_*0Z-kXocUAw<)>y4`nlvpa{XwHa; z&kTCrmd!k=FzZC3J)y}1o12gIqXhJ{^vx>V6-x}M8mW~zdfrbLsQCD;?$Z^;0Svs3 zJ47nQw=z>$2NcpeQI1TuzF>wrLvshWB391R>?>2*+>Hjk{zrQT03Rv^fq~JjFi+*V z?I}+GSJ9;gFRL7YKA8{?Zyi!mN5JET;}ioHW?Op$MV}3HK_RnfSg>kl-3}oHP;K{< zwiv0&XuMeoX(J4~)eIL~a?fsQ#?is=+ndM)f~1{2eudm!gRT)7sd`@EG%YWmGx={ zd}SmlXh3=z0sGs;)WlYg-)M*1xg&?Wt1FxCS`#!aGQuhZB8IX(0lRe9vf+}nou`)Y zwvx_*`Bb_@e@wJl19qhywVDyv(=et1Qf>@3OnNc^rd#mr#EXQQdCO?{em%~!MWBO^ zeog~lC{rqm9UW_4#SeJ+BPqr1=;8f-*{Y`pU8WaAhiT%lpG0B@15)4VQN(ejq{}D3 zm(9gI5KzG}N7ng=Wmbs#P1kPNmW6NuIEBJ;jG3s<6gM}BWteFm$3TC^ezgu*$HC({ zrUHzmva2^E$5_D1DXHozV^P(e><$iqjca6G&mi8vfB#l+Eg6U{zd_fV7h-_og$nm2@O$QV`K4c7nVOxobo zXcA*r3E&Z~{;Ro0#Py z=_@SJUvd2b??RdBy<>7@A>v`DapOXv^@)=subHi52FnKqNV$BRs+2@YGDfAEUhLycXDs&@k{z9n))mdxnE%^cWF}8U2y~=^}rn^5m zx%LmZ8S0C%qM|0Zef%UZfy&L`jf(X+fFcnsR_;d`>qoZn=oa)!1vOUISC@B0IQe_E z>7(}KE)&*JSH>jR%f!Vswlc~~M+ZijM$@?QD&vZixxqk0M8!dIc3#^%F)P`bm@oi7 z&TPbYvspQxw&btP=MGL<4}xANPpgg!Y z2ZT&hQxh}{48bnHOY?ZW?$A<;28~Tvw$6F%GsEQ9;zI}jJ8s63AP(EJbP5eYN4!R- z0PKWd&sF>7VeR|*L%p`&HWN%7We!!v1v4GZ%Egg)=D#|8kN zCe#$%t+BpoRdK2?>U%7k{%=w-ZEI6Q$d^tdG+eFcZ#wv2Yv&dq1IP$|f)x`yG^6vV z+%yiz5ZEPkaXb|$rC=b_FhTANo)}{vdv6?r&mJ!3!->;j!0~Uzm&uSa&XcpX74Seu zDd$&z)lv-Tw-tqCcHRTmF`0dT7~!hbbXWNo!uKN)m@&^sI_-}Lyq(1MheUuuK$t92 zYxKRQcqh-wsKf#w*XtL~t?A8$^>s?vR|$2DBsdN)Fxeq_U`(dz z5XRj*MgS;fBPppNLGbnCVl5`xvl|F>P#K@&_q;)SW1~%;z^7x09WL_MX=&b{(8uer zOC$t58G86UbW_D}r8*94C!Mxig#1yzkK0iI;^cUks-qD>wfprb%fRQ7omNO& zhf}Oas1F57Og&t{#hP#h4tg|3QYQV`>FT1y_uld}YgER9fI4b4YUQMjd#op~b(6A` zW|J`XE-u-#01(G-`d&s>8f7!iFChzo1WmO0JTokmY~o@{QkjNg8|sv-6i5Yp)D{}A z)Og0Jul0TdWdQIoF$oAAQ(RjqX(`bZ_1)dw)lsArC_;Xe8!X*V=0{QAfi=oJ-=w>0 zZSb##6AxFa)xpnF$~u`(!h;ZTL+)_=O`nH>AaHub$B)wAhXDY{s4lo?@UC^=m`O$2 z(mV`fdTA&gpI7sNXjN!#!bV-E8Lh7z>{|UPqN+M}3=^tJ7mTO*$l>s>3;QaQj7MlK+9Y(#Yh-M!ZTs^W4?krM^F&Hh zQ%aTt!FIh#u`|5R%solmeFM4coKDLH6_uE+QZNfnIJ03NkSfi|kGfSMRk={-NN^r% zF3176Rf2(Q`5Yc4R%>Lu1ZJ4s5~k01Y}TU@+|ldx08fRFuZxT9Un7xh;=#7TC00YyW@WMf0q;(R9E8R#TlTACrXLg%&@DOug`a6k%ACi|>`-{L2 ztkrARRtNt3sg_`c8xP6WEj~p4%G1_Fye@1>sPNr9DYe*VHpkaxj%6E~h};oJ!3pHK z-MlWZYHI6+gsScwMh{Re{j?iZ;mdrygBZ1z%1m`c4;IBwp;EYN351v*&KEDz!3D&& zf!2Klzfl|!e&Ykl$ND`n~C~U7f_}4+CZk5roCr9lGD9KHbr%1y8do{ZrjLSK?x9+4WwL zfOZZwM)DaGaZ^{*TNo}+hxoqUf?z%%GZt}q57gY>?Y(>9oCAWh=YU0FA6Q3IlgV~f zv0A`aV+l#WRDrVhc7(UM_z;Q)NuG%LZ<7^jG|h5m9b~Y@4P&8k6e<#V(2iWnMyms% zF?5>yKU&xX{FyL?+Ez~*I96P$PdqVR6f;@A^gz{~Sm%?WuDmLMRy( zx;k0+erSF+WFVJ%XaMm*BY8(zOAdet6A=-?(J~rjkPN!B7NMR%<_)8()om&0bxBMv z)GVU4a6kyd8=eq-kG@s&;+#0WoQcw$Z?$^~Y&0g_q%K ze4{R2Pe6k>TS#fXK-4E}3m1IA3D`Ijj&dmuw-pYH3rY*QQ=ndzG#dhej&9Y{e06EJ zB+GseHLyJ#OhZY@$WoK)y*@QPy`RZ7^#L3kLPHpvIZO0QSWKx)KGEG!ABHzHf9RTj zwbCN}%t#{Yd$&qmC46wAwdgFw?C@xMRX2`Bh;>;(O`QpqnE)IS>34H65(1eI66_ee zkZ==p;15Tqd6f0Q(ry3u%Z%ot47TeZ$jw^7L670HKhbe)#88bGdsZVWr`@eI5D(vC zqnAT!kBTj625w)1l0rc)CwgWW390p5TXgM8zyEG-Xb;b7kl^=TmN1mfwX2%b#M(1t z))La^{8`=#9RA{)c8QQpFHTG4_MG_OWAtZ*Cl@QE0Ao@&!2#2DU{q+w>z@x$RV_tH zzQD&*;vYmrQWc_OLNn_RH~s~kJdT zpni~&=%&-}3-gZ)I-gh!u^=sz5}>}N)4VXSOAyk>cDur9XBGkMq1ei%PRe z@iMRN%61Nn-Dwf~YY6eMvnhVw4X)a4AwwKv*)A0qmIW9jce*|ikWU!f7{76LOA&{i zsubS`c9Yn#(e<*M&3M`)M3Rg64H9}7oED>i8sO0|>}zJg0xHSyFl9{=3l_$snnHy@ zxWN|5czAHOfb{Z_ETCZ_(EDCYmVUXduMo;8%tl%*tjVG#rZ*SaRT_$y^YL>&uGW_F zh>347xrdT*38%^V0;Ao6sf@URqlo_cB*~`(U_&Vrx#$;Z^ebszHsQGSpmnO0Lm;&pa@nho$TdGsoK~Mn6Znj` zFvRUZf{vnu0vJC<7^Zna!n~6oZnp8_B66ydvrzCS=9{9oaxtR@epFu?;f{3w!ZQ!B zZ(2Xk!TTBPyQ$_UIp%dc0#M|5ZhXv(ey(tL`;9*HKR(4`VxWznN4mm8otQzvh-2Fp zcHwy&>4IRQC@45w+%tU)&7;06n&M`0I#4uWGWV>GBE$!mz2{ByqO?$Kvic6xU4?W? zNRXRe!(Z;K`FQG@mLf3k)+1LBa#XLk7{7Y^_@H-A5upW%=1(FZB96|cRn&n9JBo1L z-8~90jo^<>Ol>ttp9Y>6y7%Souy0RAONQjsnZh!ml~m4%+{(n+zeDVsl1(J z?3F?{c8&84$Bg=1B_bK>79O?)?=Vl>jg$cci4)83ZAGd6+(JD*5RW|hF7BDX3)m}DE8qmQR(VOwDCm~u>Pu_}2DHhy^*O;PQ6(NyWw8Raig>;w6l7tR2yFMr^OuL)~2E=sN%(|duN{yZm zzaj`E%tNjsS#QTFFN?>GCaVqQ8R8hNa3Q0TTWQuTSP0zcE`hMtuFhvRcS&>ept@y7 zb^KX%+{cq@vZ>*T^Q?QQGP`*|p-O;X^)4~>l_qRTi8#i^0ctltPo?E|DU}{Cm(_N( zb5h4=UU3Jy#C%d#nd+~yr>jjw!NhR}f9rZ;@%7*dP(PRsL@|6YxKxg!sF)a0x}UB` zu38P7nGXjGr{fl^GaX!=si_)Nhw|vdcU;G-pSPpBIeJBVJmIcSZmJzz<$3_k%552> z9vmM;fbEga9_=1y%r$LGnIuX)4liGLIQTZ#*GTvK6LQ|xgnJRDlI6IE`nHD&j8#Fs zsu5_KW$^@6P4Abl~gpJuN1lU_+tj$CJ5I&M`Msp@8aG|rs+WN+vab#Q2vtzY8Z zLL;nXM#(?uqpU>|p2K;~;ZSc}umPSEFV1OlF{B`(^ZX1~_Gq{0Gb(5`HL zmwT757mGF6dy7lKx`@xx8%nNUo-cWzxJk2Q1zw}1;Wo@*#AsM0#4|H9ngvNh)u6D{ zV~i7(P4r9q6|`<1&p14;SS$(MKKu3B%@nIGqFhyhzW{Rw{Z|n*2c9sO`6p0Q93lP` zz7?@OHOZO8et(xKe4gJEHcE*shK(>d#?@ap$=O^|QqtC3p2xkr!)-v?bk(Z^+kF4% zXM3V3&ps<02fw(kXWt5|+3nPr3cex<+Z2Ygd1|~lzZk7GDw}tHLl^^(x;sI$u-Tk`(b*_d%bzK6kPO3l2zLQFToHu0J)L5ffjrJ}>b%GSxIU5oCU-dKeb& zYTTZ0Qf~59xw{=MMt%w{#K_Z+NttoRnd_DE_moNsbjya?4s$jhOpE6AnwP=Bh^LsV z4QvO@s>7S`Vgl8hHW-LuZ3mUXAq~DvuqM&N!C|CW}$e5%q8UfAi$L@F))V)kLhurki~FUH=F`M2=7G+fpDD(v%206zbC zx&YZ1CBMqYtF^7)&*~i(?yMGBYM~`Q97|hsf{YWLWTsR?L$Rs%8lto-Jv-jio% zjBglYDiaPnIR69@cOTb&bb-bLA=4N_oVuBu)2X%?!VOT>sgi}^5w-ZXaqiF->NJ#l zkthTl)$+lSX==?@^OLQ)_IqL^0e(mLODNtn7niN)HzDSifXCK=+2o39ilsf44?73p zNrMqhdwaX30ufI}Oj2Oa98#uc1kqW>`=JJ=cYKw2k?I)Tebe;B-6KddKM?!)3%tFA z+_!JH1HQZ2eMEKRAW88AqTIZjlZ%A8lxh*Cc$MlJ=CtTnzUCI1^)aIdAi$7h8{z#_ zUD?QfJZ*E~t?2o%+mnh-=EL!tlbUFDd$X71fNdo{g}}oj@F%rxhQu6!2SU#M+fB4LC1Z$AAs#kX}E7-F?Xp{_fbjLZApzUk`Lpt zyo7p^L~4i)Aam$77aT0P7(HC%P{krXgXi1`EPex^*Wz_;0Q9t|80e<|?XmZIUD5&b zdXv{@lc#)rE@mF)6T=uQLIM)P91VkHeAVmiA!8Z&eZ9ig3)VYeHqEm1VtzE~jc8fx@9B>Y(vQk?M=bwHu^cF(K1VmF+uOV4bP%qh zB7@$ou3RWqrc1Hh{$OC8)*w*G1_hXSnhdWm0I!xzxq07<-Fh*-+WCF4J7^X0CB3Ga zKk?}?nmgRfn05cy>eBnM+kq?Lc8)VS4`LjI6SjnJ8fGXZhHHY-x1M%$%%o?E3-+21 zL}F3IASI3Gk$QtZB@4b=1~F62Qk8$*riwY`*z-n3R6T_7c9i8XrP!%fK}5x(GdQ*|*G;i9@X-E2y{MZ&|>3Lamc?#aN0sI6U+kP7HDcuaaSQCG`aghJk{SWnTrY zWXH(73GC|&eA*NR>F8*^Iz`{8QS+BazU3Ye38}?myZM9tJTl}uz!jp-1Uqb)$BB=s z%8?W{#1uk29~e-VN`YRuCVX>xey&nXiSq+`5sXkTv`9@}>#9_-d=XTKcCHVnnd2Ft zR+l97Gd24z_TDj8O`*h^VK9@DkdvSkS0++K52`9G9J;uu>F>{u6m*f0Nt8&soqoDB z!#!bm=G@HzCv8!BO&Gtt%uFQ!cSy;P%G?^p& z;c})u?dN5!Hvs&up6lMxt~YbNXQYcz?+cscC`Ajo|E0qkrwc5SO)6bdSy1Bc&f|3l zhCm0!q3E?UCgXrCPn0M`NfNdoG@F!{wLP)jn_3$n=X$*EK)7A&z}T)U`!Ju%OZSSTI?z|4l89I%RO^xmJB>9e+dp6QErX`8jhwXMCaWhzQb zA1dv^9&JBaksUv z7ii;0i26{F9G=@WZB-cwlKS@d9QT^CnyEb%lNO)TYS5+$F z0-@&PF;=~#M4;ff677ftk7(=sKY1ts7$ev_kLyWbNH-pDB8_cPxIS@i(*3Z8xtSS& zfU^Jo0{whP6hcuEm5wn3VDTa11QEl5^BT6NZ^abWHiykC`hx_hK7}K`^0@#um<#e^ zy{kX+Xg2Q~z@AmM_dwQW@f&$WO^3WJAVV9q1W_J8*L1-LmYT_+)l!P zjedDSWTV6a_4?)#c>l$o>$63NAGce^$HZAM90-n&mC{CkRG;I+?rtIsc$j@5jKq4c zj}FWnLS8tbTYM)qYNiZ&Y!ADi-O4`ZHQ0&VTwQl&=Qxi{e+<7)|IifAS0MH!84^cP zHH4hVs&_E(oh6_m0R6LaaExlMu|>gmHIISyrUSG8r`Iej3?O0n__`gdCU&Zj(U*QCG+LO9!la##3a zSUK#H)riSE0M%9i6?Ml5Y|~UQqR0JjvzAiHJm>Y+Z36Ej?8XPy+v0QxEq@LhM~%iw zx1YUuvMTRJ5^oxW$Bu4|i421rBMQHsR>~}3@i(KS-<%~sg9}OGWO;3NiI zCWWeQPRrpLz)%D~9~SP+0t7zBM7w>SS!u#kVg&KMk&5K}lHyDia(M9nwCEWdqr`Ad zaWPR3*FD_qh7=czJJmC{WvjLi{myf?jH4hYcYD6#uhgg|LB?G_1SmJiVJf84vuYI^ zJx``tvLOP(r{_D`Y07=jY&)ar3XJ|-dpSDQhKX%^?#InpxVU-l4umBIkMwa&sok6D zAU43Ob7;n`@qyU+`B@k+xi-J{Hx#YUF!#$_phEF)d6@LqY?OFt8*#K9wm|E=D+JjY z+T4jmbgYI?Zxb~lBKb-7`hOM{n#>t|f7}m1(kRJ()o43Tk=zcuh85FmZ& zG`Gm1rKTQIO7jZLC;)gmD9gizYMb=%2KaZC(s=r>OHL4}ug8^Wmt&Uq$Mb+u{8pVj zyTH1*F@mYDDz5MeVObe?i3w30Js|BVVCy zg={zMQ+met0?-!^P!Y1J6`cs@Id;wf`HSjektLtU#p==Z@n|*@2FB6(`Q_Oe2!1~| zcq+_pYvQ14Klp6Q%Zpz?^@q3WZC8ZeFn#Mr@B$7-a}&eiexHPDjdU`FkxTp(G&v?qDL|au5cK=E zC@~Cix*&k-RX?w@9J7XkJJ~Wz^HJ{G>Wyaa;B`m;5yMDf)VvH2U;ha4=5zeVV}N6I z;47VACfyWQcO49${h@7`B9igzq$r{1<8ErTt%mkX`*r6b!lB1(U}+_``@Cjc@T-l9 zGXW@{D@82~Bm7UbPI_tHFJx5wCC8Sa7cmENZU2iGW(CGtKyc)6cQ9^yxg4tb7r?xL zM81G%7Np8@1oeY`q(d#I(^=^P*;qWC1{up{PH~A1vNITZ|6(LZKpq+nvD@KQ`5Z6@ zQGJZpVukb`4QEM&#_@YvvjF~KjW3yGQ`?Pk1IR-Bj6+aN&H97E0lGFL8>=sNv_^9c zah+u`!@u7Yn`x~@@5>H&XbrRSscu>tb)=&eXXsnH{KUCm14(s1;XXbsOa1jrv>}?{ z?%{2=W@MX(aZs$-aw)p*Sk3uHcrF2=JQ4^?P0y_Bo~(c)xrQDP?g=`shzASZeQ*+z z)?}jyr!NI9JrZNz=`gmHJ3kHbTp3e);fkPJYJRJH{Ued$xR2g{TnrZElAYn7F5-%mts?zwP+Wwe7nJT|^b3UoSojmx+(S(R(i=)7Y~_9&8-V@UTh^;swB+i3 z8}foqO+v2>Vn8cD8{!H*Kdn13>UE^xT{t@mya6gU^Xqnh#7G8^o)aIGR1ZKNnkPZ4xl*dezwzl z&1b6Xt7lo;7u_76HcnUNW+C%l|7Y3>kSj)mtmP!(id_JiAW5k%Cx_yre-jg0GuUc+ zSu!-<<~B)L2bGe!)26jGEuB{TucNGMzV0q9auhpoCCCOlcReE) z?Wc^?!gT$@T66wQC>6BEr=r_b;*r(|k9eyAG-xrC5#u*1$6q9dU$NM%b^HKsmXOiP zxyxt}LJsm|9Z<~VxE$3=mjUIKom5D12O5=P=VgLT!0S#`B>rxt4vveJX(W!*yLmuz(Uv{Fg2DM5=g?-Xs7L;raSniAF z(i0XAD>Zst0yllYq$vKDFb3>WEYaP67fB%cZWl-YIsqWbDRsVlHz*Fc^hhwMx-4jU zCtJrrC>jb12nr=|jI^vF^Wl8k?H)g}9XvsRGZ2t^-LOoI2-#!5j)1lwhBvva~(uqv{y zA$}KWPwK5`0m*6p={DzrF%nM-M!WqaHxH-3FW0~I+MU&ujEuCrytK5ew6Jw@PE5j> z#U&+24Q=CN=v07ZE&*|f-D;hZSt%S8+{5GJz2O2HCFO6xCT?b{6mw6%-b+l{uMa3G zvf|S&vuS?^gw)gg9^2fOJ3wi_5aKJe8qM!*kq|q85fYl?PLf{DJ@XmY!rfprwY&PM z=;*idqab4UmUSF10eVLOk^c;+vqp_eaYgToYC>G*BdKtJ~y{%fQ*6g1N z?B;&aze5j=)hK5?ne_SJ{{_(AlBW(-(NdD7Ab`uwl2Df+^fFXkU;;+T@gWPsXm$iX z-kdQG;?WF1_>z12(baT2)5N5Q99Q&bL83FwK<)rM}Q%9Nr(Cwcec+1u##JnSv* zJaSf-!x1)`yBi-66SVHh>j$b&h~CC3Ul-oV!nd;wqmZ5{LGN{Q!79StA(fFSku|)i4e}$8S;jv?;;(y}vh`THQT`Mlh z&1664zcXSw!j7_P{x9_b79qzRkci>mI4*+s zGOSyMQ`2|F=%b$#u>#m{*KfyUoL1@Zt|LGjxU7SB4sV8KSnzfTwT?d&LpdPV|Lrt_ z(0{cqSI9L8eEXjqi7sS79EH6f1i4T-cRXwz=KnEuPJwwvUAJx;+i7e&Y3#;s?4+@6 ztFf)dwr$(CZ8T_jcE11n&pEx#Mc=Kp=9*)SXRdAA8vx2CBg+ir9dS4J7|rh@7Vc}h zOmw#??f3*$*?0OSC!0)Tb5Ev+{PV)WO!H%PX;NrU?P^K#Cx$7U+8W~24ehq@BeKMq z3yGYEtbUFi7@DZV-(MV@ln#iTLQj%MQbj`a5ZEuN9lq-P#T`et4M|8qd*^EH@MQ$dwnJjpo4F4A{O()G)7>jf-Rr6kIcBo^kLBO6!b<08T zwl7}XZ-*Bvt-Wl#uU*^x+Q#O8ktf{@A<(M+D-(TQfxm0Gt27%zDczsXIT+2~du}=S`z0x*H-yvyqlgjx{ID z?vskQKtB94yMsNZ22zZ@W<{R5oP~NNr4M(1F$-AQQ^aEUCCt%sgY~$C+*fLi?kgk{ zY(k+3qlE=W{IX$djCO(_UCB&1BU9v%*on0dKb#b>d_0N$6AChv$*lGRbqY8~Bj(JY zUl6a>z5i~%msPgy)Fl4`y0(5z7xTZ-r4;X&NB^Q4{QPx_PF~xA5tgQ)hvgANj#3O;vzZ3=je)Ekyxv6Xii)_-O1dCEv>zb4`bh zkpy}pzh+p`hX^b;T50BzsyUdFm-v*g!fs?jLfyznQAIlk8XIoI zX67k{E@)CGZ*J&mb$EKAh_A`*(6A z4q8guP=L#Ff$UemBW7G~GfIJNpSG9Ap78ixuDB*8CRs5TV>%!RRh>7g?XuC~{q{`Y z{S@1*cX^4wXov$h>a0Yqg0#0LimI~BqVi=_kRiWVh0s==$kzNNC~P(kJ2>6#R879I z0C&aQQr=Jl0c1D3&V)jb2s%L^o7{Pq6oyWImTpNul64JXhp^Qsl7z*ekTmbRLEu7J zZ7*ecH$?dtC8uNIC;($4#{lN-^XxOkUZt|ioln5qY;)pIy~O45Qt5d%lArOXY(G4U z4%sg^m$*?*7ZO8|8^u1sN+sl{I3@Y3!OBb;lEWW1D`LHqoE6}_1-vKdqI%%61}8-- zKtv1)iUq~54$>JJ%;UI^!Gt11?O{}??VDHK-Vj5?nWLjq*Ym+ol?nGj7-soBXOUGD zsHdl=>IrZr{FXn1IE^`KD_uVPf2@H0-vy9a>V)S8omQjZgV8ZARa|e%8DVH_v1z^- zy&{Fze_qHhFuB5j2h48eN^g(>PW8Mjuca<8r9{aD>O>c{w)#E~O_4zrPBQOCTCk9g z0C6E0QY+l6a4WNw9Eaz>Czj?Lon9VNh@(J95V-f6e(=kKl;vL%3Vy-KcL{ViB8#ur zjC2Xrvkdx=I}QiO2Y=&N|IqKy!{jzX2~-*QA~q(C%jIY~lPdzG0mN+Ll9Nq%pmP7X zoGnIHRNQF(fYJ9HMKSwyGAAVy={ql8=<@L~3pNj{;pJ$=^F&^*C~-Q_llS~wM4m}Tt&m2c zP_smUJE;FtmB`kA*p|?vy8fwW_TItb-oW)nD~Jz_}7x#Vy-Vf-A%ID8mH~ zhmDxed#S4HeQ*UhJT6U4ObCLOKL@B6W6>yCfFHrqV5+I}=#cW3O<=-yFEk2vI0&l*8hEst4!q8Di|OfjDECTb*qW z3sY90U~Bl`n-wASxVINNoqH>xeb!hDUar2uUCDVb3)*6X4WSQrH+BvVR(NQ~yQeET z+Ap?O)8n8gZ7P=h9_Nln!?M4F;kcYm4YjnEnw#mNhfr*@6RU}7soJa$4i6PtTBhsp zpo6VKZYmgI^mO$qlf~wN*ycNwDQnvI7ndmaLp+t%;13;U-BHvaJ_9oMgSsm$?>O6= z8-llS?x6IutrQFKaN+M_UwA{}^WuZcxErxI@i)yOi*|Q&k-7+jm5O#FdiN-zsS`){ zLZQ9Qvn*b8HmGPzvnhXR=E6t7D7BT0jY7rmmih5dK}Aa?LA6+k5)b(I zG&L~Ol=uq>uj1n3V*mdA+2EjEw-UnoTx`IzvG1nCEx^AfBYr)Bei;jg@#3N)d?s-}A&+05T`rDrW;F^E{<%Hpk0{jqJDx9bP!|cUtS4Tl_}`$qO-4EvMLX~YcmG>=4AF*xo@&Ly*y^kM$%+u;jMP~ zfIj!>Q0w0{LTnpnNHJnhOpZ!Q@_79f7u46EUG!;8I7*Vm^&UI!St>~^EkB=sjT3|{ z80}ja8yFE8+&r~+Gc#L>O$gX1xk1>Jb>nRLfdCS_iB}6#J6AUg8y<@fg@Tgi-IC337{-k&puVs2Md&V@1ASTJpOp z9nG*)CrFMsDw)}ARS$ygm||KlwoCrxuC2R?@Ve@FY^++jJs~0JcRbfOF{z@Xr>0|2 zrMfs*qJG0LPca%kU=C#NqU2}gu256Vr(1ezc>#Y0OH4{WmsmC^s<%Kz+tIyJQ%d-y zBQLL{5f|gSKX$(*o9;OTh!ktv$SYBj1HcNK6Tx`HS@%_r{6QEE6-w(6Y8y422F4N; z7h9T@-U04U-!*h~j}C=-!|K`^4!ES1w%}F|0#(3bOrqI~*y9-zzF9C!UaoXbt%JYm z>GiLf(Rwc4we4v$ZS8gK4BoMDSY^pX=>i6eB01{)#7$DJjQG(X7ki}(Nk;?1Dy1QB!DVXR?&WFcJ)-^;+DXE`}C1}`*=eoN@9|9kuq9hG?&_lg} zuyk)3b$}TihMu&|<;pq&W%9awF_ut^-bNe`oYw0U5coI&H`)t>hpVevD;==yWpk2D zTTd4su6LYzp)yo4k%NPSuK~4Pi@uw>PG{4fwe>6vKunTJVT{er9vpx(JBIpdY^AcY zpc)w(dQMM6Kl>d{B2$l2i5CPf_QdcZii}Le!rAKSZL1)PZ!)JR2R8YQoMu|IDKd#B zVXTSvWy`T+#`L~SW{T-IJY<#|jc_!hmpDM~34l;n$J?z<;%1?F#1$d$+ymBapf@jL zXqca|E*S>tgkkHP&6}=63*hlX|hSe1$+Z8(9CCCJyuc=r;xi1*~%+ z#QveS)s_3pnQ&9AG-7U<8by)DGxQld)%#ZG`zY4CB(LPmB=5K8Gj69-FpA|apAVp` zm*&NF*Qm$7#chq9Yp-tSrQLs!Vynw=dUT+wqG3E8iuy#u!_%(gQXSQsgdXG93uJz) zns;G(+m~iSMKhlPrW+(Y_WZEqY3FNBF#hRg#|sd&hSU3XM!)4-2L!hXIBW@fG-;Q> zI(Fp?LlOC&vO#!R4p5xQ&*zG!d_h2C=u;Frb&O(|c4JvRpRTjKpG!V%LhwRg(viz* zr%TpkEA$e>%l?NSqI;zhDXXcaQ5aVi|1-h)_R3r4!y=EnpNorYWIv3i^%v7!=C_5F z(v_;NU@uA)2Hi5ta+>ekCgI7Of!r zKJfvn8o;&7VAT5>4`hqjr2+xx@lGO>lyYty7t$BQ?8-b}4kJ;p>9}lIXGbSGTFU6_ z$n!NL@u3jPpWhI~Ot{W7!1O}E-)y!x!&lTg7nheHTIJ98;>fz19C%Q zVPRU6y}YIqm;jDhMO~c(Kqx}Q<5KDnh(98{H$s%&ACcKa{d3g~WXmI_k|z-B{=p$l z&U9$Hd@NRcuXR|35G&^8{mUZA&@?@A{+FO_w%5N7i^{&Dy}_dF{o#sbp3d)zOnZoru>B&Ir z14*w2In-t;-rHdkrODs=qj&SZ*4bB2=XF!=faq#IPuaI&Bfz`Buro5~1GxgYQkV0h z`7}CwhnX0@(weUzgmL{j8f!C|!p~$fJzlOc;QzOedKaZVZyfQZ4DAs9kZjYsj)61+ z2gS?)%>C~%@&c>*$YFuW6zPa$gD`l-tD>o6Yr6lscCd^0^@PVu)FsY>$vn67_MVrK(>(KwR3s6rVo4=-LN}zVd&(&hacPC`kdq^I1Z2IUY`+9@9{a9X#y$Cuc*1 zKRi6#47A8{z#}VaR`2QQ>1o4Plwyt&i>4#!JNqB9W!t;mc9ZEeRVs{P^^C!g4kZ-+ z0_d!Ii>1l&2KH7iFu%9w`#ra(*So?Memn=_O_e?lwaDz-y$1dF#E`Tdr4AI`s7QS3 zS9firnO@-NFa_AZ;Vv-QEK@{qP!e%Skg%~=T?D#haM&x>S`JHlmfW5^#$pc4jJTa` z)LPyCmDEwZ3_Re?IJmm%Hh+y34mG=Bi@+%+iX0UENd-kM=XE-$ovKX00~+}Xpg*_=u zOKk#|73Wm9#AVH@X<%?L3ZM5GunP1s4%m)rg|@fzNL$SC;_-X0+hfE6r+y;5LyHdQ z3k|7icxthgXn2?a2omzkwwCK0e4Q^5(A8xk0h4SX#^vft$|PS+W919s@Oe_yjW^rf zSd+|F#{q__$!G$V0znE{Cts#vgTpJ7JJ+PciS^)lXVZrepPDXVLxV4xB z;c?@MFx%{-gX4v=~R+v?I2iOmo5I+V}Y+^gVL+kbqyNv2o z-a6dD|Am554Ti%;CsR!P@HigM+-Q4wI1^iT65?$jqs0&Y?GhJgX24coQE|qW6kQ*- zHo~0y^~c`< zhYHD^F8T~`3tm1bKnmN&qwhWPz8oGAs>whg|?} z6cz!&Pu>_J|A8O(pIxbOaQs#q;TTQv~3 z9I|E()C!7-T;6xjgLS}lE`^a9$U%L%%6dCaT~A1fnyH49;nbUPf=LxDko(vA5&#h( zD(Vfy!UwJ2g?w=Ow!r+7l~8{HF_)sDU=5EU}+R*%`@-}qRgO8%mM%MLP)d>-1DcKy_&{7nogrJj6KPSMB%}I zGJgieq6?5~qDV6==bR6jEi=UMS~l*5D!mQQf7i%&oTG<@TlPiRX3NM^O3_dwWN6x2$>&d4HXe19~&SZ!+%@Sq=F<38z!K6Vlp z7vJEmZ4Su0JYE2Z;EW_HN+ms}Kn5utZq;Zl z%}v$y7_s*KPHyLx+Jg6+*kkB^DOoJof2?Q)3`bO3*JL$rSFt2C&mAEgHoCS+afN&J zq(Ug4->gVI?U3SJ=fcqD2nhvgnu(gNR_4#e(9iCC1_=(@GYktBwAk{nFM`{EM~5hg z4D(xK?li(z-&)4r*(9v66e%s=S(kse2M&khepQ6yjP3hq0t4ZJw2x}lk)^|7Tgi}L zJ?r3S`k{j*w_^aaBE*u6lv>muSb0o{w7KqXfpON5dlOv8AvKAB_g*c!v;usCMv!65 zpgeM8b7rZ7&-zGrVkAyMCHrKdrD%NR@>TRMNf$@5pt7Engr=f8e|V`aF>h5S=@%sp z{rpQ;0tV2#pmu4him)PTsGMnOYj$sOJ7r-;WPeQ9H`G!+#SC$j6z8Cajvat)ZfV)E zhcF@y!MfG+)MHc;>;2IehX(e7sN;%ZD{h)vnvWpyqd$;*VYf+)Wf}!ZU||9Sk|RNV zr3-YnkBF#$jItwo5L#9u6^#T$NdK+dBFYA@ij%-b2+f%EXizfOIGpPgJ<52E7v~If z9#z6QU9~7BgmPU!)FsMUs4KC31u;=bRCM&_`%t?UueKJiA$4GR{-oChEBd=mg>06T zR{`j^$*ek%nm=>?UO`Mq!5SwD&<3~CGGgzS@q5rX` z80LujQ)3T7RUT!5g)i2{ji_7I*U7jyhh!9D2lA2=t-CJ2Hoy}PQ~9X_>) zx$NM~+>ilD?)pMyEzGiL2E0;&n_H7P&g?(w!ILx5zi_Y!>+T$SfK6@q;z_(tIF+yz@0|G{{GeM}g znY8M=ih58P95g)(Whh)KvE2%>Z|~di4hhKR`}<~v02E$^+5b zXF64B9yxHy#q5OODl~M}gk<+-cQD9+)w1(@44ehTj{6|e$LSHJC}wUn2AJi+_APJz z)TDD%Oe8E+#W&(T+;GTuW^3vEnlEwVOMYKx0_&fdNI7UqRAYvAgY_vTO!#6P+cdXEijnstvJPENe7WBU7 zhjwTo@7OAL(?bK+a( z-piO|*o)|6(pi&wU<2kd1W3<$2y`JG!<-pwfl&cW!R*Y8+t7PO_~N_s_4$shHn-R9 zzo85+m(mMvqjh9Jftx`x`|peaKth29V_fYUYUAmOn9%Ig2#!8yf_vY|;Uj?kbxN$) zqU201=1AXyrQV{h0R@5ryN}@nR2j9iaaBp)XGaHX?2R_UrbDx@>v6NhswS#FTLK*l zgHQI{FS)z^4(dm<4wD5Y1r%A_4)pAK#r4hWKlX4=d>-$;@(lEQWj3wlfX~P}8~}s= zufA6J!}p6_4^hwB-p2?7wJ9FTY=MxLyhQ9~x1i08x?8rVi2-d(x+`X?n>yuz(=bKC z=xSS8{7MaS4G-_j`eFcH2bVL2YqE-BC+9nXSfY@cZi@5)Aq~~*)PcR5x$$YI!ON@z zb0h9@rgB_`vm#_r{UzF5>I7avM#Q^f=Oh6DzK4$Hp!$EmN5D5`&Wf1WdqgcIGj&mZ z4jL7Pbr+s$@hI>eEg8+|+dd__WCRb?SyjEubTDznO^h_(LD>20`Q8ad2WY#8$RpGP zqJ%W({o`bV>I2pS8>66#ioRn>diCJ}nF!=h{MG75g4d$j1~)xt7uUc$@3*=KSq5xtQH;9TE+w=k$8+kE8PmfD8z#M^NP`xz7=Shisdn zzEDtzU%c2r^11BZ_aNVzGYHG-A{g(}bup#1&=*8PK=nfz06G}|Ghsqk8Em=~4u|Pn z{B_1E`2AweBb&A<&i2>{JyA@gr7L->yYhEfIII`17q8a))}~vxm~w#?PrL*}ehdx{s81PPMHRYoUabM*`7{oj4A4df#}HXi;7pKKr@R+s|^uJB<oScH?P%|=Vbs4?EAi?HYsoXc<5PE%?jCqWNR|6LR=$BXHT z=5D;Z)#ZF2GtTx2{`LGMLP990ul~M-)6GywexWhZeWC=(He)Z=J4@)xCIJDsh2Ynq z!@)0}4KVE5-97!on183!Cyzsh`Lx|5aTwW1$a#IdE_)d;0@nIqBw_Ro^rx_VqS}b* zJeC3-gws5nE>I65K67by&rb30WWv2W1l9^tz}5ceqD^7I=v($t#s&FZT_Bhy#o+OA zB#&>xm!)f)}p8_mzKcb(xTG0&3>n+ z3lRAYLF;mR&Mlw;L4iTRLn6q221A+B=hD$<-0i=jy{((RIkGfpv3Zya+Alq1s8ngm z6yIyWEFQ#)bDa9eB_RLrvrB^qk$vI5^1o}e-r~Aya>KAV2p!6p=VK$Hq+PCn>2b_U z3OS^dw3TN{#Z^vvgS3A<-wGrrBjU4@n*wQ`X=P;|Wc;;xSGOieVt|s1;%u{)nw7Ox zqUa>9aJf_d51rps|Yc!?

O;2osAbo7_S51C-!lRe7NjTeD`^E zoa)3oNMdfA4eF*ge}c@&mUc@dqw)F3W$m(@8a-d`+#4Yw7cLILYoJt|hm&tkk(rTa zT4X-E%xL3CZ|6fz?pr(kojH&G-)9sG#WJ{K^@)#CR@EM`RDXemfk`u31(;gE41kH3 ztM=qBcAwp0#VJg$)oLA~zPwPEO5!vccvz-3jygW5y@;^c+(>Trk46Urd9#^qo!`$+ zS6R4O`+SVx%LUqe)fbk(BFF|At#3YR3laV-&cj>D+mL65Q!Es|#h?#DV%l!QI z{O0n01tTITGmV$|osM8*45ReqW6F0|0&D;huLWK;bLup2k0&bG=6OxPNY%I1gu>^E02u8 zt-BAA{}tf8rvpYe>Z4D!`Zzb7@85pB@eh-`!wJ|Ux)3|zxBH|b6b0_q3K6=-+1=Om z-PgEV1u11Rp5LIGal7%WO;YD6F(`P+0s;0-R%VF@+rDR$Y)as5yu_sSF3E5fZi#|Z zSE1oE#aq)7BL7+6!2DExcLN#FZg)V4=?5Ub2FulY!))0I)sbWLZI;0Bh#@?ebpUyX z540c;9rhyGd+7jJ>BEBRFET()2T~$HU((~Wq)0(g(ajy`KW%(MQCTm*VKEX85_G%< z(Q9rbeqe~nnSmSnyXI{uv>~3ey7qU=_pfnyABXlmM;>od@rLcm$Cw~t!S(Ju9z;=K zL@biCGA$?=mqa_1i*vksLo3Mt&l{Hw2=Ck7e)nffB$d9J-9*`FrjXtMWWKn#_++j~ zOjH!eoB-0$XGIPqi^VyHiAIkA53z!Rf-R&7?kBgf9iW`f6EoTG4qxWbNK<#%K@v2uKrdc&W-9@HNPz z%*oRyEzikvPk_b C{B4ve1~>3-(O6+0J@DyYcx0S9}7@EvBrwpHu33||zqHuCgA ztI~O;i1HD!P(CSl=$>PyhQ2puWmPe!)f>hR}fC$JK6il;2=bukE+Ol$Y~m2B(ov zg#M=9?#d*HfJTq$_KB@JV9=9GM)XP`#|WN^K`|pYW?$i-;)a<^p9yG_SeR zm9KIxyPDgHGb7~tHjP-ZMf2X~$uP~RP4Y2L8JCm5m=i6QcvVXtqRcm%pcB}7_Rs4YELi$&ve!!%xuY^*hD`;W-LB7`S9^8VQf%L|*^p{$~8;64ZA z6LZ{AX`c7y%*ZizaMy@7ciwG0=1^P5)lxC~=X6&ptIg%6**(9RKcwh(gYP?B@wsqm zfmyb+B&McnUv>SegC8_ssU_o~gU>`p!Le9rRgUDdnhuJKAQQy>gjV85@|H~^(@C|{7G%r zNtBS6mzRync`WlnmJ~OHiG@y%Zrc>qQs-u^O4HeDGb`)u?lkZZFei;3ED$$8)ah*K zI;s?%CNPkbzaov_bvY&?BBHfnyKO%97si$msp{t*^f|<#zgI_V{7K8%V>s3toX%z1 z%Fz&u>5y#CKnD-^-|j=Hn_I9lot0zj1S4J2>-93!K0JB4YC~!uHdIspX>sG*{=V8| z3B?8zRfD!t#7A(Gh-fm#e$|sz3vaK}t^|ei)n*5u14ger!QcE!;Eu8If2|=AC4A7u zr4I(Rl3I9wla=0Fe+-J$xI?V7it4@)Rrf+))AFF59Je>=TtLS$KSjZkmR(DCsGAHU zCN8vKw@3CO2o^RqX8@1 zRibK`_bby&hKoiIuEr74*1Fl-sqe*v_Tmbp<<{Zyd-Y;%>9XBQSeWLL8nORxa?-YBELlTcv@IvqwvoFJTny$Fx;6V-ivy&g^@A#l;%*AAv1u{%4+KikD$jRw^Q(I3A{nKjq2t3h*^q!fo8$c4&+kqc>{_iy zd%(aD5cEffFKUS84s0u^oY63eBWDFuHyUdGni3I2HBeCbK=}a^-YolSosOb-X;=%S zq&;7tWTPq$F@iJM`#<`69}X@)>{iwZRHyj$UPJF416dKop@Na&C1Ah$3($!A<`qGG zZDir?FyG?ih0jWTE+KstC2X5rxf+_yx;;1emcT&IF?_9pwmPMu&`G$(bdy|b-{xBK z6j>P^1_nnK_k3G0h9bE46Q66-umM6@`5%(aQ?Wa(I5{T;x6 zjV|!EG77Yv4~pVJ1-?#|0N?v=zZAW`5NPB2qwCa_F!BZcU~B^OG+y%O$;_Q1is11n ze#%)_i>8aPNj-Ka#4~gQotP^s6Z0AeGj!0p70b`re28Z3_T9rV&1y|QX+N-c>^Uzw z+8$bMdqs{SxGk=tc(vdcLXo0^1Qk|+!=|6HC?OAP!^ZIpek`=7G25NMn((7q&_g7v zgZ)~%*2}ysB&fPG6#HAc2}6BrjE^O?MZ!U~jPKDDX}QXoKR7JcP!YT|!8RZDqL6oh zV(1H0qI=Z|{Wm7V`1m4hmx{(X9P57QbOKKw1B*jNV6tJ;VgB6Vz^gz2sNlzr*4%#CyJlm z$FDgf^qUo#Wy5+1iJFtBg|Uk0XMZ9GZZKog$LI4WA(UgW(RuKBkH49fs?Xr}aDMqd z#o~;Di{5aOW-*)PM&9mVV(>0MH%tdHnSRoSDeJYM2W<%sUN|j>pl0`tJK>~F^3T|T zVMBP5>fZ<$z9Iv0okRzj&^gSYG6&&n84K#lD{MB-gOsS{>wuHW`^01?7`$ZDcd{+M z>Wr5DOXBWhpo8qJ%HcEI5G@gNKY&{Eyq=tpuzz$wHrU@3Zs8-$eGOXYcF-!4zL+%A z@RQj`QhAclB9Xhbcx$}xgb^H)qQ8HIHvLzEhp@|iGctq*;&py0@CiK$X6w%trbqKv zF1Mopp1?!onLdhS=h!8EtebLP_wjjW-@>#Ie!SY+WAIsd2O`yPL;? zXWs}pv|b@%kQnc^CK=gV=28;#P(9wFNu!F>Tc@MSLyve~{lnCFDUOQfyMvZen_QPC z^BcaLX4NqJikXY?!Npb`dtH70FrLADg}@GcE0(i{{gF23%jWwtzXOKnFSP_5G%PH% zZ>VUfc>SC;IXjnqYl4FQf_>{BAydd$7v2EcQFOQPyfCCV<>~k-EY7}J_IEZ7T>TH- z1W}+e+$Q|WN#GUWOmf+eDwWLOl8_(_j|m)?cA9DpXNZ}FP*Hedh-FO=Dz)WMos~{! z3eT#VzM&oKL!(kSx4t=hX-fGQgKprvhxn9ASNXgOkwv67DA_DT+lunuC*PVt3b)th zorcltaYU!jZ@%4GyR@i$AIRceO=#wRb3Je9PCjPdVvB zmT%CVx=vsXN?4wo14p(@3`Fj*hKaRSw8bpNXk{=Y2+Bk*dbQW7HLRGa>rF`-3ee$UaP;&zmwHeAnZ^ zHv5U+fiqOLEAsy4TAF8+xJ+PC|ywgQXSy1^}}K3Yxq%i^E-z3=mq~a zIg^%fCA^-rKI`9)Q7rqoff|}GdTH_5Tnh%jDs$ACBcrA64;UGM^sroV;tn~=cJ4u& z3)p+W-(h#;roFz_e1=cF!Tht%2KCj>&Q5qht$F|;@NNoDpo8;S-4$wGIA;N{fMJYi3to0eeaE**oWSXbK8K{0l zqAZpu{%dtAbppdzvlrm5KCM33PwDM__;}eeW8WHAtnXhoJpOU+dQ9}E=bqpD^=Tf* zditjqIpq~4J;D&tLrlnW+MhE8lR*gZz=Fb4&#Z+tfx{%`qhqk|arr!DG@LxmQu!wW zi5@-6zfL0H?WQ`1PhVfy&W6;3p&`LVDR2+b2`(Gx>3A%=*Z9W&?q(0O`AYvu{fZI= z_}u=r#O-Dsj*m=ExTKXbBj!I`Icted`9anCQ>VApHi|Vxd;L2|CLt!aN7yG!8eWZW zV6!lo*&IMWNRkXMC~t=g6&E%%w5>MVNOFoFO5Y4_1-rZMVJiivKzIRqTZ@XK)ahh4kRE6bv^ia=XlQ8Y=m5@Yd1Ymz zqjq<+WOCI0!0}`q&zb(V9|=%6foar*QA)p2k!lcnkP_>*${QISOx?@|GUGN7g8Q_* zIC8i4iu7?lNXije6h15SR{`Daa&_4<<&z)1T#mkwXcJk{3HsORoW)@^N5$jg!1NYE zQf89h505+_p1W8kPcyNVaXe1cJ)JRWZLl-8Jzs|au(i@s7?@ivaz!<9@fI>*mwB)7Xj;yz2EUzc znV)x@_-FNTRejhphWhz|p@BuHe0$@2HhtWB4{|3nK9(ah7Uy{T+riL|gz|0a=K1Uq zfG#mmpWf`p-(!g>Rp~(TJrSP$`Z#@CwP?j#hA@i|ckj9hnzUSlo?uHcmK44|%FDfc ztGGSzi>(RHM_6_vw0G6hTo?ws6H{lyRmO&i>9VSLwwGqHW`49Ro)6fIB#QmLH(tVB zaR;ZL7<9@Mv$m7}L{rsD&CUG+4A1>lue({rt#UpLl&e#B3@x{{ObcuZV&K7;`d%*& z8%_30m+9pJ2|%J{H3`MR)YVm3$y@KC@lBJOng-c^Kj})VRZpa`!1Zs)&HXnUMrtBl zF)vnyPpRfr>`UFWH6ovBIl)R^u%?;6<&vS`hrZ{&OWK#bq>HuhS8WWy;c`R#WHixK zIi~hF-B*${5VDK)97@v8YKifk3Q4y7LGK0>iP+QC(&@%$S>y5?ZD+A)`yqGj91k_( z=)$MY(qT{(ScqBBzD`GNyAERX-m*Ai9ls}HyeHUG`s;=Q;4 zqilAA7aVS>Yx8>b*8<6wrOCV?oF~@fKMq{@%U5nSUE7w!88X~w2EN@bX?(=O?iBL} z&H4k?1}@ImRnUj!(gnFz%ZapOKhQQH)t_mSCGy8$Y-)&@l!L`_>id+V*M2NEwM1hU z6q2if_oki-5K#5#%M1A@uPXvaU|7Ec1&`25$ZDuqvUft4I2WDI=lwvp00*Tzi;RJ> z2sF%r{2QOwi~XOtIDJd#@SlP=;+Uz??w?%1UV-sXKygNPV++Qd-BTfQJ5uk?)0eM+1>4F#=7s66(&`& z+HFxd?3JojI`ACUz{Uer$nq&1OaKZpH+A{#(GbW;FCD!_D> zyTtbd#WeUic*K5Re4N|X3EB)yF?WO4>B|>d0tw+CN)|-8GLOnJnLw$O9gD=6R#Hm8 z-qjv?9lj#IXLB*J8*DEbw<$bmNYjlF zcd&y4tkP!5@}o>Oc#xFBDxuI|O_+RtBa4bMgCax)UkYIJ%0@2?K^%PyS!R{4{4vqA-3q2im`StT$&IQ}$$4OnLT-95etkSIIW36k`E3Gb zx=(w*XO*&|(m%i-&7mJ$+?0Z~1#n-|xAHvXq2c+WuU%RX(=4$25LWUOThPIEZtR_gl$tL>rmZYs6|>4q;9lD)x0is|~V?_(M17FAF~s*vX$xxYE5jP(4x zly7>*A}{zx^AagCjx8jKPL z?JPxjng`nErD7n!bpj0&c>vg35XbA~ek#L`hwo!+dkBhX*>#jz3V_;7BujUY1l&Q^ zhMTyK_E(FXVU*!f{e&{6W?7k+F`epnS`(8rzJGNsV?P*38&xQBAD@`O<94P+ma@0E zAJ#+P=g__{Zf;Hy`83a^%M}F!Gr1-|$9w^N5o!})@H4J|6QEQ8hQi|X+FItchY@{;Md&_IvP4iy`lY4^BQ5QAZSCMm{ql^ZV zOfy=uJr}(j{V2EZ#*)v)P!Ufpj>@qeV4(1Q$_KWxuLb8vT7M@0B&q-NDcj)x)>K~U z>f(|(S!8Hku$K>BU=g;?KN;!xS)_{ln2e|VcrjNe@1aZOG0&GRSZO8MMu?tENVc2E z`<0L9Nq7B)JW|U~?wVB14;cGwcA%NN41E?E$Cdqqko@fL$sAP>qwnph<^Wefv2R0q zcsh|#P<7_#npR053&gGfw zeX-Y?wbxqD`;<2pYZ?u-nS@f2RfOTUgL;$bGnCT5rp{dDCy~9x$C;pu?FWlCu`YwQ zc=*GX1<_OO5w-Bc(7i5l{XL%qTCN(@;>Fpwe;#}m$W=|}m2>k>p=;A9>W__{Cdol1 zB)7^gmSX7?)oh3(W^Ft(DMB>+8OPtC=;<388aAI<=$}v{z%{h6!SIWaOG^*;tu5=_ z#rE)Sa70+x&~T3N)B50V>YI8)$;?c`$lV*^O_KV2caHpOHpcVaN3##j=#L$9|NP;y z{Klh*-B@g{a5)_pecJOzYQoAbWcNj&xCEi_(ty3*=(OC=e8@SVp>_coa=GrhKIQeP zvM72uokC;q+|cshi`IuPvy<|5Q)lbt?5)B1d4W=f=y+Q{bPUsSE5NX!uYWEiz`FH> z4*(_l@1IC}GaH+s@B|ugp14MxA%sXFU}5l?qO>%<7}Upk027YYrp^Biv^X@)OjTr# z%H8j*uBq7u+zf?s-V6?3`~v0!1YmL@0E>%q|5~N~Fv~r3>qEI+v)>~eE^@OZQ5eS^ zmy?yM{yY64@KI$8B|UAR!h`RDN0_7NsU7tdefVIL9E)il)P5h981e6=9Njp*%5d*8 z$a;X*E@5uB3T^)Z`*^^+KUSxfgiqj+&z?m_Q~;^1NdT^)CNg8T5OV%V%wzGR9e~a> zXGFds(M3|;V!W1&1VOr)bwB|b5|YLi-*vRtvUUI~ zm*B8;)c&i(Rlb|@$|CVS_rMRxRqE&HZ^qj}h$#flzq*Geo00ywLD9qs$cL|e0Ld{~ z@D7}X1A|=}dE#)Nor6Y;O1w7PIit!?Z&D@Z3tSA4d8C%kP_F`~id?hi>P@y{nHj6E zXhbitGb6G(-IzjklZ@N=*scqOnt?#o=#B1BPBwFh-NY?XrP%XyC?^ z@DlNTi3h#a9Wpqe5N~1{j}^cD7B6>A9?G!^@edYz0w)4R)kxIIGl@2xOcq54TUqWi z79R*s5g>_D#sV7pgI5IPcH;ob8&?;e#s9 zqac4pZZ(U;QLi5FzW*29ycnqtSe}Gw=ayb|!kcd7nt2OjZx)-v_krl9h9IGs>E8B8 zBlkUWvve|QHlk%yn%P4Nho#I?e$o1XnQdbgwu&nEtaC0;lEx~tGf>~eNU+zC1tfotN=aD z_9>{sVP212BDhHo!bVc0UeOG*IT{(^tDa~(e`PI7$_||p!2P1VjOn0$Osl;WnjBA) zCBYC(0ea03_ba;}$??nkZUc=++$nRnM7YR&N8%s!G>p(Z7PeY-f1ak_7m)paWF{mxws6s6l9Li1>b@#(7t) z;Q7U`hXWs|_^_FAS$s|wF_J_>#JOO+iHp?ozIvL*WoYyC!_$vNQXy{IZUZk1RHVKa z3q;kU8%ir=xQ)z67m3dkxYfA1{;*Jy_$&_&KYM1$2#QNvs(QU`Lxf6ztj+tr)D8p% zY}!c4#V}5z6gTN}pibQek=@wVm8Orzns324#TxtYt#X3!G`mF{KbCHdEX2e?C)yAX zXGG0+Mf7r!@eOKiVp-TIa;49)bvBLt!aOx2M9*08ovzb zgLW;uVwWp-wNs6vZaouZzxiBrTLy3=abwvH#T$y(=wqEfD=54$*fuOIvz;vxElH^c z=e3z%K#V9Wsz>QLS)bDm@}iq&y3TWsB{G@Er(3U(dJu_J$31ut%oxjLJzx&T# zbosKb!|+-Q$tHc?wE*Aox)))Dw(2(dzY}@w5;KEck_&kRqNlae|CC}`SQTfPtiLw$ z3JWgW)3@Cu^)7K`puK!yWil~AS3qX5qs0)dz&pc1Ow9NroL}yT$zwc{P%4~BmXP7+ znR&$d$=Y5IPLs4bbJ{BrPEPlc7>644^qx*(XLP$xJ7t!@iW_am4FsB<@s<`p;J9&w ztxZvwko)ec;k?E(3hIH}-@e)5;#Z){G}}z2T6(>r&ZFdu=jf(5uFrzs4V3gOzrp>3 zTpA+&4JT4IcfV>}!f2~o-MQ;o_Fs_0$Dd10DDA!JdUnSG(F&nb-!Rt|vk`Rs79*7O zjJd3yj5PGj&+1TYGS%P;>H&B2it7urr)C~EJh!0%Z2*)fABzSoH<7Q*yi*68W~5u- z-BVV)xnch;M+xPJ9=Y!)g&wRI_KfROY_ZHxpGN)S)i*BVgU6MbE$8%;{ue@wObkY@gx6?&DNBdvUjGX^Bo&?AIpPz#0+CA)%dCvjuWKAHF z0MhpSX(j~Q0l4}V7m&LxXdCa}76yf`!i)V`> zj{fR~MT@v`X`kG=CBeNs-_a<_DMJWn>USOs>+Vj#>c-btxOYD0M0*I^*}k2skD zybb3QHxmT_m*^XM*<5_oTieG~wX@Urk|GD`R1&X61AVr&B(OfFSr}zlrky{VuNElGS&zw&GBncCA>rX;9dg%7L(1rbaeulP}K*OFy& zCWnF}CDR<_n)$tv0DTsxsmc2YgTTZmYi?OAF}kz`<{BXk+sBNu5N$pba^uuNeKwwj z+Ul8qA=m-_%{~DCjLqv`0$Ds1VgpKATU%QT40>e56uue*hftfessV^W5@X-BXUV^W@3n!3K&X2q{FSgoRwoSM5P|l>zNn9959UvNI6gB9x^1r{w}rdX&o!5 z_tRfRAA0o|(U?Ab4}3;Rt6*cadq{ffD)ivws&fI57Mk>8q=SdAn{t?)N{YI+od?A1)enbj`KAN>`KjsH1{-Wq6PLv|+L%DT#LE6^lU;6^_ z4yF8;!<&N01Y$34B+~GD zASy_|Y|g(HbH@!9+h0WtRT=9iv-51pbxVQCfR@bhP*VoWwCLxId)9nwr7OGwlAd%W ztB9x`(jtuKUKHnzfPw*+QH-;~&7|L5lfRyKQ>vH4syHTRgsK0W}oCsS^RTbC3 z@-kLdu+@h$e#gu@fCNgM83OMxV>x9>J*D-W8sbRPa;&%IZ95P~Ji%Ftm`V$bLBqFIQSK_DGBam9``lnnfC z3T?^-)Q{*J%{5(i_oOV$+AuezKC4HaR>_6eehFwFE5j^Mf^@)I-OgHNF(37eK39aE z3&;oZgm|U#utwV)_IZp%0E$`Jg7WThx~YJlo)fbc@7CAK{5%>5V|x53*_q$3?wrr{ zhxxL9xkcN)k$XVNaUbYrEiaqTrzhD(TkBm3PC#UR#hXnD@LZ4I%lJ`5gab|11Dg_tpxh=dyylkm)qwh?(E7{$zfMF_S%b zZ1K`6%_^@*BYMzBq%X+#+@r)4Bd2=0S%!YskR!eA>syoJ2$Q^fzEJ$FJTpJ}b#9~; z2hcaPM4DrLfGLOasjOAqtV}d}B|UM0?EZTMdlxFLL{OE81Y#%?akTSDnS_y&i9K8Q z?Ms1kW#T_>iyFB`wN;V~jAlv@jc+ zBs~K!=8G2J3E>Y{ZWD3PQTxy$FnNZ+vneEVUrJ7X4sj6+91ly5Tk=&pp10jV1>i(s|>q=M*52T5ng8X6O4+EI& zt2&Lf`Gfjh#Er3t{f$&u4i`g@X_}iZS2`fZ>E?+RwrlJ+GYgvT95b_)(h1~=+A4pNkplE~pbZKB8E>~}ZuoZ@rZ)%#(pXeA`8nHg|+Vap66Am#mY3A{+uk_&o#Wvt|}G#&C~-~ zOxi63yJz3h%n{NR(&HQ`*5Y!LApKxokukWHsj45sSMsO2kW+rTZ^4q8@ zma0@U_ZjFgG;!m9^8&MT{M){BROy^oS9-iZ`HsnxeLSu*;i>;mX1N3q81*&XwoKF^ zxVbbr8u(@l;pkMf>(-r(uXh~t-41)5eJ3^j_(GdIx{m+h_dL%%I<1pbHlECmsfIoD z(^6I1++*Qaw99%*A=?%m9_p;38nyv5mUH)4{f2Cl++J9kXpvn9I1e50JZQN!&_+9L z=CwDgu^G%Nc znc*|FpZ&P0pEvYCg|yGf%M?I#vaABN!)x`q{q#gj#GycqRc?Q2M4NBTNmqD8HZP@*^%HmEi0wKxlhH`gj|QNI@0HeEMEd^yPt4vvz!%iLOqU z!sbO{P=Nl{E7aKorK8oAQO^)nant%JKTBFxnm8~|6PGN@eY?DW`|14!^B)0=gw*&c zTM|8KQkHFZ%oDKL1Hx$cBgYogSAI5^bv7iP6AAtqO2<#{95*pDtwkP}ya5OyYsY#6 z4#v4M09O4oSMoxCu4Al1FQS+T@1$6)xX-niwenDk$yeA-jI`ryS7n9Q<)PWjK7A-?+(Vj4wqM!# zluXlFxR?^g0LS;>e1l~vpkTLrEaHo1@eE^HQ7~6voIY0?UayGlKzl@VcX+QPe8GPE z_}W0U=R{9B0s6FK~2bAPsO0yAc$ zQp3`kk{QUvpe`&R;M@&%J^PNwXO$Z?@YOyFB;b~-DryK+ZoJdu-t!H&ZwzvC$P==9 z3E-9Q$+Ay_@oswCZv`FDLU7oq@giRhvOl(1MT&ylkUo*W)DZO+Qy5Cp3FvFWwQ_r{ z$egASNY<-5mykWU(gllJQQD$cWT=hSBa4fD37r>Pw+>qj`P5~~3-n&@`E2_Zep`~a zC8&MUjO&lDk1IuiY!|z=VaJ``LqFHX$6J8IDy@_L$35R(7teTE`EDl0m0Uxox=x`l z7pCTum!hBMv<`G5>cx^_{HtU0HLg$kl=w$H1Yp*eK3gp81->D9v5`qPyfQwr+#_&WrngR-G( zmhP|o_?bJUOc(jionpFv#jbvX$bjo<$*!@>s>ps zTtXHUMVdh7dPthLKz)l;TJJC{QN(GHIq1Dl zVsrM>tvnfD031lLO;V8($`ex-AEe_woR33msvz7Sy~>vwOWyM}TT_?wX_=ue6ygT9 z1=?koHci@mZk{?AS=Gb_Y;)hYPO!5eRCPXs`m{yo^sNL+@bECE^zsXDo|mte-?xsL z$+1~3*h@L9Ja6V+HvfQ%3QsxNFYHuB`THwz8|uf>lPm>l_|A&s*)jZWMm2 zXuVCrt@DEe^^1Gn;o*9bk@FHQ&n?KQP=sY+35nC?MC_gSTphvgZfX$naS=%LG%=-k zB07WM_$lE9^&@5`B$WqPM}_&fw38NDSs>;WFaab?p%j-)F*=Hb(as}~KC=()?ebU4 zrh^$27~lh>=XgOO)bpG~g)RwlB(1zH8*k>Z&yvCGeALB?cq6Q-3k_#aQ3G!5_;IIY zBy%g({*zYi2Gm)g?u_@-^SC9xn_z^8uAGAOlGo!qHrJRpA;4f}qkb%B$ZKrOe}@4+e(3k(eBU+jXUy6I$-PhKSz=oE6ugc+A;d3G&l zaC6yEQQ}r7!vD;yYbN*syw_z|_F)I$&Xd%23E%i6NDhi~C%R08IN;W$`Q!aJVPrTB z($$R*c`=WjBm`Wo zm3#kw#-B|vFCRWxLoX$L+C1%B;rzuwv~-#~iU(S(jpj*|U5SK6E1?G zLtA81I*53;YqKVFwF@&>db})HYdx{++d0;^oo%|EW8%Sd@%gI86eINWC7Ne<$88=S zeID@iblsp=3v^P8`1@j<9jaJK(Dh91g=UT2SWdS2cf2()Po$LE9MPAE{X@rsWhd{4 zg+hzjeg_czH;AaHN=_r3z2oev6b%%g!y6C*C55&PMA|zg0_Zk8h`gTVrzv`5rjzGB zXlPHky6d^F8>f~h3RWun#SOF)A0srk*af zJ=!8eR-rjU&ivXNoVCnfus{@j*M5?VK;qaUkL zPG7JXh=(>esD2&s;x#K`mo>{kepI{3LA5ETV`iZnT5(qozS-@9Lx1ldZL4G#IBK@v zp;YyW5we3^wyFKzf8BpH$p6^ROS-bw$+8D`Xx`Ldf0_%n;EY;n=n!|M%AyqpFFfP& zKNf9N@=uGcWT^6dW_B=0^KjB)4|!K^(-9H4y;tZ;Ar%p8W_uQK>;kM zf0Eq9fqcwlH7$w58}N^i|9eDpvfC-y1YC_QNhYfm$bwNzAA@71mmc-TC6 zP3h+CMHw!>_D!9@>K=WgZ&}5v<5qYvS!PVGpkW*U4f8WHBd?Te zK@h@>RSM5+d+%-6$j}cf+I3{fF-*G>Y?DA;fj&Yc2gqx05l_#9{wgx3P?My)NUG@; zN$5$W^BmLJD-Je|gcF&1YvrI7oijCQll2y<80c3{YV6K)AQd1|Wl;w^F{={8}YTf_Jd7K{=L+%M%*fd8h0yYSTzk&S1F=JY?IRB<02 zNs6cqhK8DQf8q`Ku?PauQR8|Ct(WVwZ9Kf|KQ=VCX3k{|*BNhF5{i(Q%Gn22XRU=@eKwL1FKjE2uFIXrbyIA0a@~5@XvVz_{pgJ0K1Ek53sLn3ZO{;q$30l&<+&-gq0|of1Tz}A~S#>84V5v(l z+Bz_xn}pwxGSI3zC~4f9Ss@5`#2D)?y8IaSA|4Q4H7m0u&BzZepqnW-V0|uci&<_Y z4P_3`ZNr%lgXKV9AI@c6nGdmuMtJ8dqfNUqql6%~K{n{0Muv1AjoWre7l9IgDWxBY zGxQPq%rsUrg6ac6X1=FirKC5LAh@y%owX#FJ)j*Q0>Yv=jANsbN=d;HKk8ffq!2$# z@0uzjf~G5Y7-_6r3~#=!r1sU2Cdx63pCHz>Px@&75gs)kSgaZ25k$$qy7n9kFz1{B z&+VMz(BU<9#zGhpdR;b`5~@z)Jszu;-lYvUvWq0}fwtiDRGoKr#Y8$~aY^#CSxxYz zR18~cxGNt}tbk9dLI`@ABH2NL{7$m|_KGujCK9PMI-YT6F{>&VOE8xj4yH&cS@$~b zwN$hDUpa$UnZ3oz(zEMHy(a1>S{v}-N$;Faeand{*G@R5QD1_aM}75PSktedlo7FL z<#0Cc%j0x&?YG$V6@I&O*zL7*~v-F3;T-KO9QqFIXaMd zOK7B1BJ%H7Vo@IsM#25tUC)5(vIotao8KIvi~l_`fWj&U-W7Y{1TO1otH-=hQd(vjC*u<{%lyM30O!D_>LoJ}Ghp$E?4}^Op0+7dxoeip0 z%;`WA-haGzYt$z59>iub;vsnaeMD`x+wZH1S%AR%m%NnY&IK}lYQiV$7oyL8wkE8+ z8NReD?6E(SYLG?LJBj!Fri+sPF6U?a*_NYT6Z!&i4af52_2H(P(+s#mk$Jj^cgQ5u1P?e*uaX^fVEprZtbPk<&OysebkO*e886z zqK&}m{7BBuzA~lhpoo5%VuOQ+#cY|Mb8Xu*Eb#61TXd)EY+L68c^)wpTEKjmmr6wa zZ(liaGN#G`bl5l2^xu@-|K%ZQ8d21~j6X2g|Hwa!MSxeG%JwpF1pe9oM93g80Qbj_ zaC~cM_dg0zubaTY + +namespace legs { + at::Tensor euler_forward(const torch::Tensor& mem, const torch::Tensor& input, const float dt); + at::Tensor euler_backward(const torch::Tensor& mem, const torch::Tensor& input, const float dt); + at::Tensor trapezoidal(const torch::Tensor& mem, const torch::Tensor& input, const float dt); + at::Tensor function_approx_trapezoidal(const torch::Tensor& input, const int memorder); +} + +namespace legt { + at::Tensor euler_forward(const torch::Tensor& mem, const torch::Tensor& input, const float dt); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("legs_euler_forward", &legs::euler_forward, "Euler forward for Hippo-LegS"); + m.def("legs_euler_backward", &legs::euler_backward, "Euler backward for Hippo-LegS"); + m.def("legs_trapezoidal", &legs::trapezoidal, "Trapezoidal for Hippo-LegS"); + m.def("legs_function_approx_trapezoidal", &legs::function_approx_trapezoidal, "Function approx trapezoidal for Hippo-LegS"); + + m.def("legt_euler_forward", &legt::euler_forward, "Euler forward for Hippo-LegT"); +} diff --git a/csrc/hippolegs.cpp b/csrc/hippolegs.cpp new file mode 100644 index 0000000..ee10ecf --- /dev/null +++ b/csrc/hippolegs.cpp @@ -0,0 +1,167 @@ +#include +#include +#include +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCPU, #x " must be on CPU") + +namespace legs { + +at::Tensor euler_forward(const torch::Tensor& mem, const torch::Tensor& input, const float dt) { + /* newmem = (I + dt A) mem + dt B input + Parameters: + mem: (batch_size, memsize, memorder) + input: (batch_size, memsize) + dt: float + Returns: + newmem: (batch_size, memsize, memorder) + */ + const auto batch_size = mem.size(0); + const auto memsize = mem.size(1); + const auto N = mem.size(2); + TORCH_CHECK(mem.dim() == 3, "legs::euler_forward: mem must have dimension 3"); + TORCH_CHECK(input.dim() == 2, "legs::euler_forward: input must have dimension 2"); + CHECK_DEVICE(mem); + CHECK_DEVICE(input); + auto newmem = torch::empty_like(mem); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(mem.scalar_type(), "legs::euler_forward", [&] { + const auto mem_a = mem.accessor(); + const auto input_a = input.accessor(); + const scalar_t dt_a = dt; + auto newmem_a = newmem.accessor(); + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t msz = 0; msz < memsize; ++msz) { + scalar_t input_val_dt = input_a[b][msz] * dt_a; + scalar_t cumsum = 0; + for (int64_t n = 0; n < N; ++n) { + scalar_t x = mem_a[b][msz][n]; + scalar_t sqrt_scale = std::sqrt(2 * n + 1); + // cumsum += x / sqrt_scale * (2 * n + 1); + // newmem_a[b][msz][n] = x - dt_a * (cumsum - x / sqrt_scale * n) * sqrt_scale; + newmem_a[b][msz][n] = x - dt_a * (cumsum * sqrt_scale + x * (n + 1)) + input_val_dt * sqrt_scale; + cumsum += x * sqrt_scale; + } + } + } + }); + return newmem; +} + +at::Tensor euler_backward(const torch::Tensor& mem, const torch::Tensor& input, const float dt) { + /* newmem = (I - dt A)^{-1} (mem + dt B input) + Parameters: + mem: (batch_size, memsize, memorder) + input: (batch_size, memsize) + dt: float + Returns: + newmem: (batch_size, memsize, memorder) + */ + const auto batch_size = mem.size(0); + const auto memsize = mem.size(1); + const auto N = mem.size(2); + TORCH_CHECK(mem.dim() == 3, "legs::euler_backward: mem must have dimension 3"); + TORCH_CHECK(input.dim() == 2, "legs::euler_backward: input must have dimension 2"); + CHECK_DEVICE(mem); + CHECK_DEVICE(input); + auto newmem = torch::empty_like(mem); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(mem.scalar_type(), "legs::euler_backward", [&] { + const auto mem_a = mem.accessor(); + const auto input_a = input.accessor(); + const scalar_t dt_a = dt; + auto newmem_a = newmem.accessor(); + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t msz = 0; msz < memsize; ++msz) { + scalar_t input_val_dt = input_a[b][msz] * dt_a; + scalar_t cumsum = 0; + for (int64_t n = 0; n < N; ++n) { + scalar_t sqrt_scale = std::sqrt(2 * n + 1); + scalar_t x = mem_a[b][msz][n] + input_val_dt * sqrt_scale; + scalar_t y = (x - dt_a * cumsum * sqrt_scale) / (1 + (n + 1) * dt_a); + newmem_a[b][msz][n] = y; + cumsum += y * sqrt_scale; + } + } + } + }); + return newmem; +} + +at::Tensor trapezoidal(const torch::Tensor& mem, const torch::Tensor& input, const float dt) { + /* newmem = (I - dt/2 A)^{-1} ((I + dt/2 A) mem + dt B input) + Parameters: + mem: (batch_size, memsize, memorder) + input: (batch_size, memsize) + dt: float + Returns: + newmem: (batch_size, memsize, memorder) + */ + const auto batch_size = mem.size(0); + const auto memsize = mem.size(1); + const auto N = mem.size(2); + TORCH_CHECK(mem.dim() == 3, "legs::trapezoidal: mem must have dimension 3"); + TORCH_CHECK(input.dim() == 2, "legs::trapezoidal: input must have dimension 2"); + CHECK_DEVICE(mem); + CHECK_DEVICE(input); + auto newmem = torch::empty_like(mem); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(mem.scalar_type(), "legs::trapezoidal", [&] { + const auto mem_a = mem.accessor(); + const auto input_a = input.accessor(); + const scalar_t dt_a = dt; + auto newmem_a = newmem.accessor(); + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t msz = 0; msz < memsize; ++msz) { + scalar_t input_val_dt = input_a[b][msz] * dt_a; + scalar_t cumsum_fwd = 0; + scalar_t cumsum_bwd = 0; + for (int64_t n = 0; n < N; ++n) { + scalar_t x = mem_a[b][msz][n]; + scalar_t sqrt_scale = std::sqrt(2 * n + 1); + scalar_t out_fwd = x - dt_a / 2 * (cumsum_fwd * sqrt_scale + x * (n + 1)) + input_val_dt * sqrt_scale; + cumsum_fwd += x * sqrt_scale; + scalar_t y = (out_fwd - dt_a / 2 * cumsum_bwd * sqrt_scale) / (1 + (n + 1) * dt_a / 2); + newmem_a[b][msz][n] = y; + cumsum_bwd += y * sqrt_scale; + } + } + } + }); + return newmem; +} + +at::Tensor function_approx_trapezoidal(const torch::Tensor& input, const int memorder) { + /* + Parameters: + input: (length, ) + memorder: int + Returns: + mem: (memorder, ) + */ + const auto length = input.size(0); + const auto N = memorder; + TORCH_CHECK(input.dim() == 1, "legs::function_approx_trapezoidal: input must have dimension 1"); + CHECK_DEVICE(input); + auto mem = torch::zeros({N}, torch::dtype(input.dtype()).device(input.device())); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "legs::function_approx_trapezoidal", [&] { + auto mem_a = mem.accessor(); + const auto input_a = input.accessor(); + mem_a[0] = input_a[0]; + for (int64_t t = 1; t < length; ++t) { + const scalar_t dt = 1.0 / t; + scalar_t input_val_dt = input_a[t] * dt; + scalar_t cumsum_fwd = 0; + scalar_t cumsum_bwd = 0; + for (int64_t n = 0; n < N; ++n) { + scalar_t x = mem_a[n]; + scalar_t sqrt_scale = std::sqrt(2 * n + 1); + scalar_t out_fwd = x - dt / 2 * (cumsum_fwd * sqrt_scale + x * (n + 1)) + input_val_dt * sqrt_scale; + cumsum_fwd += x * sqrt_scale; + scalar_t y = (out_fwd - dt / 2 * cumsum_bwd * sqrt_scale) / (1 + (n + 1) * dt / 2); + mem_a[n] = y; + cumsum_bwd += y * sqrt_scale; + } + } + }); + return mem; +} + +} // legs diff --git a/csrc/hippolegt.cpp b/csrc/hippolegt.cpp new file mode 100644 index 0000000..79d6bd4 --- /dev/null +++ b/csrc/hippolegt.cpp @@ -0,0 +1,59 @@ +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCPU, #x " must be on CPU") + +namespace legt { + +at::Tensor euler_forward(const torch::Tensor& mem, const torch::Tensor& input, const float dt) { + /* newmem = (I + dt A) mem + dt B input + Parameters: + mem: (batch_size, memsize, memorder) + input: (batch_size, memsize) + dt: float + Returns: + newmem: (batch_size, memsize, memorder) + */ + const auto batch_size = mem.size(0); + const auto memsize = mem.size(1); + const auto N = mem.size(2); + TORCH_CHECK(mem.dim() == 3, "legt::euler_forward: mem must have dimension 3"); + TORCH_CHECK(input.dim() == 2, "legt::euler_forward: input must have dimension 2"); + CHECK_DEVICE(mem); + CHECK_DEVICE(input); + auto newmem = torch::empty_like(mem); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(mem.scalar_type(), "hippolegt::euler_forward", [&] { + const auto mem_a = mem.accessor(); + const auto input_a = input.accessor(); + const scalar_t dt_a = dt; + auto newmem_a = newmem.accessor(); + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t msz = 0; msz < memsize; ++msz) { + scalar_t sum = 0; + for (int64_t n = 0; n < N; ++n) { + sum += mem_a[b][msz][n]; + } + scalar_t input_val_dt = input_a[b][msz] * dt_a; + scalar_t cumsum_even = 0, cumsum_odd = 0; + for (int64_t i = 0; i < N / 2; ++i) { + int64_t n_even = 2 * i; + scalar_t x_even = mem_a[b][msz][n_even]; + newmem_a[b][msz][n_even] = x_even + (dt_a * (-sum + 2 * cumsum_odd) + input_val_dt) * (2 * n_even + 1); + cumsum_even += x_even; + int64_t n_odd = 2 * i + 1; + scalar_t x_odd = mem_a[b][msz][n_odd]; + newmem_a[b][msz][n_odd] = x_odd + (dt_a * (-sum + 2 * cumsum_even) - input_val_dt) * (2 * n_odd + 1); + cumsum_odd += x_odd; + } + if (N % 2 == 1) { // Last element if there's an extra one + int64_t n_even = N - 1; + scalar_t x_even = mem_a[b][msz][n_even]; + newmem_a[b][msz][n_even] = x_even + (dt_a * (-sum + 2 * cumsum_odd) + input_val_dt) * (2 * n_even + 1); + } + } + } + }); + return newmem; +} + +} // legt + diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 0000000..0f0882f --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup +from torch.utils.cpp_extension import CppExtension, BuildExtension + +ext_modules = [] +extension = CppExtension('hippo', ['hippo.cpp', 'hippolegs.cpp', 'hippolegt.cpp'], extra_compile_args=['-march=native']) +ext_modules.append(extension) + +setup( + name='hippo', + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}) diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..112f346 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,232 @@ +import os +dir_path = os.path.dirname(os.path.abspath(__file__)) + +import random + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import datasets, transforms + +from . import copying, adding +from . import utils +from .tasks import BinaryClassification, MulticlassClassification, MSERegression + + +class DatasetBase(): + registry = {} + + # https://www.python.org/dev/peps/pep-0487/#subclass-registration + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Only register classes with @name attribute + if hasattr(cls, 'name'): + cls.registry[cls.name] = cls + + def __init__(self, dataset_cfg, path=dir_path): + self.dataset_cfg = dataset_cfg + self.path = path + + def prepare_data(self): + raise NotImplementedError + + def split_train_val(self, ratio=0.9): + train_len = int(len(self.train) * ratio) + self.train, self.val = torch.utils.data.random_split(self.train, (train_len, len(self.train) - train_len)) + + def prepare_dataloader(self, batch_size, **kwargs): + self.train_loader = torch.utils.data.DataLoader(self.train, batch_size=batch_size, shuffle=True, **kwargs) + self.val_loader = torch.utils.data.DataLoader(self.val, batch_size=batch_size, shuffle=False, **kwargs) + self.test_loader = torch.utils.data.DataLoader(self.test, batch_size=batch_size, shuffle=False, **kwargs) + + def __str__(self): + return self.name if hasattr(self, 'name') else self.__name__ + + +class MNIST(DatasetBase, MulticlassClassification): + name = 'mnist' + input_size = 1 + output_size = 10 + output_len = 0 + N = 784 + + def prepare_data(self): + transform_list = [transforms.ToTensor(), + transforms.Lambda(lambda x: x.view(self.input_size, self.N).t())] # (N, input_size) + if self.dataset_cfg.permute: + # below is another permutation that other works have used + # permute = np.random.RandomState(92916) + # permutation = torch.LongTensor(permute.permutation(784)) + permutation = utils.bitreversal_permutation(self.N) + transform_list.append(transforms.Lambda(lambda x: x[permutation])) + transform = transforms.Compose(transform_list) + self.train = datasets.MNIST(f'{self.path}/{self.name}', train=True, download=True, transform=transform) + self.test = datasets.MNIST(f'{self.path}/{self.name}', train=False, transform=transform) + self.split_train_val() + + def __str__(self): + return f"{'p' if self.dataset_cfg.permute else 's'}{self.name}" + + +class Copying(DatasetBase, MulticlassClassification): + name = 'copying' + + def __init__(self, dataset_cfg, path=dir_path): + super().__init__(dataset_cfg, path) + self.input_size = dataset_cfg.A + self.output_size = dataset_cfg.A + self.output_len = dataset_cfg.M + self.N = dataset_cfg.L + 2 * dataset_cfg.M + + def prepare_data(self): + cfg = self.dataset_cfg + self.train = copying.copying_static_dataset(cfg.L, cfg.M, cfg.A, cfg.variable, cfg.samples) + self.test = copying.copying_static_dataset(cfg.L, cfg.M, cfg.A, cfg.variable, cfg.test_samples) + self.split_train_val() + + def __str__(self): + return f"{self.name}{self.dataset_cfg.L}{'v' if self.dataset_cfg.variable else ''}" + + +class Adding(DatasetBase, MSERegression): + name = 'adding' + + def __init__(self, dataset_cfg, path=dir_path): + super().__init__(dataset_cfg, path) + self.input_size = 2 + self.output_size = 1 + self.output_len = 0 + self.N = dataset_cfg.L + + def prepare_data(self): + cfg = self.dataset_cfg + self.train = adding.adding_static_dataset(cfg.L, cfg.samples) + self.test = adding.adding_static_dataset(cfg.L, cfg.test_samples) + self.split_train_val() + + def __str__(self): + return f"{self.name}{self.dataset_cfg.L}" + + +# Wrap the data loader with callback function +class LoaderWCallback: + def __init__(self, loader, callback_fn): + self.loader = loader + self.callback_fn = callback_fn + + def __len__(self): + return len(self.loader) + + def __iter__(self): + self.loader_iter = iter(self.loader) + return self + + def __next__(self): + return self.callback_fn(next(self.loader_iter)) + + +class IMDB(DatasetBase, BinaryClassification): + name = 'imdb' + output_size = 1 + output_len = 0 + + def __init__(self, dataset_cfg, path=dir_path): + super().__init__(dataset_cfg, path) + self.input_size = dataset_cfg.vocab_size + self.N = dataset_cfg.max_length + + # https://github.com/bentrevett/pytorch-sentiment-analysis/issues/6 + def tokenize_once(self): + import torchtext + from torchtext import data + TEXT = data.Field(tokenize='spacy') + LABEL = data.LabelField() + train_data, test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL, root=f'{self.path}') + train_examples = [vars(t) for t in train_data] + test_examples = [vars(t) for t in test_data] + import json + with open(f'{self.path}/{self.name}/train.json', 'w+') as f: + for example in train_examples: + json.dump(example, f) + f.write('\n') + with open(f'{self.path}/{self.name}/test.json', 'w+') as f: + for example in test_examples: + json.dump(example, f) + f.write('\n') + + def prepare_data(self): + if not os.path.exists(f'{self.path}/{self.name}/train.json'): + self.tokenize_once() + import torchtext + from torchtext import data + TEXT = data.Field(batch_first=True, include_lengths=True) + LABEL = data.LabelField(dtype=torch.float) + fields = {'text': ('text', TEXT), 'label': ('label', LABEL)} + self.train, self.test = data.TabularDataset.splits( + path = f'{self.path}/{self.name}', + train = 'train.json', + test = 'test.json', + format = 'json', + fields = fields + ) + self.train, self.val = self.train.split(0.9) + TEXT.build_vocab(self.train, max_size=self.input_size - 2) # Need 2 extra for and + LABEL.build_vocab(self.train) + + def prepare_dataloader(self, batch_size, **kwargs): + from torchtext import data + self.train_loader, self.val_loader, self.test_loader = data.BucketIterator.splits( + (self.train, self.val, self.test), + shuffle=True, + sort_key=lambda ex: len(ex.text), + batch_size = batch_size) + + def postprocess(batch): # make the loader from torchtext compatible with Pytorch's loader + x, lens = batch.text + x = x[:self.N] + lens = torch.clamp(lens, max=self.N) + return x, batch.label, lens + + self.train_loader = LoaderWCallback(self.train_loader, postprocess) + self.val_loader = LoaderWCallback(self.val_loader, postprocess) + self.test_loader = LoaderWCallback(self.test_loader, postprocess) + + +class CharacterTrajectories(DatasetBase, MulticlassClassification): + """ CharacterTrajectories dataset from the UCI Machine Learning archive. + + See datasets.uea.postprocess_data for dataset configuration settings. + """ + name = 'ct' + input_size = 3 + output_size = 20 + output_len = 0 + + + def __init__(self, dataset_cfg, path=dir_path): + super().__init__(dataset_cfg, path) + if self.dataset_cfg.timestamp: + self.input_size += 1 + + def prepare_data(self): + from datasets import uea + + cfg = self.dataset_cfg + *data, num_classes, input_channels = uea.get_data( + 'CharacterTrajectories', + intensity=False, + ) + train_dataset, val_dataset, test_dataset = uea.postprocess_data( + *data, + train_hz=cfg.train_hz, + eval_hz=cfg.eval_hz, + train_uniform=cfg.train_uniform, + eval_uniform=cfg.eval_uniform, + timestamp=cfg.timestamp, + train_ts=cfg.train_ts, + eval_ts=cfg.eval_ts, + ) + self.train = train_dataset + self.val = val_dataset + self.test = test_dataset + assert num_classes == self.output_size, f"Output size should be {num_classes}" diff --git a/datasets/adding.py b/datasets/adding.py new file mode 100644 index 0000000..349de41 --- /dev/null +++ b/datasets/adding.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch.utils.data.dataset import IterableDataset +import numpy as np + + +def torch_adding_data(L, batch_shape=()): + assert L >= 2 + mid = L//2 + idx0 = torch.randint(low=0, high=mid, size=batch_shape) + idx1 = torch.randint(low=0, high=L-mid, size=batch_shape) + + idx = torch.cat((F.one_hot(idx0, mid), F.one_hot(idx1, L-mid)), dim=-1).float() # (batch_shape, L) + unif = torch.empty(batch_shape+(L,)) + unif.uniform_(0., 1.) + + x = torch.stack((unif, idx), dim=-1) # (batch_shape, L, 2) + y = torch.sum(unif*idx, dim=-1, keepdim=True) # (batch_shape, 1) + + return x, y + +def adding_static_dataset(L, samples): + all_x, all_y = torch_adding_data(L, batch_shape=(samples,)) + print("Constructing Adding dataset of shape", all_x.shape) + ds = torch.utils.data.TensorDataset(all_x, all_y) + return ds + diff --git a/datasets/copying.py b/datasets/copying.py new file mode 100644 index 0000000..45c62a4 --- /dev/null +++ b/datasets/copying.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch.utils.data.dataset import IterableDataset +import numpy as np + + +def np_copying_data(L, M, A, batch_shape=()): + seq = np.random.randint(low=1, high=A-1, size=batch_shape+(M,)) + zeros_x = np.zeros(batch_shape+(L,)) + markers = (A-1) * np.ones(batch_shape+(M,)) + zeros_y = np.zeros(batch_shape+(M+L,)) + + x_ = np.concatenate([seq, zeros_x, markers], axis=-1) + y_ = np.concatenate([zeros_y, seq], axis=-1) + x = F.one_hot(torch.tensor(x_, dtype=torch.int64), A).float() + y = torch.tensor(y_, dtype=torch.int64) + return x, y + +def torch_copying_data(L, M, A, variable=False, batch_shape=()): + tokens = torch.randint(low=1, high=A-1, size=batch_shape+(M,)) + if variable: + total_batch = np.prod(batch_shape) + inds = torch.stack([ + torch.randperm(L+M)[:M] + for _ in range(total_batch) + ], 0) + inds = inds.reshape(batch_shape+(M,)) + inds, _ = inds.sort() + else: + inds = torch.arange(M).repeat(batch_shape+(1,)) + zeros_x = torch.zeros(batch_shape+(M+L,), dtype=torch.long) + zeros_x.scatter_(-1, inds, tokens) + markers = (A-1) * torch.ones(batch_shape+(M,), dtype=torch.long) + + x_ = torch.cat([zeros_x, markers], dim=-1) + y_ = torch.cat([tokens], dim=-1) + x = F.one_hot(x_, A).float() + y = y_ + return x, y + + +def copying_static_dataset(L, M, A, variable, samples): + all_x, all_y = torch_copying_data(L, M, A, variable, batch_shape=(samples,)) + print("Constructing Copying dataset of shape", all_x.shape) + ds = torch.utils.data.TensorDataset(all_x, all_y) + return ds diff --git a/datasets/tasks.py b/datasets/tasks.py new file mode 100644 index 0000000..7d56ae4 --- /dev/null +++ b/datasets/tasks.py @@ -0,0 +1,60 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Task: + @staticmethod + def metrics(outs, y, len_batch=None): + return {} + + @staticmethod + def metrics_epoch(outs, y, len_batch=None): + return {} + + +class BinaryClassification(Task): + @staticmethod + def loss(logits, y, len_batch=None): + # BCE loss requires squeezing last dimension of logits so it has the same shape as y + return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float()) + + @staticmethod + def metrics(logits, y, len_batch=None): + return {'accuracy': torch.eq(logits.squeeze(-1) >= 0, y).float().mean()} + + @staticmethod + def metrics_epoch(logits, y, len_batch=None): + return BinaryClassification.metrics(torch.cat(logits), torch.cat(y), len_batch) + + + +class MulticlassClassification(Task): + @staticmethod + def loss(logits, y, len_batch=None): + return F.cross_entropy(logits, y) + + @staticmethod + def metrics(logits, y, len_batch=None): + return {'accuracy': torch.eq(torch.argmax(logits, dim=-1), y).float().mean()} + + @staticmethod + def metrics_epoch(logits, y, len_batch=None): + return MulticlassClassification.metrics(torch.cat(logits, dim=0), torch.cat(y, dim=0), len_batch) + + +class MSERegression(Task): + @staticmethod + def loss(outs, y, len_batch=None): + if len_batch is None: + return F.mse_loss(outs, y) + else: + # Computes the loss of the first `lens` items in the batches + mask = torch.zeros_like(outs, dtype=torch.bool) + for i, l in enumerate(len_batch): + mask[i, :l, :] = 1 + outs_masked = torch.masked_select(outs, mask) + y_masked = torch.masked_select(y, mask) + return F.mse_loss(outs_masked, y_masked) + + diff --git a/datasets/uea.py b/datasets/uea.py new file mode 100644 index 0000000..c93fd7a --- /dev/null +++ b/datasets/uea.py @@ -0,0 +1,390 @@ +""" Load data for UEA datasets, in particular CharacterTrajectories + +Adapted from https://github.com/patrick-kidger/NeuralCDE/blob/master/experiments/datasets/uea.py +""" +import os +import pathlib +import urllib.request +import zipfile +import sklearn.model_selection +import sktime.utils.load_data +import numpy as np +import torch +import collections as co + +# TODO deal with this path properly as an option +here = pathlib.Path(__file__).resolve().parent +valid_dataset_names = { + 'ArticularyWordRecognition', + 'FaceDetection', + 'NATOPS', + 'AtrialFibrillation', + 'FingerMovements', + 'PEMS - SF', + 'BasicMotions', + 'HandMovementDirection', + 'PenDigits', + 'CharacterTrajectories', + 'Handwriting', + 'PhonemeSpectra', + 'Cricket', + 'Heartbeat', + 'RacketSports', + 'DuckDuckGeese', + 'InsectWingbeat', + 'SelfRegulationSCP1', + 'EigenWorms', + 'JapaneseVowels', + 'SelfRegulationSCP2', + 'Epilepsy', + 'Libras', + 'SpokenArabicDigits', + 'ERing', + 'LSST', + 'StandWalkJump', + 'EthanolConcentration', + 'MotorImagery', + 'UWaveGestureLibrary', +} + +def download(): + """ Download data if not exists """ + base_base_loc = here / 'data' + base_loc = base_base_loc / 'UEA' + loc = base_loc / 'Multivariate2018_ts.zip' + if os.path.exists(loc): + return + if not os.path.exists(base_base_loc): + os.mkdir(base_base_loc) + if not os.path.exists(base_loc): + os.mkdir(base_loc) + urllib.request.urlretrieve('http://www.timeseriesclassification.com/Downloads/Archives/Multivariate2018_ts.zip', + str(loc)) + + with zipfile.ZipFile(loc, 'r') as f: + f.extractall(str(base_loc)) + +def load_data(dataset_name): + """ Load X, y numpy data for given dataset """ + assert dataset_name in valid_dataset_names, "Must specify a valid dataset name." + + base_filename = here / 'data' / 'UEA' / 'Multivariate_ts' / dataset_name / dataset_name + train_X, train_y = sktime.utils.load_data.load_from_tsfile_to_dataframe(str(base_filename) + '_TRAIN.ts') + test_X, test_y = sktime.utils.load_data.load_from_tsfile_to_dataframe(str(base_filename) + '_TEST.ts') + train_X = train_X.to_numpy() + test_X = test_X.to_numpy() + X = np.concatenate((train_X, test_X), axis=0) + y = np.concatenate((train_y, test_y), axis=0) + return X, y + +def save_data(dir, **tensors): + for tensor_name, tensor_value in tensors.items(): + torch.save(tensor_value, str(dir / tensor_name) + '.pt') + +def load_processed_data(dir): + tensors = {} + for filename in os.listdir(dir): + if filename.endswith('.pt'): + tensor_name = filename.split('.')[0] + tensor_value = torch.load(str(dir / filename)) + tensors[tensor_name] = tensor_value + return tensors + +def wrap_data(train_X, val_X, test_X, train_y, val_y, test_y, train_final_index, val_final_index, + test_final_index, + ): + """ Wrap data into Pytorch Dataset. """ + + train_dataset = torch.utils.data.TensorDataset(train_X, train_y, + # train_final_index + ) + val_dataset = torch.utils.data.TensorDataset(val_X, val_y, + # val_final_index + ) + test_dataset = torch.utils.data.TensorDataset(test_X, test_y, + # test_final_index + ) + + return train_dataset, val_dataset, test_dataset + +def split_data(tensor, stratify): + # 0.7/0.15/0.15 train/val/test split + (train_tensor, testval_tensor, + train_stratify, testval_stratify) = sklearn.model_selection.train_test_split(tensor, stratify, + train_size=0.7, + random_state=0, + shuffle=True, + stratify=stratify) + + val_tensor, test_tensor = sklearn.model_selection.train_test_split(testval_tensor, + train_size=0.5, + random_state=1, + shuffle=True, + stratify=testval_stratify) + return train_tensor, val_tensor, test_tensor + + +def normalize_data(X, y): + """ Normalize data by training statistics per channel. + + X: data tensor with channels as last dimension + """ + train_X, _, _ = split_data(X, y) + out = [] + for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)): + train_Xi_nonan = train_Xi.masked_select(~torch.isnan(train_Xi)) + mean = train_Xi_nonan.mean() # compute statistics using only training data. + std = train_Xi_nonan.std() + out.append((Xi - mean) / (std + 1e-5)) + out = torch.stack(out, dim=-1) + return out + + +def preprocess_data( + X, y, + final_index, + # append_times, + append_intensity, + ): + X = normalize_data(X, y) + + # Append extra channels together. Note that the order here: time, intensity, original, is important, and some models + # depend on that order. + augmented_X = [] + # if append_times: + # augmented_X.append(times.unsqueeze(0).repeat(X.size(0), 1).unsqueeze(-1)) + if append_intensity: # Note this will append #channels copies of the same intensity + intensity = ~torch.isnan(X) # of size (batch, stream, channels) + intensity = intensity.to(X.dtype).cumsum(dim=1) + augmented_X.append(intensity) + augmented_X.append(X) + if len(augmented_X) == 1: + X = augmented_X[0] + else: + X = torch.cat(augmented_X, dim=2) + + train_X, val_X, test_X = split_data(X, y) # TODO split data should just return y? or list of indices corresponding to splits + train_y, val_y, test_y = split_data(y, y) + train_final_index, val_final_index, test_final_index = split_data(final_index, y) + + # train_coeffs = controldiffeq.natural_cubic_spline_coeffs(times, train_X) + # val_coeffs = controldiffeq.natural_cubic_spline_coeffs(times, val_X) + # test_coeffs = controldiffeq.natural_cubic_spline_coeffs(times, test_X) + + in_channels = X.size(-1) + + return ( + # times, + # train_coeffs, val_coeffs, test_coeffs, + train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, test_final_index, + in_channels + ) + +def process_data(dataset_name, intensity): + # We begin by loading both the train and test data and using our own train/val/test split. + # The reason for this is that (a) by default there is no val split and (b) the sizes of the train/test splits are + # really janky by default. (e.g. LSST has 2459 training samples and 2466 test samples.) + + + X, y = load_data(dataset_name) + + lengths = torch.tensor([len(Xi[0]) for Xi in X]) + final_index = lengths - 1 + maxlen = lengths.max() + # X is now a numpy array of shape (batch, channel) + # Each channel is a pandas.core.series.Series object of length corresponding to the length of the time series + def _pad(channel, maxlen): + channel = torch.tensor(channel) + out = torch.full((maxlen,), channel[-1]) + out[:channel.size(0)] = channel + return out + X = torch.stack([torch.stack([_pad(channel, maxlen) for channel in batch], dim=0) for batch in X], dim=0) + # X is now a tensor of shape (batch, channel, length) + X = X.transpose(-1, -2) + # X is now a tensor of shape (batch, length, channel) + times = torch.linspace(0, X.size(1) - 1, X.size(1)) + + + # generator = torch.Generator().manual_seed(56789) + # for Xi in X: + # removed_points = torch.randperm(X.size(1), generator=generator)[:int(X.size(1) * missing_rate)].sort().values + # Xi[removed_points] = float('nan') + + # Now fix the labels to be integers from 0 upwards + targets = co.OrderedDict() + counter = 0 + for yi in y: + if yi not in targets: + targets[yi] = counter + counter += 1 + y = torch.tensor([targets[yi] for yi in y]) + + + (train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, + test_final_index, + input_channels) = preprocess_data( + X, y, final_index, + # append_times=True, + append_intensity=intensity, + ) + + num_classes = counter + + assert num_classes >= 2, f"Have only {num_classes} classes." + + return ( + # times, + train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, test_final_index, + num_classes, input_channels + ) + +def get_data( + dataset_name, + intensity, + train_hz=1, + eval_hz=1, + timestamp=False, + train_ts=1, + eval_ts=1, + ): + # We begin by loading both the train and test data and using our own train/val/test split. + # The reason for this is that (a) by default there is no val split and (b) the sizes of the train/test splits are + # really janky by default. (e.g. LSST has 2459 training samples and 2466 test samples.) + + assert dataset_name in valid_dataset_names, "Must specify a valid dataset name." + + base_base_loc = here / 'processed_data' + base_loc = base_base_loc / 'UEA' + loc = base_loc / (dataset_name + ('_intensity' if intensity else '')) + try: + tensors = load_processed_data(loc) + train_X = tensors['train_X'] + val_X = tensors['val_X'] + test_X = tensors['test_X'] + train_y = tensors['train_y'] + val_y = tensors['val_y'] + test_y = tensors['test_y'] + train_final_index = tensors['train_final_index'] + val_final_index = tensors['val_final_index'] + test_final_index = tensors['test_final_index'] + num_classes = int(tensors['num_classes']) + input_channels = int(tensors['input_channels']) + except: + print(f"Could not find preprocessed data. Loading {dataset_name}...") + download() # download the UEA data if necessary + if not os.path.exists(base_base_loc): + os.mkdir(base_base_loc) + if not os.path.exists(base_loc): + os.mkdir(base_loc) + if not os.path.exists(loc): + os.mkdir(loc) + ( train_X, val_X, test_X, train_y, val_y, test_y, train_final_index, val_final_index, + test_final_index, num_classes, input_channels ) = process_data(dataset_name, intensity) + save_data( + loc, + train_X=train_X, val_X=val_X, test_X=test_X, + train_y=train_y, val_y=val_y, test_y=test_y, train_final_index=train_final_index, + val_final_index=val_final_index, test_final_index=test_final_index, + num_classes=torch.as_tensor(num_classes), input_channels=torch.as_tensor(input_channels), + ) + + return ( + train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, test_final_index, + num_classes, input_channels, + ) + + +def _subsample(X, hz=1, uniform=True): + """ Subsample X non-uniformly at hz frequency, append timestamps """ + L = X.shape[1] + # create subsampler + if uniform: + removed_points = torch.arange(int(L*hz)) // hz + removed_points = removed_points.to(int) + time_gen = lambda: removed_points + else: + generator = torch.Generator().manual_seed(56789) + time_gen = lambda: torch.randperm(L, generator=generator)[:int(L*hz)].sort().values + + X_ = [] + T_ = [] + for Xi in X: + times = time_gen() + Xi_ = Xi[times] + times_ = times.to(torch.float32).unsqueeze(-1) + X_.append(Xi_) + T_.append(times_) + return torch.stack(X_, dim=0), torch.stack(T_, dim=0) + +def postprocess_data( + train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, test_final_index, + train_hz=1, + eval_hz=1, + train_uniform=True, + eval_uniform=True, + timestamp=False, + train_ts=1, + eval_ts=1, + ): + """ + train_hz, eval_hz: subsampling multiplier of original data + e.g. train_hz=0.5 means data is sampled at half speed, so remove every other element of the sequence + Since the original data is sampled from a trajectory at 200Hz, this corresponds to a sampling rate of 100Hz + train_uniform, eval_uniform: whether subsampling is uniformly spaced or random + timestamp: data comes with timestamps + train_ts, eval_ts: timestamp multiplier + + Example configurations: + train_hz=1.0, eval_hz=0.5, {train,eval}_uniform=True, timestamp=False + - non-timestamped, uniformly sampled data, where evaluation sequences have every other element removed + + {train,eval}_uniform=False, timestamp=True, train_ts=1.0, eval_ts=0.5 + - timestamped, randomly sampled data, where evaluation sequences have timestamps halved + + Both of the above configurations test train->evaluation generalization of halving the timescale frequency, either from the measurement sampling rate decreasing (from 200Hz -> 100hz), or the subject drawing half as fast. + """ + + + train_X, train_T = _subsample(train_X, train_hz, train_uniform) + val_X, val_T = _subsample(val_X, eval_hz, eval_uniform) + test_X, test_T = _subsample(test_X, eval_hz, eval_uniform) + + if timestamp: + train_X = torch.cat([train_ts*train_T, train_X], dim=-1) + val_X = torch.cat([eval_ts*val_T, val_X], dim=-1) + test_X = torch.cat([eval_ts*test_T, test_X], dim=-1) + + train_dataset, val_dataset, test_dataset = wrap_data( + train_X, val_X, test_X, + train_y, val_y, test_y, + train_final_index, val_final_index, test_final_index + ) + return train_dataset, val_dataset, test_dataset + + +if __name__ == '__main__': + *data, numclasses, input_channels = get_data( + 'CharacterTrajectories', + intensity=False, + ) + + train_dataset, val_dataset, test_dataset = postprocess_data( + *data, + train_hz=1, + eval_hz=0.5, + train_uniform=True, + eval_uniform=False, + timestamp=True, + train_ts=1, + eval_ts=0.5, + ) diff --git a/datasets/utils.py b/datasets/utils.py new file mode 100644 index 0000000..f7ed3f1 --- /dev/null +++ b/datasets/utils.py @@ -0,0 +1,48 @@ +import math +import numpy as np + +import torch + + +def bitreversal_po2(n): + m = int(math.log(n)/math.log(2)) + perm = np.arange(n).reshape(n,1) + for i in range(m): + n1 = perm.shape[0]//2 + perm = np.hstack((perm[:n1],perm[n1:])) + return perm.squeeze(0) + +def bitreversal_permutation(n): + m = int(math.ceil(math.log(n)/math.log(2))) + N = 1 << m + perm = bitreversal_po2(N) + return np.extract(perm < n, perm) + + +# For language modeling +# Adapted from https://github.com/salesforce/awd-lstm-lm/blob/master/utils.py + +def repackage_hidden(h): + """Wraps hidden states in new Tensors, + to detach them from their history.""" + if isinstance(h, torch.Tensor): + return h.detach() + else: + return tuple(repackage_hidden(v) for v in h) + + +def batchify(data, bsz): + # Work out how cleanly we can divide the dataset into bsz parts. + nbatch = data.size(0) // bsz + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, nbatch * bsz) + # Evenly divide the data across the bsz batches. + data = data.view(bsz, -1).t().contiguous() + return data + + +def get_batch(source, i, seq_len): + seq_len = min(seq_len, len(source) - 1 - i) + data = source[i:i+seq_len].t() + target = source[i+1:i+1+seq_len].t().reshape(-1) + return data, target diff --git a/model/components.py b/model/components.py new file mode 100644 index 0000000..0b6e225 --- /dev/null +++ b/model/components.py @@ -0,0 +1,86 @@ +from functools import partial +import torch +import torch.nn as nn + +from model.exprnn.orthogonal import modrelu + +def get_activation(activation, size): + if activation == 'id': + return nn.Identity() + elif activation == 'tanh': + return torch.tanh + elif activation == 'relu': + return torch.relu + elif activation == 'sigmoid': + return torch.sigmoid + elif activation == 'modrelu': + return Modrelu(size) + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + + +def get_initializer(name, activation): + if activation in ['id', 'identity', 'linear', 'modrelu']: + nonlinearity = 'linear' + elif activation in ['relu', 'tanh', 'sigmoid']: + nonlinearity = activation + else: + assert False, f"get_initializer: activation {activation} not supported" + if name == 'uniform': + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == 'normal': + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == 'xavier': + initializer = torch.nn.init.xavier_normal_ + elif name == 'zero': + initializer = partial(torch.nn.init.constant_, val=0) + elif name == 'one': + initializer = partial(torch.nn.init.constant_, val=1) + else: + assert False, f"get_initializer: initializer type {name} not supported" + + return initializer + + + +class Modrelu(modrelu): + def reset_parameters(self): + self.b.data.uniform_(-0.0, 0.0) + + +def Linear_(input_size, output_size, bias, init='normal', zero_bias_init=False, **kwargs): + """ Returns a nn.Linear module with initialization options """ + l = nn.Linear(input_size, output_size, bias=bias, **kwargs) + get_initializer(init, 'linear')(l.weight) + if bias and zero_bias_init: + nn.init.zeros_(l.bias) + return l + + +class Gate(nn.Module): + """ Implements gating mechanisms. + + Mechanisms: + N - No gate + G - Standard sigmoid gate + """ + def __init__(self, size, preact_ctor, preact_args, mechanism='N'): + super().__init__() + self.size = size + self.mechanism = mechanism + + if self.mechanism == 'N': + pass + elif self.mechanism == 'G': + self.W_g = preact_ctor(*preact_args) + else: + assert False, f'Gating type {self.mechanism} is not supported.' + + def forward(self, *inputs): + if self.mechanism == 'N': + return 1.0 + + if self.mechanism == 'G': + g_preact = self.W_g(*inputs) + g = torch.sigmoid(g_preact) + return g diff --git a/model/exprnn/expm32.py b/model/exprnn/expm32.py new file mode 100644 index 0000000..69780ed --- /dev/null +++ b/model/exprnn/expm32.py @@ -0,0 +1,315 @@ +# Downloaded from https://github.com/Lezcano/expRNN + +""" +Adaptation of expm and expm_frechet in numpy for torch +""" + +# +# Authors: Travis Oliphant, March 2002 +# Anthony Scopatz, August 2012 (Sparse Updates) +# Jake Vanderplas, August 2012 (Sparse Updates) +# + +from __future__ import division, print_function, absolute_import + +import math + +import numpy as np + +import torch +import scipy.special + +def _onenorm_matrix_power_nnm(A, p): + """ + Compute the 1-norm of a non-negative integer power of a non-negative matrix. + + Parameters + ---------- + A : a square ndarray or matrix or sparse matrix + Input matrix with non-negative entries. + p : non-negative integer + The power to which the matrix is to be raised. + + Returns + ------- + out : float + The 1-norm of the matrix power p of A. + + """ + # check input + if int(p) != p or p < 0: + raise ValueError('expected non-negative integer p') + p = int(p) + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError('expected A to be like a square matrix') + + # Explicitly make a column vector so that this works when A is a + # numpy matrix (in addition to ndarray and sparse matrix). + v = torch.ones((A.shape[0], 1), dtype=A.dtype, device=A.device) + M = A.t() + for _ in range(p): + v = M.mm(v) + return torch.max(v).item() + + +def _onenorm(A): + return torch.norm(A, 1).item() + + +def _ident_like(A): + return torch.eye(A.shape[0], A.shape[1], dtype=A.dtype, device=A.device) + +class _ExpmPadeHelper(object): + """ + Help lazily evaluate a matrix exponential. + + The idea is to not do more work than we need for high expm precision, + so we lazily compute matrix powers and store or precompute + other properties of the matrix. + + """ + def __init__(self, A): + """ + Initialize the object. + + Parameters + ---------- + A : a dense or sparse square numpy matrix or ndarray + The matrix to be exponentiated. + """ + self.A = A + self._A2 = None + self._A4 = None + self._A6 = None + self._A8 = None + self._A10 = None + self._d4_exact = None + self._d6_exact = None + self._d8_exact = None + self._d10_exact = None + self._d4_approx = None + self._d6_approx = None + self._d8_approx = None + self._d10_approx = None + self.ident = _ident_like(A) + + @property + def A2(self): + if self._A2 is None: + self._A2 = self.A.mm(self.A) + return self._A2 + + @property + def A4(self): + if self._A4 is None: + self._A4 = self.A2.mm(self.A2) + return self._A4 + + @property + def A6(self): + if self._A6 is None: + self._A6 = self.A4.mm(self.A2) + return self._A6 + + @property + def A8(self): + if self._A8 is None: + self._A8 = self.A6.mm(self.A2) + return self._A8 + + @property + def A10(self): + if self._A10 is None: + self._A10 = self.A4.mm(self.A6) + return self._A10 + + @property + def d4_tight(self): + if self._d4_exact is None: + self._d4_exact = _onenorm(self.A4)**(1/4.) + return self._d4_exact + + @property + def d6_tight(self): + if self._d6_exact is None: + self._d6_exact = _onenorm(self.A6)**(1/6.) + return self._d6_exact + + @property + def d8_tight(self): + if self._d8_exact is None: + self._d8_exact = _onenorm(self.A8)**(1/8.) + return self._d8_exact + + @property + def d10_tight(self): + if self._d10_exact is None: + self._d10_exact = _onenorm(self.A10)**(1/10.) + return self._d10_exact + + @property + def d4_loose(self): + return self.d4_tight + + @property + def d6_loose(self): + return self.d6_tight + + @property + def d8_loose(self): + return self.d8_tight + + @property + def d10_loose(self): + return self.d10_tight + + def pade3(self): + b = (120., 60., 12., 1.) + U = self.A.mm(b[3]*self.A2 + b[1]*self.ident) + V = b[2]*self.A2 + b[0]*self.ident + return U, V + + def pade5(self): + b = (30240., 15120., 3360., 420., 30., 1.) + U = self.A.mm(b[5]*self.A4 + b[3]*self.A2 + b[1]*self.ident) + V = b[4]*self.A4 + b[2]*self.A2 + b[0]*self.ident + return U, V + + def pade7_scaled(self, s): + b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.) + + B = self.A * 2**-s + B2 = self.A2 * 2**(-2*s) + B4 = self.A4 * 2**(-4*s) + B6 = self.A6 * 2**(-6*s) + + U = B.mm(b[7]*B6 + b[5]*B4 + b[3]*B2 + b[1]*self.ident) + V = b[6]*B6 + b[4]*B4 + b[2]*B2 + b[0]*self.ident + return U, V + + +def expm32(A): + """ + Compute the matrix exponential using Pade approximation. + + Parameters + ---------- + A : (M,M) array_like or sparse matrix + 2D Array or Matrix (sparse or dense) to be exponentiated + + Returns + ------- + expA : (M,M) ndarray + Matrix exponential of `A` + + Notes + ----- + This is algorithm (6.1) which is a simplification of algorithm (5.1). + + .. versionadded:: 0.12.0 + + References + ---------- + .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2009) + "A New Scaling and Squaring Algorithm for the Matrix Exponential." + SIAM Journal on Matrix Analysis and Applications. + 31 (3). pp. 970-989. ISSN 1095-7162 + + """ + return _expm(A) + + +def _expm(A): + # Core of expm, separated to allow testing exact and approximate + # algorithms. + + # Avoid indiscriminate asarray() to allow sparse or other strange arrays. + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError('expected a square matrix') + + # Trivial case + if A.shape == (1, 1): + return torch.exp(A) + + # Track functions of A to help compute the matrix exponential. + h = _ExpmPadeHelper(A) + + # Try Pade order 3. + eta_1 = max(h.d4_loose, h.d6_loose) + theta3 = 4.2587300348979312e-001 + if eta_1 < theta3 and _ell(h.A, 3) == 0: + U, V = h.pade3() + return _solve_P_Q(U, V) + + # Try Pade order 5. + eta_2 = max(h.d4_tight, h.d6_loose) + theta5 = 1.8801526985337688e+000 + if eta_2 < theta5 and _ell(h.A, 5) == 0: + U, V = h.pade5() + return _solve_P_Q(U, V) + + theta_7 = 3.9257248464332842e+000 + eta_3 = max(h.d6_tight, h.d8_loose) + s = max(int(np.ceil(np.log2(eta_3 / theta_7))), 0) + + s += _ell(2**-s * h.A, 7) + U, V = h.pade7_scaled(s) + X = _solve_P_Q(U, V) + return torch.matrix_power(X, 2**s) + + +def _solve_P_Q(U, V): + P = U + V + Q = -U + V + return torch.solve(P, Q)[0] + + +def _ell(A, m): + """ + A helper function for expm_2009. + + Parameters + ---------- + A : linear operator + A linear operator whose norm of power we care about. + m : int + The power of the linear operator + + Returns + ------- + value : int + A value related to a bound. + + """ + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError('expected A to be like a square matrix') + + p = 2*m + 1 + + # The c_i are explained in (2.2) and (2.6) of the 2005 expm paper. + # They are coefficients of terms of a generating function series expansion. + choose_2p_p = scipy.special.comb(2*p, p, exact=True) + abs_c_recip = float(choose_2p_p * math.factorial(2*p + 1)) + + # This is explained after Eq. (1.2) of the 2009 expm paper. + # It is the "unit roundoff" of IEEE double precision arithmetic. + u = 2.**-24 + + # Compute the one-norm of matrix power p of abs(A). + A_abs_onenorm = _onenorm_matrix_power_nnm(abs(A), p) + + # Treat zero norm as a special case. + if not A_abs_onenorm: + return 0 + + alpha = A_abs_onenorm / (_onenorm(A) * abs_c_recip) + return max(int(np.ceil(np.log2(alpha/u) / (2 * m))), 0) + +def differential(f, A, E): + """ Computes the differential of f at A when acting on E: (df)_A(E) """ + n = A.size(0) + M = torch.zeros(2*n, 2*n, dtype=A.dtype, device=A.device, requires_grad=False) + M[:n, :n] = A + M[n:, n:] = A + M[:n, n:] = E + return f(M)[:n, n:] diff --git a/model/exprnn/initialization.py b/model/exprnn/initialization.py new file mode 100644 index 0000000..db6b032 --- /dev/null +++ b/model/exprnn/initialization.py @@ -0,0 +1,67 @@ +# Downloaded from https://github.com/Lezcano/expRNN + +import torch +import numpy as np +import scipy.linalg as la + + +def henaff_init_(A): + size = A.size(0) // 2 + diag = A.new(size).uniform_(-np.pi, np.pi) + return create_diag_(A, diag) + + +def cayley_init_(A): + size = A.size(0) // 2 + diag = A.new(size).uniform_(0., np.pi / 2.) + diag = -torch.sqrt((1. - torch.cos(diag))/(1. + torch.cos(diag))) + return create_diag_(A, diag) + +# We include a few more initializations that could be useful for other problems +def haar_init_(A): + """ Haar initialization on SO(n) """ + torch.nn.init.orthogonal_(A) + with torch.no_grad(): + if A.det() < 0.: + # Go bijectively from O^-(n) to O^+(n) \iso SO(n) + idx = np.random.randint(0, A.size(0)) + A[idx] *= -1. + An = la.logm(A.data.cpu().numpy()).real + An = .5 * (An - An.T) + A.copy_(torch.tensor(An)) + return A + + +def haar_diag_init_(A): + """ Block-diagonal skew-symmetric matrix with eigenvalues distributed as those from a Haar """ + haar_init_(A) + with torch.no_grad(): + An = A.data.cpu().numpy() + eig = la.eigvals(An).imag + eig = eig[::2] + if A.size(0) % 2 == 1: + eig = eig[:-1] + eig = torch.tensor(eig) + return create_diag_(A, eig) + + +def normal_squeeze_diag_init_(A): + size = A.size(0) // 2 + diag = A.new(size).normal_(0, 1).fmod_(np.pi/8.) + return create_diag_(A, diag) + +def normal_diag_init_(A): + size = A.size(0) // 2 + diag = A.new(size).normal_(0, 1).fmod_(np.pi) + return create_diag_(A, diag) + + +def create_diag_(A, diag): + n = A.size(0) + diag_z = torch.zeros(n-1) + diag_z[::2] = diag + A_init = torch.diag(diag_z, diagonal=1) + A_init = A_init - A_init.T + with torch.no_grad(): + A.copy_(A_init) + return A diff --git a/model/exprnn/orthogonal.py b/model/exprnn/orthogonal.py new file mode 100644 index 0000000..3bbebb9 --- /dev/null +++ b/model/exprnn/orthogonal.py @@ -0,0 +1,107 @@ +# Adapted from https://github.com/Lezcano/expRNN + +import torch +import torch.nn as nn + +from .parametrization import Parametrization + + +class Orthogonal(Parametrization): + """ Class that implements optimization restricted to the Stiefel manifold """ + def __init__(self, input_size, output_size, initializer_skew, mode, param): + """ + mode: "static" or a tuple such that: + mode[0] == "dynamic" + mode[1]: int, K, the number of steps after which we should change the basis of the dyn triv + mode[2]: int, M, the number of changes of basis after which we should project back onto the manifold the basis. This is particularly helpful for small values of K. + + param: A parametrization of in terms of skew-symmetyric matrices + """ + max_size = max(input_size, output_size) + A = torch.empty(max_size, max_size) + base = torch.empty(input_size, output_size) + super(Orthogonal, self).__init__(A, base, mode) + self.input_size = input_size + self.output_size = output_size + self.param = param + self.init_A = initializer_skew + self.init_base = nn.init.eye_ + + self.reset_parameters() + + def reset_parameters(self): + self.init_A(self.A) + self.init_base(self.base) + + def forward(self, input): + return input.matmul(self.B) + + def retraction(self, A, base): + # This could be any parametrization of a tangent space + A = A.triu(diagonal=1) + A = A - A.t() + B = base.mm(self.param(A)) + if self.input_size != self.output_size: + B = B[:self.input_size, :self.output_size] + return B + + def project(self, base): + try: + # Compute the projection using the thin SVD decomposition + U, _, V = torch.svd(base, some=True) + return U.mm(V.t()) + except RuntimeError: + # If the svd does not converge, fallback to the (thin) QR decomposition + x = base + if base.size(0) < base.size(1): + x = base.t() + ret = torch.qr(x, some=True).Q + if base.size(0) < base.size(1): + ret = ret.t() + return ret + + +class modrelu(nn.Module): + def __init__(self, features): + # For now we just support square layers + super(modrelu, self).__init__() + self.features = features + self.b = nn.Parameter(torch.Tensor(self.features)) + self.reset_parameters() + + def reset_parameters(self): + self.b.data.uniform_(-0.01, 0.01) + + def forward(self, inputs): + norm = torch.abs(inputs) + biased_norm = norm + self.b + magnitude = nn.functional.relu(biased_norm) + phase = torch.sign(inputs) + + return phase * magnitude + + +class OrthogonalRNN(nn.Module): + def __init__(self, input_size, hidden_size, initializer_skew, mode, param): + super(OrthogonalRNN, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.recurrent_kernel = Orthogonal(hidden_size, hidden_size, initializer_skew, mode, param=param) + self.input_kernel = nn.Linear(in_features=self.input_size, out_features=self.hidden_size, bias=False) + self.nonlinearity = modrelu(hidden_size) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.input_kernel.weight.data, nonlinearity="relu") + + def default_hidden(self, input): + return input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) + + def forward(self, input, hidden): + input = self.input_kernel(input) + hidden = self.recurrent_kernel(hidden) + out = input + hidden + out = self.nonlinearity(out) + + return out, out diff --git a/model/exprnn/parametrization.py b/model/exprnn/parametrization.py new file mode 100644 index 0000000..6a80939 --- /dev/null +++ b/model/exprnn/parametrization.py @@ -0,0 +1,127 @@ +# Downloaded from https://github.com/Lezcano/expRNN + +import torch +import torch.nn as nn + + +def get_parameters(model): + parametrized_params = [] + + def get_parametrized_params(mod): + nonlocal parametrized_params + if isinstance(mod, Parametrization): + parametrized_params.append(mod.A) + + def not_in(elem, l): + return all(elem is not x for x in l) + + model.apply(get_parametrized_params) + unconstrained_params = (param for param in model.parameters() if not_in(param, parametrized_params)) + return unconstrained_params, parametrized_params + + +class Parametrization(nn.Module): + """ + Implements the parametrization of a manifold in terms of a Euclidean space + + It gives the parametrized matrix through the attribute `B` + + To use it, subclass it and implement the method `retraction` and the method `forward` (and optionally `project`). See the documentation in these methods for details + + You can find an example in the file `orthogonal.py` where we implement the Orthogonal class to optimize over the Stiefel manifold using an arbitrary retraction + """ + + def __init__(self, A, base, mode): + """ + mode: "static" or a tuple such that: + mode[0] == "dynamic" + mode[1]: int, K, the number of steps after which we should change the basis of the dyn triv + mode[2]: int, M, the number of changes of basis after which we should project back onto the manifold the basis. This is particularly helpful for small values of K. + """ + super(Parametrization, self).__init__() + assert mode == "static" or (isinstance(mode, tuple) and len(mode) == 3 and mode[0] == "dynamic") + + self.A = nn.Parameter(A) + self.register_buffer("_B", None) + self.register_buffer('base', base) + # This is necessary, as it will be generated again the first time that self.B is called + # We still need to register the buffer though + + if mode == "static": + self.mode = mode + else: + self.mode = mode[0] + self.K = mode[1] + self.M = mode[2] + self.k = 0 + self.m = 0 + + # This implements the parametrization trick in a rather slick way. + # We put a hook on A, such that, whenever its gradients are computed, we + # get rid of self._B so that it has to be recomputed the next time that + # self.B is accessed + def hook(grad): + nonlocal self + self._B = None + self.A.register_hook(hook) + + + def rebase(self): + with torch.no_grad(): + self.base.data.copy_(self._B.data) + self.A.data.zero_() + + @property + def B(self): + not_B = self._B is None + if not_B or (not self._B.grad_fn and torch.is_grad_enabled()): + self._B = self.retraction(self.A, self.base) + # Just to be safe + self._B.requires_grad_() + # Now self._B it's not a leaf tensor, so we convert it into a leaf + self._B.retain_grad() + + # Increment the counters for the dyntriv algorithm if we have generated B + if self.mode == "dynamic" and not_B: + if self.k == 0: + self.rebase() + # Project the base back to the manifold every M changes of base + # Increment the counter before as we don't project the first time + self.m = (self.m + 1) % self.M + # It's optional to implement this method + if self.m == 0 and hasattr(self, "project"): + with torch.no_grad(): + self.base = self.project(self.base) + # Change the basis after K optimization steps + # Increment the counter afterwards as we change the basis in the first iteration + if self.K != "infty": + self.k = (self.k + 1) % self.K + else: + # Make sure that we just update the base once + if self.k == 0: + self.k = 1 + + return self._B + + def retraction(self, A, base): + """ + It computes r_{base}(A). + Notice that A will not always be in the tangent space of our manifold + For this reason, we first have to use A to parametrize the tangent space, + and then compute the retraction + When dealing with Lie groups, raw_A is always projected into the Lie algebra, as an optimization (cf. Section E in the paper) + """ + raise NotImplementedError + + def project(self, base): + """ + This method is OPTIONAL + It returns the projected base back into the manifold + """ + raise NotImplementedError + + def forward(self, input): + """ + It uses the attribute self.B to implement the layer itself (e.g. Linear, CNN, ...) + """ + raise NotImplementedError diff --git a/model/exprnn/trivializations.py b/model/exprnn/trivializations.py new file mode 100644 index 0000000..2a52fb7 --- /dev/null +++ b/model/exprnn/trivializations.py @@ -0,0 +1,24 @@ +# Downloaded from https://github.com/Lezcano/expRNN + +import torch + +# from model.exprnn.expm32 import expm32, differential +from .expm32 import expm32, differential + +def cayley_map(X): + n = X.size(0) + Id = torch.eye(n, dtype=X.dtype, device=X.device) + return torch.solve(Id - X, Id + X)[0] + +class expm_class(torch.autograd.Function): + @staticmethod + def forward(ctx, A): + ctx.save_for_backward(A) + return expm32(A) + + @staticmethod + def backward(ctx, G): + (A,) = ctx.saved_tensors + return differential(expm32, A.t(), G) + +expm = expm_class.apply diff --git a/model/memory.py b/model/memory.py new file mode 100644 index 0000000..6ac9c56 --- /dev/null +++ b/model/memory.py @@ -0,0 +1,426 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import numpy as np +from scipy import signal +from scipy import linalg as la +from functools import partial + +from model.rnncell import RNNCell +from model.orthogonalcell import OrthogonalLinear +from model.components import Gate, Linear_, Modrelu, get_activation, get_initializer +from model.op import LegSAdaptiveTransitionManual, LegTAdaptiveTransitionManual, LagTAdaptiveTransitionManual, TLagTAdaptiveTransitionManual + + + +forward_aliases = ['euler', 'forward_euler', 'forward', 'forward_diff'] +backward_aliases = ['backward', 'backward_diff', 'backward_euler'] +bilinear_aliases = ['bilinear', 'tustin', 'trapezoidal', 'trapezoid'] +zoh_aliases = ['zoh'] + + +class MemoryCell(RNNCell): + + name = None + valid_keys = ['uxh', 'ux', 'uh', 'um', 'hxm', 'hx', 'hm', 'hh', 'bias', ] + + def default_initializers(self): + return { + 'uxh': 'uniform', + 'hxm': 'xavier', + 'hx': 'xavier', + 'hm': 'xavier', + + 'um': 'zero', + 'hh': 'xavier', + } + + + def default_architecture(self): + return { + 'ux': True, + # 'uh': True, + 'um': False, + 'hx': True, + 'hm': True, + 'hh': False, + 'bias': True, + } + + + def __init__(self, input_size, hidden_size, memory_size, memory_order, + memory_activation='id', + gate='G', # 'N' | 'G' | UR' + memory_output=False, + **kwargs + ): + self.memory_size = memory_size + self.memory_order = memory_order + + self.memory_activation = memory_activation + self.gate = gate + self.memory_output = memory_output + + super(MemoryCell, self).__init__(input_size, hidden_size, **kwargs) + + + self.input_to_hidden_size = self.input_size if self.architecture['hx'] else 0 + self.input_to_memory_size = self.input_size if self.architecture['ux'] else 0 + + # Construct and initialize u + self.W_uxh = nn.Linear(self.input_to_memory_size + self.hidden_size, self.memory_size, + bias=self.architecture['bias']) + # nn.init.zeros_(self.W_uxh.bias) + if 'uxh' in self.initializers: + get_initializer(self.initializers['uxh'], self.memory_activation)(self.W_uxh.weight) + if 'ux' in self.initializers: # Re-init if passed in + get_initializer(self.initializers['ux'], self.memory_activation)(self.W_uxh.weight[:, :self.input_size]) + if 'uh' in self.initializers: # Re-init if passed in + get_initializer(self.initializers['uh'], self.memory_activation)(self.W_uxh.weight[:, self.input_size:]) + + + # Construct and initialize h + self.memory_to_hidden_size = self.memory_size * self.memory_order if self.architecture['hm'] else 0 + preact_ctor = Linear_ + preact_args = [self.input_to_hidden_size + self.memory_to_hidden_size, self.hidden_size, + self.architecture['bias']] + + self.W_hxm = preact_ctor(*preact_args) + + if self.initializers.get('hxm', None) is not None: # Re-init if passed in + get_initializer(self.initializers['hxm'], self.hidden_activation)(self.W_hxm.weight) + if self.initializers.get('hx', None) is not None: # Re-init if passed in + get_initializer(self.initializers['hx'], self.hidden_activation)(self.W_hxm.weight[:, :self.input_size]) + if self.initializers.get('hm', None) is not None: # Re-init if passed in + get_initializer(self.initializers['hm'], self.hidden_activation)(self.W_hxm.weight[:, self.input_size:]) + + if self.architecture['um']: + # No bias here because the implementation is awkward otherwise, but probably doesn't matter + self.W_um = nn.Parameter(torch.Tensor(self.memory_size, self.memory_order)) + get_initializer(self.initializers['um'], self.memory_activation)(self.W_um) + + if self.architecture['hh']: + self.reset_hidden_to_hidden() + else: + self.W_hh = None + + if self.gate is not None: + if self.architecture['hh']: + print("input to hidden size, memory to hidden size, hidden size:", self.input_to_hidden_size, self.memory_to_hidden_size, self.hidden_size) + preact_ctor = Linear_ + preact_args = [self.input_to_hidden_size + self.memory_to_hidden_size + self.hidden_size, self.hidden_size, + self.architecture['bias']] + self.W_gxm = Gate(self.hidden_size, preact_ctor, preact_args, mechanism=self.gate) + + def reset_parameters(self): + # super().reset_parameters() + self.hidden_activation_fn = get_activation(self.hidden_activation, self.hidden_size) # TODO figure out how to remove this duplication + self.memory_activation_fn = get_activation(self.memory_activation, self.memory_size) + + def forward(self, input, state): + h, m, time_step = state + + input_to_hidden = input if self.architecture['hx'] else input.new_empty((0,)) + input_to_memory = input if self.architecture['ux'] else input.new_empty((0,)) + + # Construct the update features + memory_preact = self.W_uxh(torch.cat((input_to_memory, h), dim=-1)) # (batch, memory_size) + if self.architecture['um']: + memory_preact = memory_preact + (m * self.W_um).sum(dim=-1) + u = self.memory_activation_fn(memory_preact) # (batch, memory_size) + + # Update the memory + m = self.update_memory(m, u, time_step) # (batch, memory_size, memory_order) + + # Update hidden state from memory + if self.architecture['hm']: + memory_to_hidden = m.view(input.shape[0], self.memory_size*self.memory_order) + else: + memory_to_hidden = input.new_empty((0,)) + m_inputs = (torch.cat((input_to_hidden, memory_to_hidden), dim=-1),) + hidden_preact = self.W_hxm(*m_inputs) + + if self.architecture['hh']: + hidden_preact = hidden_preact + self.W_hh(h) + hidden = self.hidden_activation_fn(hidden_preact) + + + # Construct gate if necessary + if self.gate is None: + h = hidden + else: + if self.architecture['hh']: + m_inputs = torch.cat((m_inputs[0], h), -1), + g = self.W_gxm(*m_inputs) + h = (1.-g) * h + g * hidden + + next_state = (h, m, time_step + 1) + output = self.output(next_state) + + return output, next_state + + def update_memory(self, m, u, time_step): + """ + m: (B, M, N) [batch size, memory size, memory order] + u: (B, M) + + Output: (B, M, N) + """ + raise NotImplementedError + + def default_state(self, input, batch_size=None): + batch_size = input.size(0) if batch_size is None else batch_size + return (input.new_zeros(batch_size, self.hidden_size, requires_grad=False), + input.new_zeros(batch_size, self.memory_size, self.memory_order, requires_grad=False), + 0) + + def output(self, state): + """ Converts a state into a single output (tensor) """ + h, m, time_step = state + + if self.memory_output: + hm = torch.cat((h, m.view(m.shape[0], self.memory_size*self.memory_order)), dim=-1) + return hm + else: + return h + + def state_size(self): + return self.hidden_size + self.memory_size*self.memory_order + + def output_size(self): + if self.memory_output: + return self.hidden_size + self.memory_size*self.memory_order + else: + return self.hidden_size + + +class LTICell(MemoryCell): + """ A cell implementing Linear Time Invariant dynamics: c' = Ac + Bf. """ + + def __init__(self, input_size, hidden_size, memory_size, memory_order, + A, B, + trainable_scale=0., # how much to scale LR on A and B + dt=0.01, + discretization='zoh', + **kwargs + ): + super().__init__(input_size, hidden_size, memory_size, memory_order, **kwargs) + + + C = np.ones((1, memory_order)) + D = np.zeros((1,)) + dA, dB, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization) + + dA = dA - np.eye(memory_order) # puts into form: x += Ax + self.trainable_scale = np.sqrt(trainable_scale) + if self.trainable_scale <= 0.: + self.register_buffer('A', torch.Tensor(dA)) + self.register_buffer('B', torch.Tensor(dB)) + else: + self.A = nn.Parameter(torch.Tensor(dA / self.trainable_scale), requires_grad=True) + self.B = nn.Parameter(torch.Tensor(dB / self.trainable_scale), requires_grad=True) + + # TODO: proper way to implement LR scale is a preprocess() function that occurs once per unroll + # also very useful for orthogonal params + def update_memory(self, m, u, time_step): + u = u.unsqueeze(-1) # (B, M, 1) + if self.trainable_scale <= 0.: + return m + F.linear(m, self.A) + F.linear(u, self.B) + else: + return m + F.linear(m, self.A * self.trainable_scale) + F.linear(u, self.B * self.trainable_scale) + +class LSICell(MemoryCell): + """ A cell implementing Linear 'Scale' Invariant dynamics: c' = 1/t (Ac + Bf). """ + + def __init__(self, input_size, hidden_size, memory_size, memory_order, + A, B, + init_t = 0, # 0 for special case at t=0 (new code), else old code without special case + max_length=1024, + discretization='bilinear', + **kwargs + ): + """ + # TODO: make init_t start at arbitrary time (instead of 0 or 1) + """ + + # B should have shape (N, 1) + assert len(B.shape) == 2 and B.shape[1] == 1 + + super().__init__(input_size, hidden_size, memory_size, memory_order, **kwargs) + + assert isinstance(init_t, int) + self.init_t = init_t + self.max_length = max_length + + A_stacked = np.empty((max_length, memory_order, memory_order), dtype=A.dtype) + B_stacked = np.empty((max_length, memory_order), dtype=B.dtype) + B = B[:,0] + N = memory_order + for t in range(1, max_length + 1): + At = A / t + Bt = B / t + if discretization in forward_aliases: + A_stacked[t - 1] = np.eye(N) + At + B_stacked[t - 1] = Bt + elif discretization in backward_aliases: + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True) + elif discretization in bilinear_aliases: + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True) + elif discretization in zoh_aliases: + A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t))) + B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True) + B_stacked = B_stacked[:, :, None] + + A_stacked -= np.eye(memory_order) # puts into form: x += Ax + self.register_buffer('A', torch.Tensor(A_stacked)) + self.register_buffer('B', torch.Tensor(B_stacked)) + + + def update_memory(self, m, u, time_step): + u = u.unsqueeze(-1) # (B, M, 1) + t = time_step - 1 + self.init_t + if t < 0: + return F.pad(u, (0, self.memory_order - 1)) + else: + if t >= self.max_length: t = self.max_length - 1 + return m + F.linear(m, self.A[t]) + F.linear(u, self.B[t]) + + +class TimeMemoryCell(MemoryCell): + """ MemoryCell with timestamped data """ + def __init__(self, input_size, hidden_size, memory_size, memory_order, **kwargs): + super().__init__(input_size-1, hidden_size, memory_size, memory_order, **kwargs) + def forward(self, input, state): + h, m, time_step = state + timestamp, input = input[:, 0], input[:, 1:] + + input_to_hidden = input if self.architecture['hx'] else input.new_empty((0,)) + input_to_memory = input if self.architecture['ux'] else input.new_empty((0,)) + + # Construct the update features + memory_preact = self.W_uxh(torch.cat((input_to_memory, h), dim=-1)) # (batch, memory_size) + if self.architecture['um']: + memory_preact = memory_preact + (m * self.W_um).sum(dim=-1) + u = self.memory_activation_fn(memory_preact) # (batch, memory_size) + + # Update the memory + m = self.update_memory(m, u, time_step, timestamp) # (batch, memory_size, memory_order) + + # Update hidden state from memory + if self.architecture['hm']: + memory_to_hidden = m.view(input.shape[0], self.memory_size*self.memory_order) + else: + memory_to_hidden = input.new_empty((0,)) + m_inputs = (torch.cat((input_to_hidden, memory_to_hidden), dim=-1),) + hidden_preact = self.W_hxm(*m_inputs) + + if self.architecture['hh']: + hidden_preact = hidden_preact + self.W_hh(h) + hidden = self.hidden_activation_fn(hidden_preact) + + + # Construct gate if necessary + if self.gate is None: + h = hidden + else: + if self.architecture['hh']: + m_inputs = torch.cat((m_inputs[0], h), -1), + g = self.W_gxm(*m_inputs) + h = (1.-g) * h + g * hidden + + next_state = (h, m, timestamp) + output = self.output(next_state) + + return output, next_state + +class TimeLSICell(TimeMemoryCell): + """ A cell implementing "Linear Scale Invariant" dynamics: c' = Ac + Bf with timestamped inputs. """ + + name = 'tlsi' + + def __init__(self, input_size, hidden_size, memory_size=1, memory_order=-1, + measure='legs', + measure_args={}, + method='manual', + discretization='bilinear', + **kwargs + ): + if memory_order < 0: + memory_order = hidden_size + + + super().__init__(input_size, hidden_size, memory_size, memory_order, **kwargs) + + assert measure in ['legs', 'lagt', 'tlagt', 'legt'] + assert method in ['manual', 'linear', 'toeplitz'] + if measure == 'legs': + if method == 'manual': + self.transition = LegSAdaptiveTransitionManual(self.memory_order) + kwargs = {'precompute': False} + if measure == 'legt': + if method == 'manual': + self.transition = LegTAdaptiveTransitionManual(self.memory_order) + kwargs = {'precompute': False} + elif measure == 'lagt': + if method == 'manual': + self.transition = LagTAdaptiveTransitionManual(self.memory_order) + kwargs = {'precompute': False} + elif measure == 'tlagt': + if method == 'manual': + self.transition = TLagTAdaptiveTransitionManual(self.memory_order, **measure_args) + kwargs = {'precompute': False} + + if discretization in forward_aliases: + self.transition_fn = partial(self.transition.forward_diff, **kwargs) + elif discretization in backward_aliases: + self.transition_fn = partial(self.transition.backward_diff, **kwargs) + elif discretization in bilinear_aliases: + self.transition_fn = partial(self.transition.bilinear, **kwargs) + else: assert False + + + def update_memory(self, m, u, t0, t1): + """ + m: (B, M, N) [batch, memory_size, memory_order] + u: (B, M) + t0: (B,) previous time + t1: (B,) current time + """ + + if torch.eq(t1, 0.).any(): + return F.pad(u.unsqueeze(-1), (0, self.memory_order - 1)) + else: + dt = ((t1-t0)/t1).unsqueeze(-1) + m = self.transition_fn(dt, m, u) + return m + +class TimeLTICell(TimeLSICell): + """ A cell implementing Linear Time Invariant dynamics: c' = Ac + Bf with timestamped inputs. """ + + name = 'tlti' + + def __init__(self, input_size, hidden_size, memory_size=1, memory_order=-1, + dt=1.0, + **kwargs + ): + if memory_order < 0: + memory_order = hidden_size + + self.dt = dt + + super().__init__(input_size, hidden_size, memory_size, memory_order, **kwargs) + + def update_memory(self, m, u, t0, t1): + """ + m: (B, M, N) [batch, memory_size, memory_order] + u: (B, M) + t0: (B,) previous time + t1: (B,) current time + """ + + dt = self.dt*(t1-t0).unsqueeze(-1) + m = self.transition_fn(dt, m, u) + return m diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..f54a2a1 --- /dev/null +++ b/model/model.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +from functools import partial + +from model.rnn import RNN, RNNWrapper, LSTMWrapper +from model import rnncell, opcell # TODO: this is just to force cell_registry to update. There is probably a better programming pattern for this +from model.rnncell import CellBase +from model.orthogonalcell import OrthogonalCell + +class Model(nn.Module): + + def __init__( + self, + input_size, + output_size, + output_len=0, + cell='lstm', + cell_args={}, + output_hiddens=[], + embed_args=None, + preprocess=None, + ff=False, + dropout=0.0, + split=0, + ): + super(Model, self).__init__() + + # Save arguments needed for forward pass + self.input_size = input_size + self.output_size = output_size + self.output_len = output_len + assert output_len >= 0, f"output_len {output_len} should be 0 to return just the state or >0 to return the last output tokens" + self.dropout = dropout + self.split = split + + cell_args['input_size'] = input_size + if embed_args is not None: + self.embed_dim = embed_args['embed_dim'] + self.embedding = nn.Embedding(input_size, self.embed_dim) + cell_args['input_size'] = self.embed_dim + + + ### Handle optional Hippo preprocessing + self.preprocess = preprocess + if self.preprocess is not None: + assert isinstance(self.preprocess, dict) + assert 'order' in self.preprocess + assert 'measure' in self.preprocess + self.hippo = VariableMemoryProjection(**self.preprocess) + cell_args['input_size'] *= (self.preprocess['order']+1) # will append this output to original channels + + ### Construct main RNN + if ff: # feedforward model + cell_args['input_size'] = input_size + self.rnn = QRNN(**cell_args) + else: + # Initialize proper cell type + if cell == 'lstm': + self.rnn = LSTMWrapper(**cell_args, dropout=self.dropout) + else: + if cell in CellBase.registry: + cell_ctor = CellBase.registry[cell] + elif cell == 'orthogonal': + cell_ctor = OrthogonalCell + else: + assert False, f"cell {cell} not supported" + + self.rnn = RNN(cell_ctor(**cell_args), dropout=self.dropout) + if self.split > 0: + self.initial_rnn = RNN(cell_ctor(**cell_args), dropout=self.dropout) + + + ### Construct output head + sizes = [self.rnn.output_size()] + output_hiddens + [output_size] + self.output_mlp = nn.Sequential(*[nn.Linear(sizes[i], sizes[i+1]) for i in range(len(sizes)-1)]) + + + # @profile + def forward(self, inputs, len_batch=None): + B, L, C = inputs.shape + inputs = inputs.transpose(0, 1) # .unsqueeze(-1) # (seq_length, batch, channels) + + # Apply Hippo preprocessing if necessary + if self.preprocess is not None: + p = self.hippo(inputs) + p = p.reshape(L, B, self.input_size * self.preprocess['order']) + inputs = torch.cat([inputs, p], dim=-1) + + # Handle embedding + if hasattr(self, 'embedding'): + inputs = self.embedding(inputs) + if len_batch is not None: + inputs = nn.utils.rnn.pack_padded_sequence(inputs, len_batch, enforce_sorted=False) + + # Option to have separate RNN for head of sequence, mostly for debugging gradients etc + if self.split > 0: + initial_inputs, inputs = inputs[:self.split], inputs[self.split:] + _, initial_state = self.initial_rnn(initial_inputs, return_output=False) + else: + initial_state = None + + # Apply main RNN + if self.output_len > 0: + outputs, _ = self.rnn(inputs, init_state=initial_state, return_output=True) + # get last output tokens + outputs = outputs[-self.output_len:,:,:] + outputs = outputs.transpose(0, 1) + return self.output_mlp(outputs) + else: + _, state = self.rnn(inputs, init_state=initial_state, return_output=False) + state = self.rnn.output(state) + return self.output_mlp(state) + diff --git a/model/op.py b/model/op.py new file mode 100644 index 0000000..a03b6f0 --- /dev/null +++ b/model/op.py @@ -0,0 +1,266 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from scipy import signal +from scipy import linalg as la +from scipy import special as ss + + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures. + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + if measure == 'tlagt': + # beta = 1 corresponds to no tilt + b = measure_args.get('beta', 1.0) + A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + if measure == 'glagt': + alpha = measure_args.get('alpha', 0.0) + beta = measure_args.get('beta', 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) + A = (1./L[:, None]) * A * L[None, :] + B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + # LMU: equivalent to LegT up to normalization + elif measure == 'lmu': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.)**(i-j+1)) * R + B = (-1.)**Q[:, None] * R + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + + return A, B + + + +class AdaptiveTransition(nn.Module): + def precompute_forward(self): + raise NotImplementedError + + def precompute_backward(self): + raise NotImplementedError + + def forward_mult(self, u, delta): + """ Computes (I + delta A) u + + A: (n, n) + u: (..., n) + delta: (...) or scalar + + output: (..., n) + """ + raise NotImplementedError + + def inverse_mult(self, u, delta): # TODO swap u, delta everywhere + """ Computes (I - d A)^-1 u """ + raise NotImplementedError + + # @profile + def forward_diff(self, d, u, v, **kwargs): + """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d B v + d: (...) + u: (..., n) + v: (...) + """ + # TODO F.linear should be replaced by broadcasting, self.B shouldl be shape (n) instead of (n, 1) + # x = self.forward_mult(u, d) + dt * F.linear(v.unsqueeze(-1), self.B) + v = d * v + v = v.unsqueeze(-1) * self.B + x = self.forward_mult(u, d, **kwargs) + x = x + v + return x + + # @profile + def backward_diff(self, d, u, v, **kwargs): + """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d (I - d A)^-1 B v + d: (...) + u: (..., n) + v: (...) + """ + v = d * v + v = v.unsqueeze(-1) * self.B + x = u + v + x = self.inverse_mult(x, d, **kwargs) + return x + + # @profile + def bilinear(self, dt, u, v, alpha=.5, **kwargs): + """ Computes the bilinear (aka trapezoid or Tustin's) update rule. + + (I - d/2 A)^-1 (I + d/2 A) u + d B (I - d/2 A)^-1 B v + """ + x = self.forward_mult(u, (1-alpha)*dt, **kwargs) + v = dt * v + v = v.unsqueeze(-1) * self.B + x = x + v + x = self.inverse_mult(x, (alpha)*dt, **kwargs) + return x + + def zoh(self, dt, u, v): + raise NotImplementedError + + def precompute(self, deltas): + """ deltas: list of step sizes """ + for delta in deltas: + # self.forward_cache[delta] = self.precompute_forward(delta) + # self.backward_cache[delta] = self.precompute_backward(delta) + # TODO being lazy here; should check whether bilinear rule is being used + self.forward_cache[delta/2] = self.precompute_forward(delta/2) + self.backward_cache[delta/2] = self.precompute_backward(delta/2) + + + +class ManualAdaptiveTransition(AdaptiveTransition): + def __init__(self, N, **kwargs): + """ Slow (n^3, or n^2 if step sizes are cached) version via manual matrix mult/inv + + delta: optional list of step sizes to cache the transitions for + """ + super().__init__() + A, B = transition(type(self).measure, N, **kwargs) + self.N = N + self.register_buffer('A', torch.Tensor(A)) + self.register_buffer('B', torch.Tensor(B[:, 0])) + self.register_buffer('I', torch.eye(self.N)) + + # Precompute stacked A, B matrix for zoh computation + AB = torch.cat((self.A, self.B.unsqueeze(-1)), dim=-1) + AB = torch.cat((AB, torch.zeros((1, N+1))), dim=0) + self.register_buffer('AB', AB) + + self.forward_cache = {} + self.backward_cache = {} + + print(f"ManualAdaptiveTransition:\n A {self.A}\nB {self.B}") + + def precompute_forward(self, delta): + return self.I + delta*self.A + + def precompute_backward(self, delta): + return torch.triangular_solve(self.I, self.I - delta*self.A, upper=False)[0] + + def precompute_exp(self, delta): + # NOTE this does not work because torch has no matrix exponential yet, support ongoing: + # https://github.com/pytorch/pytorch/issues/9983 + e = torch.expm(delta * self.AB) + return e[:-1, :-1], e[:-1, -1] + + # @profile + def forward_mult(self, u, delta, precompute=True): + """ Computes (I + d A) u + + A: (n, n) + u: (b1* d, n) d represents memory_size + delta: (b2*, d) or scalar + Assume len(b2) <= len(b1) + + output: (broadcast(b1, b2)*, d, n) + """ + + # For forward Euler, precompute materializes the matrix + if precompute: + if isinstance(delta, torch.Tensor): + delta = delta.unsqueeze(-1).unsqueeze(-1) + # print(delta, isinstance(delta, float), delta in self.forward_cache) + if isinstance(delta, float) and delta in self.forward_cache: + mat = self.forward_cache[delta] + else: + mat = self.precompute_forward(delta) + if len(u.shape) >= len(mat.shape): + # For memory efficiency, leverage extra batch dimensions + s = len(u.shape) + # TODO can make the permutation more efficient by just permuting the last 2 or 3 dim, but need to do more casework) + u = u.permute(list(range(1, s)) + [0]) + x = mat @ u + x = x.permute([s-1] + list(range(s-1))) + else: + x = (mat @ u.unsqueeze(-1))[..., 0] + # x = F.linear(u, mat) + else: + if isinstance(delta, torch.Tensor): + delta = delta.unsqueeze(-1) + x = F.linear(u, self.A) + x = u + delta * x + + return x + + + # @profile + def inverse_mult(self, u, delta, precompute=True): + """ Computes (I - d A)^-1 u """ + + if isinstance(delta, torch.Tensor): + delta = delta.unsqueeze(-1).unsqueeze(-1) + + if precompute: + if isinstance(delta, float) and delta in self.backward_cache: + mat = self.backward_cache[delta] + else: + mat = self.precompute_backward(delta) # (n, n) or (..., n, n) + + if len(u.shape) >= len(mat.shape): + # For memory efficiency, leverage extra batch dimensions + s = len(u.shape) + # TODO can make the permutation more efficient by just permuting the last 2 or 3 dim, but need to do more casework + u = u.permute(list(range(1, s)) + [0]) + x = mat @ u + x = x.permute([s-1] + list(range(s-1))) + else: + x = (mat @ u.unsqueeze(-1))[..., 0] + + else: + _A = self.I - delta*self.A + x = torch.triangular_solve(u.unsqueeze(-1), _A, upper=False)[0] + x = x[..., 0] + + return x + + def zoh(self, dt, u, v): + dA, dB = self.precompute_exp(dt) + return F.linear(u, dA) + dB * v.unsqueeze(-1) + +class LegSAdaptiveTransitionManual(ManualAdaptiveTransition): + measure = 'legs' + +class LegTAdaptiveTransitionManual(ManualAdaptiveTransition): + measure = 'legt' + +class LagTAdaptiveTransitionManual(ManualAdaptiveTransition): + measure = 'lagt' + +class TLagTAdaptiveTransitionManual(ManualAdaptiveTransition): + measure = 'tlagt' diff --git a/model/opcell.py b/model/opcell.py new file mode 100644 index 0000000..3ba6456 --- /dev/null +++ b/model/opcell.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np + +from model.memory import LTICell, LSICell +from model.op import transition + + +class OPLTICell(LTICell): + # name = 'lagt' + measure = None + + def __init__(self, input_size, hidden_size, memory_size=1, memory_order=-1, measure_args={}, + **kwargs + ): + if memory_order < 0: + memory_order = hidden_size + + # A, B = transition(type(self).measure, memory_order) + A, B = transition(type(self).measure, memory_order, **measure_args) + super().__init__(input_size, hidden_size, memory_size, memory_order, A, B, **kwargs) +class OPLSICell(LSICell): + # name = 'lagt' + measure = None + + def __init__(self, input_size, hidden_size, memory_size=1, memory_order=-1, measure_args={}, + **kwargs + ): + if memory_order < 0: + memory_order = hidden_size + + A, B = transition(type(self).measure, memory_order, **measure_args) + super().__init__(input_size, hidden_size, memory_size, memory_order, A, B, **kwargs) + +# TODO there should be a way to declare the parent class programatically to avoid duplicating this +# i.e. have a single OPCell that calls the appropriate superclass constructor +# for measure in ['lagt', 'legt', 'legs']: +# type('t'+measure, OPLTICell, {'measure': measure}): +# type('s'+measure, OPLSICell, {'measure': measure}): + +class LegendreTranslateCell(OPLTICell): + name = 'legt' + measure = 'legt' +class LegendreTranslateSCell(OPLSICell): + name = 'legts' + measure = 'legt' +class LegendreScaleCell(OPLSICell): + name = 'legs' + measure = 'legs' +class LegendreScaleTCell(OPLTICell): + name = 'legst' + measure = 'legs' +class LaguerreTranslateCell(OPLTICell): + name = 'lagt' + measure = 'lagt' +class LaguerreTranslateSCell(OPLSICell): + name = 'lagts' + measure = 'lagt' +class LMUTCell(OPLTICell): + name = 'lmut' + measure = 'lmu' +class LMUCell(OPLTICell): + name = 'lmu' + measure = 'lmu' + + def default_initializers(self): + return { + 'uxh': 'uniform', + 'ux': 'one', + 'uh': 'zero', + 'um': 'zero', + 'hxm': 'xavier', + 'hx': 'zero', + 'hh': 'zero', + 'hm': 'xavier', + } + + def default_architecture(self): + return { + 'ux': True, + 'um': True, + 'hx': True, + 'hm': True, + 'hh': True, + 'bias': False, + } + + def __init__(self, input_size, hidden_size, theta=100, dt=1., **kwargs): + super().__init__(input_size, hidden_size, dt=dt/theta, **kwargs) + + +class LegendreScaleNoiseCell(LTICell): + name = 'legsn' + measure = 'legs' + + def __init__(self, input_size, hidden_size, memory_size=1, memory_order=-1, + **kwargs + ): + if memory_order < 0: + memory_order = hidden_size + + A, B = transition(type(self).measure, memory_order) + N = memory_order + A = A + np.random.normal(size=(N, N)) / N + + super().__init__(input_size, hidden_size, memory_size, memory_order, A, B, **kwargs) diff --git a/model/orthogonalcell.py b/model/orthogonalcell.py new file mode 100644 index 0000000..8871a8b --- /dev/null +++ b/model/orthogonalcell.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn + +from model.exprnn.orthogonal import Orthogonal +from model.exprnn.trivializations import expm, cayley_map +from model.exprnn.initialization import henaff_init_, cayley_init_ + +from model.components import Modrelu + +param_name_to_param = {'cayley': cayley_map, 'expm': expm} +init_name_to_init = {'henaff': henaff_init_, 'cayley': cayley_init_} + + +class OrthogonalLinear(Orthogonal): + def __init__(self, input_size, output_size, method='exprnn', init='cayley', K=100): + """ Wrapper around expRNN's Orthogonal class taking care of parameter names """ + if method == "exprnn": + mode = "static" + param = 'expm' + elif method == "dtriv": + # We use 100 as the default to project back to the manifold. + # This parameter does not really affect the convergence of the algorithms, even for K=1 + mode = ("dynamic", ortho_args['K'], 100) # TODO maybe K=30? check exprnn codebase + param = 'expm' + elif method == "cayley": + mode = "static" + param = 'cayley' + else: + assert False, f"OrthogonalLinear: orthogonal method {method} not supported" + + param = param_name_to_param[param] + init_A = init_name_to_init[init] + super().__init__(input_size, output_size, init_A, mode, param) + +class OrthogonalCell(nn.Module): + """ Replacement for expRNN's OrthogonalRNN class + + initializer_skew (str): either 'henaff' or 'cayley' + param (str): A parametrization of in terms of skew-symmetyric matrices, either 'cayley' or 'expm' + """ + def __init__(self, input_size, hidden_size, **ortho_args): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.recurrent_kernel = OrthogonalLinear(hidden_size, hidden_size, **ortho_args) + self.input_kernel = nn.Linear(in_features=self.input_size, out_features=self.hidden_size, bias=False) + self.nonlinearity = Modrelu(hidden_size) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.input_kernel.weight.data, nonlinearity="relu") + + def forward(self, input, hidden): + input = self.input_kernel(input) + hidden = self.recurrent_kernel(hidden) + out = input + hidden + out = self.nonlinearity(out) + + return out, out + + def default_state(self, input, batch_size=None): + return input.new_zeros(input.size(0) if batch_size is None else batch_size, + self.hidden_size, requires_grad=False) + + def output(self, h): + return h + + def state_size(self): + return self.hidden_size + + def output_size(self): + return self.hidden_size + + def initial_state(self, trainable=False): + """ Return initial state of the RNN + This should not need to see the input as it should be batch size agnostic and automatically broadcasted + + # TODO Currently not used + """ + if trainable: + self.initial_state = torch.zeros(self.hidden_size, requires_grad=True) + else: + return torch.zeros(self.hidden_size, requires_grad=True) diff --git a/model/rnn.py b/model/rnn.py new file mode 100644 index 0000000..5181ada --- /dev/null +++ b/model/rnn.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def apply_tuple(tup, fn): + """Apply a function to a Tensor or a tuple of Tensor + """ + if isinstance(tup, tuple): + return tuple((fn(x) if isinstance(x, torch.Tensor) else x) for x in tup) + else: + return fn(tup) + +def concat_tuple(tups, dim=0): + """Concat a list of Tensors or a list of tuples of Tensor + """ + if isinstance(tups[0], tuple): + return tuple((torch.cat(xs, dim) if isinstance(xs[0], torch.Tensor) else xs[0]) for xs in zip(*tups)) + else: + return torch.cat(tups, dim) + + +class RNN(nn.Module): + + def __init__(self, cell, dropout=0.0): + super().__init__() + self.cell = cell + + if dropout > 0.0: + self.use_dropout = True + self.drop_prob = dropout + self.dropout = nn.Dropout(p=dropout) + else: + self.use_dropout = False + + def forward(self, inputs, init_state=None, return_output=False): + """ + cell.forward : (input, state) -> (output, state) + inputs : [length, batch, dim] + """ + # Similar implementation to https://github.com/pytorch/pytorch/blob/9e94e464535e768ad3444525aecd78893504811f/torch/nn/modules/rnn.py#L202 + is_packed = isinstance(inputs, nn.utils.rnn.PackedSequence) + if is_packed: + inputs, batch_sizes, sorted_indices, unsorted_indices = inputs + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + max_batch_size = inputs.size(1) + sorted_indices = None + unsorted_indices = None + # Construct initial state + if init_state is None: + state = self.cell.default_state(inputs[0], max_batch_size) + else: + state = apply_tuple(init_state, lambda x: x[sorted_indices] if sorted_indices is not None else x) + # Construct recurrent dropout masks + if self.use_dropout: + input_dropout = self.dropout(torch.ones(max_batch_size, self.cell.input_size, device=inputs.device)) + recurrent_dropout = self.dropout(torch.ones(max_batch_size, self.cell.hidden_size, device=inputs.device)) + output_dropout = self.dropout(torch.ones(max_batch_size, self.output_size(), device=inputs.device)) + + outputs = [] + if not is_packed: + for input in torch.unbind(inputs, dim=0): + if self.use_dropout: + ## Recurrent Dropout + input = input * input_dropout + output, new_state = self.cell.forward(input, state) + if self.use_dropout: + output = output * output_dropout + try: + state = (self.dropout(new_state[0]),) + new_state[1:] # TODO not general + except: + state = self.dropout(new_state) + else: + state = new_state + if return_output: + outputs.append(output) + return torch.stack(outputs) if return_output else None, state + else: + # Following implementation at https://github.com/pytorch/pytorch/blob/9e94e464535e768ad3444525aecd78893504811f/aten/src/ATen/native/RNN.cpp#L621 + # Batch sizes is a sequence of decreasing lengths, which are offsets + # into a 1D list of inputs. At every step we slice out batch_size elements, + # and possibly account for the decrease in the batch size since the last step, + # which requires us to slice the hidden state (since some sequences + # are completed now). The sliced parts are also saved, because we will need + # to return a tensor of final hidden state. + batch_sizes_og = batch_sizes + batch_sizes = batch_sizes.detach().cpu().numpy() + input_offset = 0 + last_batch_size = batch_sizes[0] + saved_states = [] + for batch_size in batch_sizes: + step_input = inputs[input_offset:input_offset + batch_size] + input_offset += batch_size + dec = last_batch_size - batch_size + if (dec > 0): + saved_state = apply_tuple(state, lambda x: x[batch_size:]) + state = apply_tuple(state, lambda x: x[:batch_size]) + saved_states.append(saved_state) + last_batch_size = batch_size + if self.use_dropout: + step_input = step_input * input_dropout[:batch_size] + output, new_state = self.cell.forward(step_input, state) + if self.use_dropout: + output = output * output_dropout[:batch_size] + try: + state = (self.dropout(new_state[0]),) + new_state[1:] # TODO not general + except: + state = self.dropout(new_state) + else: + state = new_state + if return_output: + outputs.append(output) + saved_states.append(state) + saved_states.reverse() + state = concat_tuple(saved_states) + state = apply_tuple(state, lambda x: x[unsorted_indices] if unsorted_indices is not None else x) + if return_output: + outputs = nn.utils.rnn.PackedSequence(torch.cat(outputs, dim=0), batch_sizes_og, sorted_indices, unsorted_indices) + else: + outputs = None + return outputs, state + + def state_size(self): + return self.cell.state_size() + + def output_size(self): + return self.cell.output_size() + + def output(self, state): + return self.cell.output(state) + + +class RNNWrapper(nn.RNN): + + def forward(self, inputs, h_0=None): + output, h_n = super().forward(inputs, h_0) + return output, h_n.squeeze(0) + + +class LSTMWrapper(nn.LSTM): + + # return_output is only here to absorb the argument, making the interface compatible with RNN + def forward(self, inputs, return_output=None, init_state=None): + # init_state is just to absorb the extra argument that can be passed into our custom RNNs. Replaces (h_0, c_0) argument of nn.LSTM + output, (h_n, c_n) = super().forward(inputs, init_state) + return output, (h_n.squeeze(0), c_n.squeeze(0)) + + def state_size(self): + return self.hidden_size + + def output_size(self): + return self.hidden_size + + def output(self, state): + return state[0] diff --git a/model/rnncell.py b/model/rnncell.py new file mode 100644 index 0000000..97c823f --- /dev/null +++ b/model/rnncell.py @@ -0,0 +1,269 @@ +""" Baseline RNN cells such as the vanilla RNN and GRU. """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.components import Gate, Linear_, Modrelu, get_activation, get_initializer +from model.orthogonalcell import OrthogonalLinear + + +class CellBase(nn.Module): + """ Abstract class for our recurrent cell interface. + + Passes input through + """ + registry = {} + + # https://www.python.org/dev/peps/pep-0487/#subclass-registration + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Only register classes with @name attribute + if hasattr(cls, 'name') and cls.name is not None: + cls.registry[cls.name] = cls + + name = 'id' + valid_keys = [] + + def default_initializers(self): + return {} + + def default_architecture(self): + return {} + + def __init__(self, input_size, hidden_size, initializers=None, architecture=None): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.architecture = self.default_architecture() + self.initializers = self.default_initializers() + if initializers is not None: + self.initializers.update(initializers) + print("Initializers:", initializers) + if architecture is not None: + self.architecture.update(architecture) + + assert set(self.initializers.keys()).issubset(self.valid_keys) + assert set(self.architecture.keys()).issubset(self.valid_keys) + + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, input, hidden): + return input, input + + def default_state(self, input, batch_size=None): + return input.new_zeros(input.size(0) if batch_size is None else batch_size, + self.hidden_size, requires_grad=False) + + def output(self, h): + return h + + def state_size(self): + return self.hidden_size + + def output_size(self): + return self.hidden_size + + def initial_state(self, trainable=False): + """ Return initial state of the RNN + This should not need to see the input as it should be batch size agnostic and automatically broadcasted + + # TODO Currently not used + """ + if trainable: + self.initial_state = torch.zeros(self.hidden_size, requires_grad=True) + else: + return torch.zeros(self.hidden_size, requires_grad=True) + + + +class RNNCell(CellBase): + name = 'rnn' + + valid_keys = ['hx', 'hh', 'bias'] + + def default_initializers(self): + return { + 'hx': 'xavier', + 'hh': 'xavier', + } + + def default_architecture(self): + return { + 'bias': True, + } + + + def __init__(self, input_size, hidden_size, + hidden_activation='tanh', + orthogonal=False, + ortho_args=None, + zero_bias_init=False, + **kwargs + ): + + self.hidden_activation = hidden_activation + self.orthogonal = orthogonal + self.ortho_args = ortho_args + self.zero_bias_init=zero_bias_init + + super().__init__(input_size, hidden_size, + **kwargs, + ) + + def reset_parameters(self): + self.W_hx = Linear_(self.input_size, self.hidden_size, bias=self.architecture['bias'], zero_bias_init=self.zero_bias_init) + get_initializer(self.initializers['hx'], self.hidden_activation)(self.W_hx.weight) + self.hidden_activation_fn = get_activation(self.hidden_activation, self.hidden_size) + + self.reset_hidden_to_hidden() + + def reset_hidden_to_hidden(self): + if self.orthogonal: + + if self.ortho_args is None: + self.ortho_args = {} + self.ortho_args['input_size'] = self.hidden_size + self.ortho_args['output_size'] = self.hidden_size + + self.W_hh = OrthogonalLinear(**self.ortho_args) + else: + self.W_hh = nn.Linear(self.hidden_size, self.hidden_size, bias=self.architecture['bias']) + get_initializer(self.initializers['hh'], self.hidden_activation)(self.W_hh.weight) + + def forward(self, input, h): + ### Update hidden state + hidden_preact = self.W_hx(input) + self.W_hh(h) + hidden = self.hidden_activation_fn(hidden_preact) + + return hidden, hidden + +class GatedRNNCell(RNNCell): + name = 'gru' + + def __init__(self, input_size, hidden_size, + gate='G', # 'N' | 'G' + reset='N', + **kwargs + ): + self.gate = gate + self.reset = reset + super().__init__(input_size, hidden_size, **kwargs) + + def reset_parameters(self): + super().reset_parameters() + + preact_ctor = Linear_ + preact_args = [self.input_size + self.hidden_size, self.hidden_size, self.architecture['bias']] + self.W_g = Gate(self.hidden_size, preact_ctor, preact_args, mechanism=self.gate) + self.W_reset = Gate(self.hidden_size, preact_ctor, preact_args, mechanism=self.reset) + + def forward(self, input, h): + hx = torch.cat((input, h), dim=-1) + reset = self.W_reset(hx) + + _, update = super().forward(input, reset*h) + + g = self.W_g(hx) + h = (1.-g) * h + g * update + + return h, h + +class MinimalRNNCell(CellBase): + name = 'mrnn' + + valid_keys = ['hx', 'bias'] + + def default_initializers(self): + return { + 'hx': 'xavier', + } + + def default_architecture(self): + return { + 'bias': True, + } + + + def __init__(self, input_size, hidden_size, + hidden_activation='tanh', + orthogonal=False, + ortho_args=None, + zero_bias_init=False, + **kwargs + ): + + self.hidden_activation = hidden_activation + self.zero_bias_init=zero_bias_init + + super().__init__(input_size, hidden_size, + **kwargs, + ) + + def reset_parameters(self): + self.W_hx = Linear_(self.input_size, self.hidden_size, bias=self.architecture['bias'], zero_bias_init=self.zero_bias_init) + get_initializer(self.initializers['hx'], self.hidden_activation)(self.W_hx.weight) + self.hidden_activation_fn = get_activation(self.hidden_activation, self.hidden_size) + + preact_ctor = Linear_ + preact_args = [self.input_size + self.hidden_size, self.hidden_size, self.architecture['bias']] + self.W_g = Gate(self.hidden_size, preact_ctor, preact_args, mechanism='G') + + + def forward(self, input, h): + ### Update hidden state + hidden_preact = self.W_hx(input) + hidden = self.hidden_activation_fn(hidden_preact) + hx = torch.cat((input, h), dim=-1) + g = self.W_g(hx) + h = (1.-g) * h + g * hidden + + return h, h + + +class GatedSRNNCell(GatedRNNCell): + name = 'grus' + + def __init__(self, input_size, hidden_size, + **kwargs + ): + super().__init__(input_size, hidden_size, **kwargs) + + def reset_parameters(self): + super().reset_parameters() + + def forward(self, input, hidden): + hidden, t = hidden + + hx = torch.cat((input, hidden), dim=-1) + reset = self.W_reset(hx) + + _, update = super().forward(input, reset*hidden) + + g = self.W_g(hx) + g = g * 1. / (t+1) + h = (1.-g) * hidden + g * update + + return h, (h, t+1) + + def default_state(self, input, batch_size=None): + batch_size = input.size(0) if batch_size is None else batch_size + return (input.new_zeros(batch_size, self.hidden_size, requires_grad=False), + 0) + + def output(self, state): + """ Converts a state into a single output (tensor) """ + h, t = state + + return h + +class ExpRNNCell(RNNCell): + """ Note: there is a subtle distinction between this and the ExpRNN original cell (now implemented as orthogonalcell.OrthogonalCell) in the initialization of hx, but this shouldn't matter """ + name = 'exprnn' + + def __init__(self, input_size, hidden_size, **kwargs): + super().__init__(input_size, hidden_size, orthogonal=True, hidden_activation='modrelu', **kwargs) diff --git a/pl_runner.py b/pl_runner.py new file mode 100644 index 0000000..3f0ed45 --- /dev/null +++ b/pl_runner.py @@ -0,0 +1,34 @@ +import torch +import pytorch_lightning as pl + + +def pl_train(cfg, pl_model_class): + if cfg.seed is not None: + torch.manual_seed(cfg.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(cfg.seed) + model = pl_model_class(cfg.model, cfg.dataset, cfg.train) + if 'pl' in cfg and 'profile' in cfg.pl and cfg.pl.profile: + profiler_args = { 'profiler': pl.profiler.AdvancedProfiler(), } + else: + profiler_args = {} + if 'pl' in cfg and 'wandb' in cfg.pl and cfg.pl.wandb: + logger = WandbLogger(project='ops-memory-pl') + logger.log_hyperparams(cfg.model) + logger.log_hyperparams(cfg.dataset) + logger.log_hyperparams(cfg.train) + profiler_args['logger'] = logger + print("profiler args", profiler_args) + trainer = pl.Trainer( + gpus=1, + gradient_clip_val=cfg.train.gradient_clip_val, + max_epochs=1 if cfg.smoke_test else cfg.train.epochs, + early_stop_callback=False, progress_bar_refresh_rate=1, + limit_train_batches=cfg.train.limit_train_batches, + track_grad_norm=2, + **profiler_args, + ) + + trainer.fit(model) + trainer.test(model) + return trainer, model diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3c02033 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +pytorch-lightning>=0.7.6 +torchtext +numpy +scipy +sklearn +sktime +matplotlib +tqdm +hydra-core==1.0.0rc4 +omegaconf +munch diff --git a/tensorflow/hippo.py b/tensorflow/hippo.py new file mode 100644 index 0000000..087652b --- /dev/null +++ b/tensorflow/hippo.py @@ -0,0 +1,390 @@ +import numpy as np + +from keras import backend as K +from keras import activations, initializers +from keras.initializers import Constant, Initializer +from keras.layers import Layer + +from scipy import signal +from scipy import linalg as la +import math +import tensorflow as tf + + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + if measure == 'tlagt': + # beta = 1 corresponds to no tilt + b = measure_args.get('beta', 1.0) + A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + if measure == 'glagt': + alpha = measure_args.get('alpha', 0.0) + beta = measure_args.get('beta', 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) + A = (1./L[:, None]) * A * L[None, :] + B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + # LMU: equivalent to LegT up to normalization + elif measure == 'lmu': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.)**(i-j+1)) * R + B = (-1.)**Q[:, None] * R + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + + return A, B + +forward_aliases = ['euler', 'forward_euler', 'forward', 'forward_diff'] +backward_aliases = ['backward', 'backward_diff', 'backward_euler'] +bilinear_aliases = ['bilinear', 'tustin', 'trapezoidal', 'trapezoid'] +zoh_aliases = ['zoh'] + +class HippoTCell(Layer): + + def __init__(self, + units, + memory_order, + theta, # relative to dt=1 + measure='legt', + method='zoh', + trainable_input_encoders=True, + trainable_hidden_encoders=True, + trainable_memory_encoders=True, + trainable_input_kernel=True, + trainable_hidden_kernel=True, + trainable_memory_kernel=True, + trainable_A=False, + trainable_B=False, + input_encoders_initializer='lecun_uniform', + hidden_encoders_initializer='lecun_uniform', + memory_encoders_initializer=Constant(0), # 'lecun_uniform', + input_kernel_initializer='glorot_normal', + hidden_kernel_initializer='glorot_normal', + memory_kernel_initializer='glorot_normal', + hidden_activation='tanh', + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.memory_order = memory_order + self.theta = theta + self.method = method + self.trainable_input_encoders = trainable_input_encoders + self.trainable_hidden_encoders = trainable_hidden_encoders + self.trainable_memory_encoders = trainable_memory_encoders + self.trainable_input_kernel = trainable_input_kernel + self.trainable_hidden_kernel = trainable_hidden_kernel + self.trainable_memory_kernel = trainable_memory_kernel + self.trainable_A = trainable_A + self.trainable_B = trainable_B + + self.input_encoders_initializer = initializers.get( + input_encoders_initializer) + self.hidden_encoders_initializer = initializers.get( + hidden_encoders_initializer) + self.memory_encoders_initializer = initializers.get( + memory_encoders_initializer) + self.input_kernel_initializer = initializers.get( + input_kernel_initializer) + self.hidden_kernel_initializer = initializers.get( + hidden_kernel_initializer) + self.memory_kernel_initializer = initializers.get( + memory_kernel_initializer) + + self.hidden_activation = activations.get(hidden_activation) + + A, B = transition(measure, memory_order) + # Construct A and B matrices + C = np.ones((1, memory_order)) + D = np.zeros((1,)) + dA, dB, _, _, _ = signal.cont2discrete((A, B, C, D), dt=1./theta, method=method) + + self._A = dA - np.eye(memory_order) # puts into form: x += Ax + self._B = dB + + self.state_size = (self.units, self.memory_order) + self.output_size = self.units + + def build(self, input_shape): + input_dim = input_shape[-1] + + self.input_encoders = self.add_weight( + name='input_encoders', + shape=(input_dim, 1), + initializer=self.input_encoders_initializer, + trainable=self.trainable_input_encoders) + + self.hidden_encoders = self.add_weight( + name='hidden_encoders', + shape=(self.units, 1), + initializer=self.hidden_encoders_initializer, + trainable=self.trainable_hidden_encoders) + + self.memory_encoders = self.add_weight( + name='memory_encoders', + shape=(self.memory_order, 1), + initializer=self.memory_encoders_initializer, + trainable=self.trainable_memory_encoders) + + self.input_kernel = self.add_weight( + name='input_kernel', + shape=(input_dim, self.units), + initializer=self.input_kernel_initializer, + trainable=self.trainable_input_kernel) + + self.hidden_kernel = self.add_weight( + name='hidden_kernel', + shape=(self.units, self.units), + initializer=self.hidden_kernel_initializer, + trainable=self.trainable_hidden_kernel) + + self.memory_kernel = self.add_weight( + name='memory_kernel', + shape=(self.memory_order, self.units), + initializer=self.memory_kernel_initializer, + trainable=self.trainable_memory_kernel) + + self.AT = self.add_weight( + name='AT', + shape=(self.memory_order, self.memory_order), + initializer=Constant(self._A.T), # note: transposed + trainable=self.trainable_A) + + self.BT = self.add_weight( + name='BT', + shape=(1, self.memory_order), # system is SISO + initializer=Constant(self._B.T), # note: transposed + trainable=self.trainable_B) + + self.built = True + + def call(self, inputs, states): + h, m = states + + u = (K.dot(inputs, self.input_encoders) + + K.dot(h, self.hidden_encoders) + + K.dot(m, self.memory_encoders)) + + m = m + K.dot(m, self.AT) + K.dot(u, self.BT) + + h = self.hidden_activation( + K.dot(inputs, self.input_kernel) + + K.dot(h, self.hidden_kernel) + + K.dot(m, self.memory_kernel)) + + return h, [h, m] + +class HippoSCell(Layer): + + def __init__(self, + units, + memory_order, + measure='legt', + method='zoh', + max_length=256, + trainable_input_encoders=True, + trainable_hidden_encoders=True, + trainable_memory_encoders=True, + trainable_input_kernel=True, + trainable_hidden_kernel=True, + trainable_memory_kernel=True, + trainable_A=False, + trainable_B=False, + input_encoders_initializer='lecun_uniform', + hidden_encoders_initializer='lecun_uniform', + memory_encoders_initializer=Constant(0), # 'lecun_uniform', + input_kernel_initializer='glorot_normal', + hidden_kernel_initializer='glorot_normal', + memory_kernel_initializer='glorot_normal', + hidden_activation='tanh', + gate=False, + **kwargs): + super().__init__(**kwargs) + + self.units = units + self.memory_order = memory_order + self.method = method + self.max_length = max_length + self.trainable_input_encoders = trainable_input_encoders + self.trainable_hidden_encoders = trainable_hidden_encoders + self.trainable_memory_encoders = trainable_memory_encoders + self.trainable_input_kernel = trainable_input_kernel + self.trainable_hidden_kernel = trainable_hidden_kernel + self.trainable_memory_kernel = trainable_memory_kernel + self.trainable_A = trainable_A + self.trainable_B = trainable_B + self.gate = gate + + self.input_encoders_initializer = initializers.get( + input_encoders_initializer) + self.hidden_encoders_initializer = initializers.get( + hidden_encoders_initializer) + self.memory_encoders_initializer = initializers.get( + memory_encoders_initializer) + self.input_kernel_initializer = initializers.get( + input_kernel_initializer) + self.hidden_kernel_initializer = initializers.get( + hidden_kernel_initializer) + self.memory_kernel_initializer = initializers.get( + memory_kernel_initializer) + + self.hidden_activation = activations.get(hidden_activation) + + A, B = transition(measure, memory_order) + # Construct A and B matrices + + A_stacked = np.empty((max_length, memory_order, memory_order), dtype=A.dtype) + B_stacked = np.empty((max_length, memory_order), dtype=B.dtype) + B = B[:,0] + N = memory_order + for t in range(1, max_length + 1): + At = A / t + Bt = B / t + # if discretization in forward_aliases: + if method in forward_aliases: + A_stacked[t - 1] = np.eye(N) + At + B_stacked[t - 1] = Bt + # elif discretization in backward_aliases: + elif method in backward_aliases: + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True) + elif method in bilinear_aliases: + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True) + elif method in zoh_aliases: + A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t))) + B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True) + B_stacked = B_stacked[:, :, None] + + A_stacked -= np.eye(memory_order) # puts into form: x += Ax + self._A = A_stacked - np.eye(memory_order) # puts into form: x += Ax + self._B = B_stacked + + self.state_size = (self.units, self.memory_order, 1) + self.output_size = self.units + + def build(self, input_shape): + input_dim = input_shape[-1] + + self.input_encoders = self.add_weight( + name='input_encoders', + shape=(input_dim, 1), + initializer=self.input_encoders_initializer, + trainable=self.trainable_input_encoders) + + self.hidden_encoders = self.add_weight( + name='hidden_encoders', + shape=(self.units, 1), + initializer=self.hidden_encoders_initializer, + trainable=self.trainable_hidden_encoders) + + self.memory_encoders = self.add_weight( + name='memory_encoders', + shape=(self.memory_order, 1), + initializer=self.memory_encoders_initializer, + trainable=self.trainable_memory_encoders) + + self.input_kernel = self.add_weight( + name='input_kernel', + shape=(input_dim, self.units), + initializer=self.input_kernel_initializer, + trainable=self.trainable_input_kernel) + + if self.trainable_hidden_kernel: + self.hidden_kernel = self.add_weight( + name='hidden_kernel', + shape=(self.units, self.units), + initializer=self.hidden_kernel_initializer, + trainable=self.trainable_hidden_kernel) + else: + self.hidden_kernel = self.add_weight( + name='hidden_kernel', + shape=(self.units, self.units), + initializer=Constant(0.), + trainable=False) + + self.memory_kernel = self.add_weight( + name='memory_kernel', + shape=(self.memory_order, self.units), + initializer=self.memory_kernel_initializer, + trainable=self.trainable_memory_kernel) + + self.A = self.add_weight( + name='A', + shape=(self.max_length, self.memory_order, self.memory_order), + initializer=Constant(self._A), # note: transposed + trainable=self.trainable_A) + + self.B = self.add_weight( + name='B', + shape=(self.max_length, self.memory_order, 1), # system is SISO + initializer=Constant(self._B), # note: transposed + trainable=self.trainable_B) + + if self.gate: + self.W_gate = self.add_weight( + name='gate', + shape=(self.units+self.memory_order, self.units), # system is SISO + initializer=initializers.get('glorot_normal'), # note: transposed + trainable=True) + + self.built = True + + def call(self, inputs, states): + h, m, t = states + tt = tf.cast(t, tf.int32) + tt = tt[0,0] + + tt = tf.math.minimum(tt, self.max_length-1) + u = (K.dot(inputs, self.input_encoders) + + K.dot(h, self.hidden_encoders) + + K.dot(m, self.memory_encoders)) + + m = m + K.dot(m, tf.transpose(self.A[tt])) + K.dot(u, tf.transpose(self.B[tt])) + + new_h = self.hidden_activation( + K.dot(inputs, self.input_kernel) + + K.dot(h, self.hidden_kernel) + + K.dot(m, self.memory_kernel)) + if self.gate: + g = tf.sigmoid(K.dot(tf.concat([h, m], axis=-1), self.W_gate)) + h = (1.-g)*h + g*new_h + else: + h = new_h + + return h, [h, m, t+1] diff --git a/tests/test_legs_extension.py b/tests/test_legs_extension.py new file mode 100644 index 0000000..1055854 --- /dev/null +++ b/tests/test_legs_extension.py @@ -0,0 +1,229 @@ +import math +import unittest + +import numpy as np +from scipy import linalg as la + +import torch +import torch.nn.functional as F +import hippo + +# from .op import transition + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures """ + if measure == 'lagt': + # A_l = (1 - dt / 4) * np.eye(N) + dt / 2 * np.tril(np.ones((N, N))) + # A_r = (1 + dt / 4) * np.eye(N) - dt / 2 * np.tril(np.ones((N, N))) + # alpha = dt / 2 / (1 - dt / 4) + # col = -alpha / (1 + alpha) ** np.arange(1, N + 1) + # col[0] += 1 + # A_l_inv = la.toeplitz(col / (1 - dt / 4), np.zeros(N)) + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + if measure == 'tlagt': + # beta = 1 corresponds to no tilt + # b = measure_args['beta'] + b = measure_args.get('beta', 1.0) + A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.)**(i-j+1)) * R + B = (-1.)**Q[:, None] * R + + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + + return A, B + + +def slo(input, N, d_t=1.0, method='trapezoidal'): + q = np.arange(N) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T) + # d, V = np.linalg.eig(A) + # d, V = d[::-1], V[:, ::-1] + c = np.zeros(N, dtype=np.float64) + c[0] = input[0] + for t in range(1, input.shape[0]): + At = A / t + Bt = B / t + u = input[t] + if method == 'euler' or method == 'forward_diff': + c = (np.eye(N) + d_t * At) @ c + d_t * Bt * u + elif method == 'backward_diff' or method == 'backward_euler': + c = la.solve_triangular(np.eye(N) - d_t * At, c + d_t * Bt * u, lower=True) + elif method == 'bilinear' or method == 'tustin' or method == 'trapezoidal': + c = la.solve_triangular(np.eye(N) - d_t / 2 * At, (np.eye(N) + d_t / 2 * At) @ c + d_t * Bt * u, lower=True) + elif method == 'zoh': + # aa, bb, _, _, _ = signal.cont2discrete((A, B[:, None], np.ones((1, N)), np.zeros((1,))), dt=math.log(t + d_t) - math.log(t), method='zoh') + # bb = bb.squeeze(-1) + aa = la.expm(A * (math.log(t + d_t) - math.log(t))) + bb = la.solve_triangular(A, aa @ B - B, lower=True) + c = aa @ c + bb * f(t) + else: + assert False, f'method {method} not supported' + # f_approx = (c @ (T @ ss.eval_legendre(np.arange(N)[:, None], 2 * t_vals / T_max - 1))) + return c + + +class LegSTest(unittest.TestCase): + + def setUp(self): + self.rtol = 10 + self.atol = 1e-3 + + def test_legs_euler_forward_cpu(self): + batch_size = 10 + memsize = 23 + memorder = 587 + dt = 0.27 + # batch_size = 1 + # memsize = 1 + # memorder = 5 + # dt = 0.5 + A, B = transition('legs', memorder) + A = torch.Tensor(A) + B = torch.Tensor(B).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + out = hippo.legs_euler_forward(x, input, dt) + out_torch = x + dt * F.linear(x, A) + dt * input.unsqueeze(-1) * B + out_double = x.double() + dt * F.linear(x.double(), A.double()) + dt * input.unsqueeze(-1).double() * B.double() + err = (out - out_double).abs().max().item() + err_torch = (out_torch - out_double).abs().max().item() + # print(out_double) + print((out - out_double).abs().max().item()) + print((out_torch - out_double).abs().max().item()) + self.assertTrue(err <= err_torch * (1 + self.rtol) + self.atol, + ((out - out_torch).abs().max().item())) + + def test_legs_euler_backward_cpu(self): + batch_size = 10 + memsize = 23 + memorder = 587 + dt = 0.27 + # batch_size = 1 + # memsize = 1 + # memorder = 5 + # dt = 0.5 + A, B = transition('legs', memorder) + A_inv = la.solve_triangular(np.eye(memorder) - dt * A, np.eye(memorder), lower=True) + B_inv = la.solve_triangular(np.eye(memorder) - dt * A, B, lower=True) + A_inv = torch.Tensor(A_inv) + B_inv = torch.Tensor(B_inv).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + out = hippo.legs_euler_backward(x, input, dt) + out_torch = F.linear(x, A_inv) + dt * input.unsqueeze(-1) * B_inv + out_double = F.linear(x.double(), A_inv.double()) + dt * input.unsqueeze(-1).double() * B_inv.double() + err = (out - out_double).abs().max().item() + err_torch = (out_torch - out_double).abs().max().item() + # print(out_double) + print((out - out_double).abs().max().item()) + print((out_torch - out_double).abs().max().item()) + self.assertTrue(err <= err_torch * (1 + self.rtol) + self.atol, + ((out - out_torch).abs().max().item())) + + def test_legs_trapezoidal_cpu(self): + batch_size = 10 + memsize = 23 + memorder = 587 + dt = 0.27 + # batch_size = 1 + # memsize = 1 + # memorder = 5 + # dt = 0.5 + A, B = transition('legs', memorder) + trap_A_inv = la.solve_triangular(np.eye(memorder) - dt / 2 * A, np.eye(memorder) + dt / 2 * A, lower=True) + trap_A_inv = torch.Tensor(trap_A_inv) + trap_B_inv = la.solve_triangular(np.eye(memorder) - dt / 2 * A, B, lower=True) + trap_B_inv = torch.Tensor(trap_B_inv).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + out = hippo.legs_trapezoidal(x, input, dt) + out_torch = F.linear(x, trap_A_inv) + dt * input.unsqueeze(-1) * trap_B_inv + out_double = F.linear(x.double(), trap_A_inv.double()) + dt * input.unsqueeze(-1).double() * trap_B_inv.double() + err = (out - out_double).abs().max().item() + err_torch = (out_torch - out_double).abs().max().item() + # print(out_double) + print((out - out_double).abs().max().item()) + print((out_torch - out_double).abs().max().item()) + self.assertTrue(err <= err_torch * (1 + self.rtol) + self.atol, + ((out - out_torch).abs().max().item())) + + def test_function_approx(self): + length = int(1e3) + memorder = 256 + input = torch.randn(length, dtype=torch.float64) + mem = hippo.legs_function_approx_trapezoidal(input, memorder) + mem_np = torch.Tensor(slo(input.cpu().numpy().astype(np.float64), memorder)).double() + self.assertTrue(torch.allclose(mem, mem_np)) + + +def timeit(fn, nsteps): + import time + fn() + start = time.perf_counter() + for _ in range(nsteps): + fn() + end = time.perf_counter() + return (end - start) / nsteps + + +def benchmark(): + torch.set_num_threads(1) + batch_size = 1 + memsize = 1 + memorder = 256 + dt = 0.27 + A, B = transition('legs', memorder) + A_inv = la.solve_triangular(np.eye(memorder) - dt * A, np.eye(memorder), lower=True) + B_inv = la.solve_triangular(np.eye(memorder) - dt * A, B, lower=True) + A_inv = torch.Tensor(A_inv) + B_inv = torch.Tensor(B_inv).squeeze(-1) + trap_A_inv = la.solve_triangular(np.eye(memorder) - dt / 2 * A, np.eye(memorder) + dt / 2 * A, lower=True) + trap_A_inv = torch.Tensor(trap_A_inv) + trap_B_inv = la.solve_triangular(np.eye(memorder) - dt / 2 * A, B, lower=True) + trap_B_inv = torch.Tensor(trap_B_inv).squeeze(-1) + A = torch.Tensor(A) + B = torch.Tensor(B).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + nsteps = 10000 + euler_forward_fn = lambda: hippo.legs_euler_forward(x, input, dt) + euler_forward_torch_fn = lambda: x + dt * F.linear(x, A) + dt * input.unsqueeze(-1) * B + euler_backward_fn = lambda: hippo.legs_euler_backward(x, input, dt) + euler_backward_torch_fn = lambda: F.linear(x, A_inv) + dt * input.unsqueeze(-1) * B_inv + trapezoidal_fn = lambda: hippo.legs_trapezoidal(x, input, dt) + trapezoidal_torch_fn = lambda: F.linear(x, trap_A_inv) + dt * input.unsqueeze(-1) * trap_B_inv + print(f'Euler forward C++: {timeit(euler_forward_fn, nsteps)}s') + print(f'Euler backward C++: {timeit(euler_backward_fn, nsteps)}s') + print(f'Trapezoidal C++: {timeit(trapezoidal_fn, nsteps)}s') + print(f'Euler forward Pytorch: {timeit(euler_forward_torch_fn, nsteps)}s') + print(f'Euler backward Pytorch: {timeit(euler_backward_torch_fn, nsteps)}s') + print(f'Trapezoidal Pytorch: {timeit(trapezoidal_torch_fn, nsteps)}s') + + length = int(1e6) + input = torch.randn(length, dtype=torch.float64) + trap_func_approx_fn = lambda: hippo.legs_function_approx_trapezoidal(input, memorder) + nsteps = 1 + print(f'Function approx trapezoidal C++: {timeit(trap_func_approx_fn, nsteps)}s') + + +if __name__ == "__main__": + benchmark() diff --git a/tests/test_legt_extension.py b/tests/test_legt_extension.py new file mode 100644 index 0000000..81cf3b8 --- /dev/null +++ b/tests/test_legt_extension.py @@ -0,0 +1,109 @@ +import math +import unittest + +import numpy as np +from scipy import linalg as la + +import torch +import torch.nn.functional as F +import hippo + +# from .op import transition + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures """ + if measure == 'lagt': + # A_l = (1 - dt / 4) * np.eye(N) + dt / 2 * np.tril(np.ones((N, N))) + # A_r = (1 + dt / 4) * np.eye(N) - dt / 2 * np.tril(np.ones((N, N))) + # alpha = dt / 2 / (1 - dt / 4) + # col = -alpha / (1 + alpha) ** np.arange(1, N + 1) + # col[0] += 1 + # A_l_inv = la.toeplitz(col / (1 - dt / 4), np.zeros(N)) + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + if measure == 'tlagt': + # beta = 1 corresponds to no tilt + # b = measure_args['beta'] + b = measure_args.get('beta', 1.0) + A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.)**(i-j+1)) * R + B = (-1.)**Q[:, None] * R + + elif measure == 'legt': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + + return A, B + + +class LegtTest(unittest.TestCase): + + def setUp(self): + self.rtol = 10 + self.atol = 1e-3 + + def test_legt_euler_forward_cpu(self): + batch_size = 10 + memsize = 23 + memorder = 587 + dt = 0.27 + # batch_size = 1 + # memsize = 1 + # memorder = 5 + # dt = 0.5 + A, B = transition('legt', memorder) + A = torch.Tensor(A) + B = torch.Tensor(B).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + out = hippo.legt_euler_forward(x, input, dt) + out_torch = x + dt * F.linear(x, A) + dt * input.unsqueeze(-1) * B + out_double = x.double() + dt * F.linear(x.double(), A.double()) + dt * input.unsqueeze(-1).double() * B.double() + err = (out - out_double).abs().max().item() + err_torch = (out_torch - out_double).abs().max().item() + # print(out_double) + print((out - out_double).abs().max().item()) + print((out_torch - out_double).abs().max().item()) + self.assertTrue(err <= err_torch * (1 + self.rtol) + self.atol, + ((out - out_torch).abs().max().item())) + +def timeit(fn, nsteps): + import time + fn() + start = time.perf_counter() + for _ in range(nsteps): + fn() + end = time.perf_counter() + return (end - start) / nsteps + + +def benchmark(): + torch.set_num_threads(1) + batch_size = 1 + memsize = 1 + memorder = 256 + dt = 0.27 + A, B = transition('legt', memorder) + A = torch.Tensor(A) + B = torch.Tensor(B).squeeze(-1) + x = torch.randn(batch_size, memsize, memorder) + input = torch.randn(batch_size, memsize) + nsteps = 10000 + euler_forward_fn = lambda: hippo.legt_euler_forward(x, input, dt) + euler_forward_torch_fn = lambda: x + dt * F.linear(x, A) + dt * input.unsqueeze(-1) * B + print(f'Euler forward C++: {timeit(euler_forward_fn, nsteps)}s') + print(f'Euler forward Pytorch: {timeit(euler_forward_torch_fn, nsteps)}s') + +if __name__ == "__main__": + benchmark() diff --git a/train.py b/train.py new file mode 100644 index 0000000..58c8ad2 --- /dev/null +++ b/train.py @@ -0,0 +1,124 @@ +from pathlib import Path +project_root = Path(__file__).parent.absolute() +import os +# Add to $PYTHONPATH so that ray workers can see +os.environ['PYTHONPATH'] = str(project_root) + ":" + os.environ.get('PYTHONPATH', '') + +import numpy as np +import torch +import pytorch_lightning as pl + +import hydra +from omegaconf import OmegaConf + +from model.model import Model +from datasets import DatasetBase +from model.exprnn.parametrization import get_parameters +from utils import to_scalar + + +class RNNTraining(pl.LightningModule): + + def __init__(self, model_args, dataset_cfg, train_args): + super().__init__() + self.save_hyperparameters() + self.dataset_cfg = dataset_cfg + self.dataset = DatasetBase.registry[dataset_cfg.name](dataset_cfg) + self.train_args = train_args + self.model_args = model_args + self.model = Model( + self.dataset.input_size, + self.dataset.output_size, + output_len=self.dataset.output_len, + **model_args, + ) + + def forward(self, input): + self.model.forward(input) + + def training_step(self, batch, batch_idx): + batch_x, batch_y, *len_batch = batch + # Either fixed length sequence or variable length + len_batch = None if not len_batch else len_batch[0] + out = self.model(batch_x, len_batch) + loss = self.dataset.loss(out, batch_y, len_batch) + metrics = self.dataset.metrics(out, batch_y) + return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y, + 'progress_bar': metrics, 'log': metrics} + + def training_epoch_end(self, outputs, prefix='train'): + losses = torch.stack([output['loss'] for output in outputs]) + sizes = torch.tensor([output['size'] for output in outputs], device=losses.device) + loss_mean = (losses * sizes).sum() / sizes.sum() + outs = [output['out'] for output in outputs] + targets = [output['target'] for output in outputs] + metrics = self.dataset.metrics_epoch(outs, targets) + metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} + results = {f'{prefix}_loss': loss_mean, **metrics} + results_scalar = {k: to_scalar(v) for k, v in results.items()} # PL prefers torch.Tensor while we prefer float + setattr(self, f'_{prefix}_results', results_scalar) + if getattr(self.train_args, 'verbose', False): + print(f'{prefix} set results:', results_scalar) + return {f'{prefix}_loss': loss_mean, 'log': results} + + def validation_step(self, batch, batch_idx): + batch_x, batch_y, *len_batch = batch + # Either fixed length sequence or variable length + len_batch = None if not len_batch else len_batch[0] + out = self.model(batch_x, len_batch) + loss = self.dataset.loss(out, batch_y, len_batch) + metrics = self.dataset.metrics(out, batch_y) + return {'size': batch_x.shape[0], 'loss': loss, 'out': out, 'target': batch_y, **metrics} + + def validation_epoch_end(self, outputs): + return self.training_epoch_end(outputs, prefix='val') + + def test_step(self, batch, batch_idx): + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + return self.training_epoch_end(outputs, prefix='test') + + def configure_optimizers(self): + name_to_opt = {'adam': torch.optim.Adam, 'rmsprop': torch.optim.RMSprop} + optimizer = name_to_opt[self.train_args.optimizer] + if self.model_args.cell == 'exprnn' or self.model_args.cell_args.get('orthogonal', False): + non_orth_params, log_orth_params = get_parameters(self.model) + return optimizer([ + {'params': non_orth_params, 'lr': self.train_args.lr, 'weight_decay': self.train_args.wd}, + {'params': log_orth_params, 'lr': self.train_args.lr/10.0}, + ]) + else: + return optimizer(self.model.parameters(), lr=self.train_args.lr) + + def prepare_data(self): + self.dataset.prepare_data() + kwargs = {'num_workers': self.dataset_cfg.num_workers, 'pin_memory': True} + self.dataset.prepare_dataloader(self.train_args.batch_size, **kwargs) + + def train_dataloader(self): + return self.dataset.train_loader + + def val_dataloader(self): + return self.dataset.val_loader + + def test_dataloader(self): + return self.dataset.test_loader + + +@hydra.main(config_path="cfg/config.yaml", strict=False) +def main(cfg: OmegaConf): + print(cfg.pretty()) + if cfg.runner.name == 'pl': + from pl_runner import pl_train + trainer, model = pl_train(cfg, RNNTraining) + elif cfg.runner.name == 'ray': + # Shouldn't need to install ray unless doing distributed training + from ray_runner import ray_train + ray_train(cfg, RNNTraining) + else: + assert False, 'Only pl and ray runners are supported' + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..93a0298 --- /dev/null +++ b/utils.py @@ -0,0 +1,22 @@ +import torch + +from omegaconf.dictconfig import DictConfig +from munch import Munch + + +# pytorch-lightning returns pytorch 0-dim tensor instead of python scalar +def to_scalar(x): + return x.item() if isinstance(x, torch.Tensor) else x + + +def dictconfig_to_munch(d): + """Convert object of type OmegaConf to Munch so Wandb can log properly + Support nested dictionary. + """ + return Munch({k: dictconfig_to_munch(v) if isinstance(v, DictConfig) + else v for k, v in d.items()}) + + +def munch_to_dictconfig(m): + return DictConfig({k: munch_to_dictconfig(v) if isinstance(v, Munch) + else v for k, v in m.items()})