Skip to content

Commit

Permalink
Merge pull request #31 from wsmoses/vc/capi
Browse files Browse the repository at this point in the history
Use Enzyme C-API
  • Loading branch information
vchuravy authored Dec 25, 2020
2 parents 3a62aab + 72d5ad0 commit 93d2e29
Show file tree
Hide file tree
Showing 9 changed files with 556 additions and 149 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]
version = "0.2.2"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand All @@ -12,7 +13,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[compat]
Cassette = "0.3"
Enzyme_jll = "0.0.3"
Enzyme_jll = "0.0.5"
GPUCompiler = "0.8, 0.9"
LLVM = "3.2"
julia = "1.5"
20 changes: 18 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export Const, Active, Duplicated

using Cassette

abstract type Annotation{T} end
abstract type Annotation{T} end
struct Const{T} <: Annotation{T}
val::T
end
Expand All @@ -24,6 +24,11 @@ end

Base.eltype(::Type{<:Annotation{T}}) where T = T

import LLVM

include("api.jl")
include("typeanalysis.jl")
include("typetree.jl")
include("utils.jl")
include("compiler.jl")

Expand All @@ -44,7 +49,7 @@ end

import .Compiler: EnzymeCtx
# Ops that have intrinsics
for op in (sin, cos, exp)
for op in (sin, cos, tan, exp)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
Expand All @@ -55,6 +60,17 @@ for op in (sin, cos, exp)
end
end

for op in (copysign,)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T, y::$T)
ccall($llvmf, llvmcall, $T, ($T, $T), x, y)
end
end
end
end

for op in (asin,)
for (T, llvm_t) in ((Float32, "float"), (Float64, "double"))
decl = "declare double @$(nameof(op))($llvm_t)"
Expand Down
150 changes: 150 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
module API

import LLVM.API: LLVMValueRef, LLVMModuleRef, LLVMTypeRef, LLVMContextRef
using Enzyme_jll
using Libdl
using CEnum

struct EnzymeAAResultsRef
a::Ptr{Cvoid}
b::Ptr{Cvoid}
c::Ptr{Cvoid}
end
const EnzymeTypeAnalysisRef = Ptr{Cvoid}
const EnzymeAugmentedReturnPtr = Ptr{Cvoid}

struct IntList
data::Ptr{Int64}
size::Csize_t
end
IntList() = IntList(Ptr{Int64}(0),0)

@cenum(CConcreteType,
DT_Anything = 0,
DT_Integer = 1,
DT_Pointer = 2,
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
DT_Unknown = 6
)


struct EnzymeTypeTree end
const CTypeTreeRef = Ptr{EnzymeTypeTree}

EnzymeNewTypeTree() = ccall((:EnzymeNewTypeTree, libEnzyme), CTypeTreeRef, ())
EnzymeNewTypeTreeCT(T, ctx) = ccall((:EnzymeNewTypeTreeCT, libEnzyme), CTypeTreeRef, (CConcreteType, LLVMContextRef), T, ctx)
EnzymeNewTypeTreeTR(tt) = ccall((:EnzymeNewTypeTreeTR, libEnzyme), CTypeTreeRef, (CTypeTreeRef,), tt)

EnzymeFreeTypeTree(tt) = ccall((:EnzymeFreeTypeTree, libEnzyme), Cvoid, (CTypeTreeRef,), tt)
EnzymeSetTypeTree(dst, src) = ccall((:EnzymeSetTypeTree, libEnzyme), Cvoid, (CTypeTreeRef, CTypeTreeRef), dst, src)
EnzymeMergeTypeTree(dst, src) = ccall((:EnzymeMergeTypeTree, libEnzyme), Cvoid, (CTypeTreeRef, CTypeTreeRef), dst, src)
EnzymeTypeTreeOnlyEq(dst, x) = ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x)
EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) =
ccall((:EnzymeTypeTreeShiftIndiciesEq, libEnzyme), Cvoid, (CTypeTreeRef, Cstring, Int64, Int64, UInt64),
dst, dl, offset, maxSize, addOffset)

struct CFnTypeInfo
arguments::Ptr{CTypeTreeRef}
ret::CTypeTreeRef

known_values::Ptr{IntList}
end

@cenum(CDIFFE_TYPE,
DFT_OUT_DIFF = 0, # add differential to an output struct
DFT_DUP_ARG = 1, # duplicate the argument and store differential inside
DFT_CONSTANT = 2, # no differential
DFT_DUP_NONEED = 3 # duplicate this argument and store differential inside,
# but don't need the forward
)


function EnzymeGetGlobalAA(mod)
ccall((:EnzymeGetGlobalAA, libEnzyme), EnzymeAAResultsRef, (LLVMModuleRef,), mod)
end

function EnzymeFreeGlobalAA(aa)
ccall((:EnzymeFreeGlobalAA, libEnzyme), Cvoid, (EnzymeAAResultsRef,), aa)
end

# Create the derivative function itself.
# \p todiff is the function to differentiate
# \p retType is the activity info of the return
# \p constant_args is the activity info of the arguments
# \p returnValue is whether the primal's return should also be returned
# \p dretUsed is whether the shadow return value should also be returned
# \p additionalArg is the type (or null) of an additional type in the signature
# to hold the tape.
# \p typeInfo is the type info information about the calling context
# \p _uncacheable_args marks whether an argument may be rewritten before loads in
# the generated function (and thus cannot be cached).
# \p augmented is the data structure created by prior call to an augmented forward
# pass
# \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way
# \p PostOpt is whether to perform basic optimization of the function after synthesis
function EnzymeCreatePrimalAndGradient(todiff, retType, constant_args, TA, global_AA,
returnValue, dretUsed, topLevel, additionalArg, typeInfo,
uncacheable_args, augmented, atomicAdd, postOpt)
ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef,
(LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, EnzymeTypeAnalysisRef,
EnzymeAAResultsRef, UInt8, UInt8, UInt8, LLVMTypeRef, CFnTypeInfo,
Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8, UInt8),
todiff, retType, constant_args, length(constant_args), TA, global_AA, returnValue,
dretUsed, topLevel, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args),
augmented, atomicAdd, postOpt)
end

# Create an augmented forward pass.
# \p todiff is the function to differentiate
# \p retType is the activity info of the return
# \p constant_args is the activity info of the arguments
# \p returnUsed is whether the primal's return should also be returned
# \p typeInfo is the type info information about the calling context
# \p _uncacheable_args marks whether an argument may be rewritten before loads in
# the generated function (and thus cannot be cached).
# \p forceAnonymousTape forces the tape to be an i8* rather than the true tape structure
# \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way
# \p PostOpt is whether to perform basic optimization of the function after synthesis
function EnzymeCreateAugmentedPrimal(todiff, retType, constant_args, TA, global_AA, returnUsed,
typeInfo, uncacheable_args, forceAnonymousTape, atomicAdd, postOpt)
ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr,
(LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t,
EnzymeTypeAnalysisRef, EnzymeAAResultsRef, UInt8,
CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, UInt8),
todiff, retType, constant_args, length(constant_args), TA, global_AA, returnUsed,
typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, atomicAdd, postOpt)
end

# typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/,
# CTypeTree * /*args*/, size_t /*numArgs*/,
# LLVMValueRef);
const CustomRuleType = Ptr{Cvoid}

function CreateTypeAnalysis(triple, rulenames, rules)
@assert length(rulenames) == length(rules)
ccall((:CreateTypeAnalysis, libEnzyme), EnzymeTypeAnalysisRef, (Cstring, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), triple, rulenames, rules, length(rules))
end

function FreeTypeAnalysis(ta)
ccall((:FreeTypeAnalysis, libEnzyme), Cvoid, (EnzymeAAResultsRef,), ta)
end

function EnzymeExtractReturnInfo(ret, data, existed)
@assert length(data) == length(existed)
ccall((:EnzymeExtractReturnInfo, libEnzyme),
Cvoid, (EnzymeAugmentedReturnPtr, Ptr{Int64}, Ptr{UInt8}, Csize_t),
ret, data, existed, length(data))
end

function EnzymeExtractFunctionFromAugmentation(ret)
ccall((:EnzymeExtractFunctionFromAugmentation, libEnzyme), LLVMValueRef, (EnzymeAugmentedReturnPtr,), ret)
end


function EnzymeExtractTapeTypeFromAugmentation(ret)
ccall((:EnzymeExtractTapeTypeFromAugmentation, libEnzyme), LLVMTypeRef, (EnzymeAugmentedReturnPtr,), ret)
end

end
129 changes: 79 additions & 50 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Compiler

import ..Enzyme: Const, Active, Duplicated, DuplicatedNoNeed
import ..Enzyme: API, TypeTree, typetree, TypeAnalysis, FnTypeInfo

using LLVM, GPUCompiler, Libdl
import Enzyme_jll
Expand Down Expand Up @@ -57,7 +58,7 @@ struct EnzymeCompilerParams <: AbstractCompilerParams end

## job

# TODO: We shouldn't blancket opt-out
# TODO: We shouldn't blanket opt-out
GPUCompiler.check_invocation(job::CompilerJob{EnzymeTarget}, entry::LLVM.Function) = nothing

GPUCompiler.runtime_module(target::CompilerJob{EnzymeTarget}) = Runtime
Expand Down Expand Up @@ -90,76 +91,104 @@ Create the `FunctionSpec` pair, and lookup the primal return type.
end


"""
wrapper!(::LLVM.Module, ::LLVM.Function, ::FunctionSpec, ::Type)
Generates a wrapper function that will call `__enzyme_autodiff` on the primal function,
named `enzyme_entry`.
"""
function wrapper!(mod, primalf, adjoint, rt, name = "enzyme_entry")
# create a wrapper function that will call `__enzyme_autodiff`
function enzyme!(mod, primalf, adjoint, rt, split)
ctx = context(mod)
rettype = convert(LLVMType, rt, ctx)
dl = string(LLVM.datalayout(mod))

tt = [adjoint.tt.parameters...,]
params = parameters(primalf)
adjoint_tt = LLVMType[]
for (i, T) in enumerate(tt)
llvmT = llvmtype(params[i])
push!(adjoint_tt, llvmT)
if T <: Duplicated
push!(adjoint_tt, llvmT)
end
end

llvmf = LLVM.Function(mod, name, LLVM.FunctionType(rettype, adjoint_tt))
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0, ctx))
args_activity = API.CDIFFE_TYPE[]
uncacheable_args = Bool[]
args_typeInfo = TypeTree[]
args_known_values = API.IntList[]

# Create the FunctionType and funtion declaration for the intrinsic
pt = LLVM.PointerType(LLVM.Int8Type(ctx))
ftd = LLVM.FunctionType(rettype, LLVMType[pt], vararg = true)
autodiff = LLVM.Function(mod, string("__enzyme_autodiff.", rt), ftd)

params = LLVM.Value[]
llvm_params = parameters(llvmf)
i = 1
for T in tt
if T <: Const
push!(params, MDString("enzyme_const"))
push!(args_activity, API.DFT_CONSTANT)
elseif T <: Active
push!(params, MDString("enzyme_out"))
elseif T <: Duplicated
push!(params, MDString("enzyme_dup"))
push!(params, llvm_params[i])
i += 1
push!(args_activity, API.DFT_OUT_DIFF)
elseif T <: Duplicated
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed
push!(params, MDString("enzyme_dupnoneed"))
push!(params, llvm_params[i])
i += 1
else
push!(args_activity, API.DFT_DUP_NONEED)
else
@assert("illegal annotation type")
end
push!(params, llvm_params[i])
i += 1
typeTree = typetree(T, ctx, dl)
push!(args_typeInfo, typeTree)
if split
push!(uncacheable_args, true)
else
push!(uncacheable_args, false)
end
push!(args_known_values, API.IntList())
end

Builder(ctx) do builder
entry = BasicBlock(llvmf, "entry", ctx)
position!(builder, entry)
# TODO ABI returned
# The return of createprimal and gradient has this ABI
# It returns a struct containing the following values
# If requested, the original return value of the function
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument

if rt <: Integer
retType = API.DFT_CONSTANT
elseif rt <: AbstractFloat
retType = API.DFT_OUT_DIFF
elseif rt == Nothing
retType = API.DFT_CONSTANT
else
error("What even is $rt")
end

tc = bitcast!(builder, primalf, pt)
pushfirst!(params, tc)
TA = TypeAnalysis(triple(mod))
global_AA = API.EnzymeGetGlobalAA(mod)
retTT = typetree(rt, ctx, dl)

val = call!(builder, autodiff, params)
typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values)

ret!(builder, val)
end
if split
augmented = API.EnzymeCreateAugmentedPrimal(
primalf, retType, args_activity, TA, global_AA, #=returnUsed=# true,
typeInfo, uncacheable_args, #=forceAnonymousTape=# false, #=atomicAdd=# false, #=postOpt=# false)

# 2. get new_primalf
augmented_primalf = LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented))

# TODOs:
# 1. Handle mutable or !pointerfree arguments by introducing caching
# + specifically by setting uncacheable_args[i] = true
# 2. Forward tape from augmented primalf to adjoint (as last arg)
# 3. Make creation of augumented primalf vs joint forward and reverse optional

tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented)
data = Array{Int64, 1}(undef, 3)
existed = Array{UInt8, 1}(undef, 3)

return llvmf
API.EnzymeExtractReturnInfo(augmented, data, existed)

adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient(
primalf, retType, args_activity, TA, global_AA,
#=returnValue=#false, #=dretUsed=#false, #=topLevel=#false,
#=additionalArg=#tape, typeInfo,
uncacheable_args, augmented, #=atomicAdd=#false, #=postOpt=#false))
else
adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient(
primalf, retType, args_activity, TA, global_AA,
#=returnValue=#false, #=dretUsed=#false, #=topLevel=#true,
#=additionalArg=#C_NULL, typeInfo,
uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=#false, #=postOpt=#false))
augmented_primalf = nothing
end

API.EnzymeFreeGlobalAA(global_AA)
return adjointf, augmented_primalf
end

include("compiler/thunk.jl")
include("compiler/reflection.jl")
# include("compiler/validation.jl")

end
end
Loading

0 comments on commit 93d2e29

Please sign in to comment.