diff --git a/src/compiler.jl b/src/compiler.jl index 33f14424a6..9f7ca6c373 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3361,20 +3361,38 @@ struct EnzymeCacheToken always_inline method_table::Core.MethodTable param_type::Type - mode::API.CDerivativeMode + is_fwd::API.CDerivativeMode end GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = EnzymeCacheToken( typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), - typeof(job.config.params), job.config.params.mode, + typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode, ) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) else + +# the codeinstance cache to use -- should only be used for the constructor +# Note that the only way the interpreter modifies codegen is either not inlining a fwd mode +# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. +const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() +const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() +function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) + return if job.config.params.mode == API.DEM_ForwardMode + GLOBAL_FWD_CACHE + else + GLOBAL_REV_CACHE + end +end + +@static if VERSION < v"1.8" +GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) +end + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) end include("compiler/passes.jl") @@ -6952,7 +6970,7 @@ end run_enzyme = false Const else - A + A end if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World)