Cassette API Documentation
Cassette.Context
— TypeContext{N<:Cassette.AbstractContextName,
M<:Any,
P<:Cassette.AbstractPass,
T<:Union{Nothing,Cassette.Tag},
B<:Union{Nothing,Cassette.BindingMetaDictCache},
H<:Union{Nothing,Cassette.DisableHooks}}
A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context
:
julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B,H} where H<:Union{Nothing,DisableHooks}
where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
where P<:AbstractPass
where T<:Union{Nothing,Tag}
where M
Constructors
Given a context type alias named e.g. MyCtx
, an instance of the type can be constructed via:
MyCtx(; metadata = nothing, pass = Cassette.NoPass())
To construct a new context instance using an existing context instance as a template, see the similarcontext
function.
To enable contextual tagging for a given context instance, see the enabletagging
function.
Fields
name::N<:Cassette.AbstractContextName
: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishesMyCtx
from otherContext
type aliases).metadata::M<:Any
: trace-local metadata as provided to the context constructortag::T<:Union{Nothing,Tag}
: the tag object that is attached to values when they are tagged w.r.t. the context instancepass::P<:Cassette.AbstractPass
: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the@pass
macro for details).bindingscache::B<:Union{Nothing,BindingMetaDictCache}}
: storage for metadata associated with tagged module bindingshooktoggle::H<:Union{Nothing,DisableHooks}
: configuration toggle for disabling theoverdub
pass'sprehook
/posthook
injection (seedisablehooks
for details)
Cassette.similarcontext
— Functionsimilarcontext(context::Context;
metadata = context.metadata,
pass = context.pass)
Return a copy of the given context
, where the copy's metadata
and/or pass
fields are replaced with those provided via the corresponding keyword arguments.
Cassette.disablehooks
— Functiondisablehooks(context::Cassette.Context)
Return of copy of the given context
with prehook
/posthook
injection disabled for the context. Disabling hook injection can reduce IR bloat in scenarios where these hooks are not being utilized.
Cassette.enabletagging
— Functionenabletagging(context::Cassette.Context, f)
Return a copy of the given context
with the tagging system enabled for the contextual execution of f
.
Cassette uses the type of f
to generate the tag
field of the returned instance.
Note that it is generally unsafe to use the returned instance to contextually execute functions other than f
. Specifically, in cases of nested contextual execution where both inner and outer contexts employ the tagging system, improper application of the tagged system could cause (for example) separate contexts to erroneously interfere with each other's metadata propagation.
See also: hastagging
Cassette.hastagging
— Functionhastagging(::Type{<:Cassette.Context})
Returns true
if the given type indicates that the contextual tagging system is enabled for context instances of the type, returns false
otherwise.
Example
julia> Cassette.@context MyCtx;
julia> ctx = MyCtx();
julia> Cassette.hastagging(typeof(ctx))
false
julia> ctx = Cassette.enabletagging(ctx, sum);
julia> Cassette.hastagging(typeof(ctx))
true
See also: enabletagging
Cassette.@context
— MacroCassette.@context Ctx
Define a new Cassette context type with the name Ctx
. In reality, Ctx
is simply a type alias for Cassette.Context{Cassette.nametype(Ctx)}
.
Note that Cassette.overdub
is automatically overloaded w.r.t. Ctx
to define several primitives by default. A full list of these default primitives can be obtained by running:
methods(Cassette.overdub, (Ctx, Vararg{Any}))
Note also that many of the default primitives' signatures only match when contextual tagging is enabled.
See also: Context
Cassette.overdub
— Functionoverdub(context::Context, f, args...)
Execute f(args...)
overdubbed with respect to context
.
More specifically, execute f(args...)
, but with every internal method invocation g(x...)
replaced by statements similar to the following:
begin
prehook(context, g, x...)
overdub(context, g, x...) # %n
posthook(context, %n, g, x...)
%n
end
Otherwise, if Cassette cannot retrieve lowered IR for the method body of f(args...)
, then fallback(context, f, args...)
will be called instead. Cassette's canrecurse
function is a useful utility for checking if this will occur.
If the injected prehook
/posthook
statements are not needed for your use case, you can disable their injection via the disablehooks
function.
Additionally, for every method body encountered in the execution trace, apply the compiler pass associated with context
if one exists. Note that this user-provided pass is performed on the method IR before method invocations are transformed into the form specified above. See the @pass
macro for further details.
If Cassette.hastagging(typeof(context))
, then a number of additional passes are run in order to accomodate tagged value propagation:
Expr(:new)
is replaced with a call toCassette.tagged_new
Expr(:splatnew)
is replaced with a call toCassette.tagged_splatnew
- conditional values passed to
Expr(:gotoifnot)
are untagged - arguments to
Expr(:foreigncall)
are untagged - load/stores to external module bindings are intercepted by the tagging system
The default definition of overdub
is to recursively enter the given function and continue overdubbing, but one can interrupt/redirect this recursion by overloading overdub
w.r.t. a given context and/or method signature to define new contextual execution primitives. For example:
julia> using Cassette
julia> Cassette.@context Ctx;
julia> Cassette.overdub(::Ctx, ::typeof(sin), x) = cos(x)
julia> Cassette.overdub(Ctx(), x -> sin(x) + cos(x), 1) == 2 * cos(1)
true
Cassette.@overdub
— MacroCassette.@overdub(ctx, expression)
A convenience macro for executing expression
within the context ctx
. This macro roughly expands to Cassette.recurse(ctx, () -> expression)
.
Cassette.recurse
— Functionrecurse(context::Context, f, args...)
Execute f(args...)
overdubbed with respect to context
.
This method performs exactly the same transformation as the default overdub
transformation, but is not meant to be overloaded. Thus, one can call recurse
to "continue" recursively overdubbing a function when calling overdub
directly on that function might've dispatched to a contextual primitive.
To illustrate why recurse
might be useful, consider the following example which utilizes recurse
as part of a Cassette-based memoization implementation for the classic Fibonacci function:
using Cassette: Cassette, @context, overdub, recurse
fib(x) = x < 3 ? 1 : fib(x - 2) + fib(x - 1)
fibtest(n) = fib(2 * n) + n
@context MemoizeCtx
function Cassette.overdub(ctx::MemoizeCtx, ::typeof(fib), x)
result = get(ctx.metadata, x, 0)
if result === 0
result = recurse(ctx, fib, x)
ctx.metadata[x] = result
end
return result
end
See Cassette's Contextual Dispatch documentation for more details and examples.
Cassette.prehook
— Functionprehook(context::Context, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual prehook for that context.
To understand when/how this method is called, see the documentation for overdub
.
Invoking prehook
is a no-op by default (it immediately returns nothing
).
See also: overdub
, posthook
, recurse
, fallback
Examples
Simple trace logging:
julia> Cassette.@context PrintCtx;
julia> Cassette.prehook(::PrintCtx, f, args...) = println(f, args)
julia> Cassette.overdub(PrintCtx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5
Counting the number of method invocations with one or more arguments of a given type:
julia> mutable struct Count{T}
count::Int
end
julia> Cassette.@context CountCtx;
julia> Cassette.prehook(ctx::CountCtx{Count{T}}, f, arg::T, args::T...) where {T} = (ctx.metadata.count += 1)
# count the number of calls of the form `f(::Float64, ::Float64...)`
julia> ctx = CountCtx(metadata = Count{Float64}(0));
julia> Cassette.overdub(ctx, /, 1, 2)
0.5
julia> ctx.metadata.count
2
Cassette.posthook
— Functionposthook(context::Context, output, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual posthook for that context.
To understand when/how this method is called, see the documentation for overdub
.
Invoking posthook
is a no-op by default (it immediately returns nothing
).
See also: overdub
, prehook
, recurse
, fallback
Examples
Simple trace logging:
julia> Cassette.@context PrintCtx;
julia> Cassette.posthook(::PrintCtx, output, f, args...) = println(output, " = ", f, args)
julia> Cassette.overdub(PrintCtx(), /, 1, 2)
1.0 = sitofp(Float64, 1)
1.0 = Float64(1,)
1.0 = AbstractFloat(1,)
1.0 = float(1,)
2.0 = sitofp(Float64, 2)
2.0 = Float64(2,)
2.0 = AbstractFloat(2,)
2.0 = float(2,)
0.5 = div_float(1.0, 2.0)
0.5 = /(1.0, 2.0)
0.5
Accumulate the sum of all numeric scalar outputs encountered in the trace:
julia> mutable struct Accum
x::Number
end
julia> Cassette.@context AccumCtx;
julia> Cassette.posthook(ctx::AccumCtx{Accum}, out::Number, f, args...) = (ctx.metadata.x += out)
julia> ctx = AccumCtx(metadata = Accum(0));
julia> Cassette.overdub(ctx, /, 1, 2)
0.5
julia> ctx.metadata.x
13.0
Cassette.fallback
— Functionfallback(context::Context, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual execution fallback for that context.
To understand when/how this method is called, see the documentation for overdub
and canrecurse
.
By default, invoking fallback(context, f, args...)
will simply call f(args...)
(with all arguments automatically untagged, if hastagging(typeof(context))
).
See also: canrecurse
, overdub
, recurse
, prehook
, posthook
Cassette.canrecurse
— Functioncanrecurse(context::Context, f, args...)
Return true
if f(args...)
has a lowered IR representation that Cassette can overdub, return false
otherwise.
Alternatively, but equivalently:
Return false
if recurse(context, f, args...)
directly translates to fallback(context, f, args...)
, return true
otherwise.
Note that unlike overdub
, fallback
, etc., this function is not intended to be overloaded.
Cassette.@pass
— MacroCassette.@pass transform
Return a Cassette pass that can be provided to the Context
constructor's pass
keyword argument in order to apply transform
to the lowered IR representations of all methods invoked during contextual execution.
transform
must be a Julia object that is callable with the following signature:
transform(::Type{<:Context}, ::Cassette.Reflection)::Union{Expr,CodeInfo}
If isa(transform(...), Expr)
, then the returned Expr
will be emitted immediately without any additional processing. Otherwise, if isa(transform(...), CodeInfo)
, then the returned CodeInfo
will undergo the rest of Cassette's overdubbing transformation before being emitted from the overdub
generator.
Two special Expr
heads are available to Cassette pass authors that are not normally valid in Julia IR. Expr
s with these heads can be used to interact with the downstream built-in Cassette passes that consume them.
:nooverdub
: Wrap anExpr
with this head value around the first argument in anExpr(:call)
to tell downstream built-in Cassette passes not to overdub that call. For example,Expr(:call, Expr(:nooverdub, GlobalRef(MyModule, :myfunc)), args...)
.:contextslot
: Cassette will replace anyExpr(:contextslot)
with the actualSlotNumber
corresponding to the context object associated with the execution trace. For example, one could construct an IR element that accesses the context'smetadata
field by emitting:Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))
Cassette provides a few IR-munging utility functions of interest to pass authors; for details, see insert_statements!
, replace_match!
, and is_ir_element
.
Note that the @pass
macro expands to an eval
call and thus should only be called at top-level. Furthermore, to avoid world-age issues, transform
should not be overloaded after it has been registered with @pass
.
Note also that transform
should be "relatively pure." More specifically, Julia's compiler has license to apply transform
multiple times, even if only compiling a single method invocation once. Thus, it is required that transform
always return a generally "equivalent" CodeInfo
for a given context, method body, and signature. If your transform
implementation is not naturally "pure" in this sense, then it is still possible to guarantee this property by memoizing your implementation (i.e. maintaining a cache of previously computed IR results, instead of recomputing results every time).
Cassette.replace_match!
— Functionreplace_match!(replace, ismatch, x)
Return x
with all subelements y
replaced with replace(y)
if ismatch(y)
. If !ismatch(y)
, but y
is of type Expr
, Array
, or SubArray
, then replace y
in x
with replace_match!(replace, ismatch, y)
.
Generally, x
should be of type Expr
, Array
, or SubArray
.
Note that this function modifies x
(and potentially its subelements) in-place.
See also: insert_statements!
, is_ir_element
Cassette.insert_statements!
— Functioninsert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)
For every statement stmt
at position i
in code
for which stmtcount(stmt, i)
returns an Int
, remove stmt
, and in its place, insert the statements returned by newstmts(stmt, i)
. If stmtcount(stmt, i)
returns nothing
, leave stmt
alone.
For every insertion, all downstream SSAValue
s, label indices, etc. are incremented appropriately according to number of inserted statements.
Proper usage of this function dictates that following properties hold true:
code
is expected to be a valid value for thecode
field of aCodeInfo
object.codelocs
is expected to be a valid value for thecodelocs
field of aCodeInfo
object.newstmts(stmt, i)
should return aVector
of valid IR statements.stmtcount
andnewstmts
must obeystmtcount(stmt, i) == length(newstmts(stmt, i))
ifisa(stmtcount(stmt, i), Int)
.
To gain a mental model for this function's behavior, consider the following scenario. Let's say our code
object contains several statements:
code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]
Let's also say that for our stmtcount
returns 2
for stmtcount(oldstmt2, 2)
, returns 3
for stmtcount(oldstmt5, 5)
, and returns nothing
for all other inputs. From this setup, we can think of code
/codelocs
being modified in the following manner:
newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
newstmts2[1], newstmts2[2],
oldstmt3, oldstmt4,
newstmts5[1], newstmts5[2], newstmts5[3],
oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]
See also: replace_match!
, is_ir_element
Cassette.is_ir_element
— Functionis_ir_element(x, y, code::Vector)
Return true
if x === y
or if x
is an SSAValue
such that is_ir_element(code[x.id], y, code)
is true
.
See also: replace_match!
, insert_statements!
Cassette.OVERDUB_CONTEXT_NAME
— ConstantCassette.OVERDUB_CONTEXT_NAME
The variable name bound to overdub
's Context
argument in its @generated
method definition.
This binding can be used to manually reference/destructure overdub
arguments within Expr
thunks emitted by user-provided passes.
See also: OVERDUB_ARGUMENTS_NAME
, @pass
, overdub
Cassette.OVERDUB_ARGUMENTS_NAME
— ConstantCassette.OVERDUB_ARGUMENTS_NAME
The variable name bound to overdub
's tuple of non-Context
arguments in its @generated
method definition.
This binding can be used to manually reference/destructure overdub
arguments within Expr
thunks emitted by user-provided passes.
See also: OVERDUB_CONTEXT_NAME
, @pass
, overdub
Cassette.Reflection
— TypeCassette.Reflection
A struct representing the information retrieved via Cassette.reflect
.
A Reflection
is essentially just a convenient bundle of information about a specific method invocation.
Fields
signature
: the invocation signature (inTuple{...}
type form) for the invoked method.method
: theMethod
object associated with the invoked method.static_params
: aVector
representing the invoked method's static parameter list.code_info
: theCodeInfo
object associated with the invoked method.
Cassette.tag
— Functiontag(value, context::Context, metadata = Cassette.NoMetaData())
Return value
tagged w.r.t. context
, optionally associating metadata
with the returned Tagged
instance.
Any provided metadata
must obey the type constraints determined by Cassette's metadatatype
method.
Note that hastagging(typeof(context))
must be true
for a value to be tagged w.r.t. to context
.
See also: untag
, enabletagging
, hastagging
Cassette.untag
— Functionuntag(x, context::Context)
Return x
untagged w.r.t. context
if istagged(x, context)
, otherwise return x
directly.
In other words, untag(tag(x, context), context) === x
is always true
.
If !istagged(x, context)
, then untag(x, context) === x
is true
.
Cassette.untagtype
— Functionuntagtype(::Type{T}, ::Type{C<:Context})
Return typeof(untag(::T, ::C))
.
In other words, untagtype(typeof(tag(x, context)), typeof(context)) === typeof(x)
is always true
.
If !istaggedtype(T, C)
, then untagtype(T, C) === T
is true
.
Cassette.metadata
— Functionmetadata(x, context::Context)
Return the metadata
attached to x
if hasmetadata(x, context)
, otherwise return Cassette.NoMetaData()
.
In other words, metadata(tag(x, context, m), context) === m
is always true
.
If !hasmetadata(x, context)
, then metadata(x, context) === Cassette.NoMetaData()
is true
.
Cassette.metadatatype
— Functionmetadatatype(::Type{<:Context}, ::Type{T})
Overload this Cassette method w.r.t. a given context to define the type of metadata that can be tagged to values of type T
within that context.
By default, this method is set such that associating metadata with any tagged value is disallowed.
Cassette uses metadatatype
to statically compute a context-specific metadata type hiearchy for all tagged values within overdubbed programs. To gain a mental model for this mechanism, consider a simple struct definition as follows:
struct Foo
x::Int
y::Complex{Int}
end
Now, Cassette can use metadatatype
to determine type constraints for metadata structures associated with tagged values of type Foo
. In psuedo-Julia-code, these metadata structures might look something like the following for Foo
:
struct IntMeta
data::metadatatype(Ctx, Int)
meta::Cassette.NoMetaMeta
end
struct ComplexIntMeta
data::metadatatype(Ctx, Complex{Int})
meta::NamedTuple{(:re,:im),Tuple{IntMeta,IntMeta}}
end
struct FooMeta
data::metadatatype(Ctx, Foo)
meta::NamedTuple{(:x,:y),Tuple{IntMeta,ComplexIntMeta}
end
Examples
julia> Cassette.@context Ctx;
# any value of type `Number` can now be tagged with metadata of type `Number`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{<:Number}) = Number
# any value of type `T<:Number` can now be tagged with metadata of type `T`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{T}) where {T<:Number} = T
# any value of type `T<:Number` can now be tagged with metadata of type `promote_type(T, M)`
# where `M` is the type of the trace-local metadata associated with the context
julia> Cassette.metadatatype(::Type{<:Ctx{M}}, ::Type{T}) where {M<:Number,T<:Number} = promote_type(T, M)
Cassette.hasmetadata
— Functionhasmetadata(x, context::Context)
Return true
if !isa(metadata(x, context), Cassette.NoMetaData)
, return false
otherwise.
In other words, hasmetadata(tag(x, context, m), context)
is always true
and hasmetadata(tag(x, context), context)
is always false
.
See also: metadata
Cassette.istagged
— Functionistagged(x, context::Context)
Return true
if x
is tagged w.r.t. context
, return false
otherwise.
In other words, istagged(tag(x, context), context)
is always true
.
See also: tag
, istaggedtype
Cassette.istaggedtype
— Functionistaggedtype(::Type{T}, ::Type{C<:Context})
Return typeof(istagged(::T, ::C))
.
In other words, istaggedtype(typeof(tag(x, context)), typeof(context))
is always true
.