Cassette API Documentation
Cassette.Context
— Type.Context{N<:Cassette.AbstractContextName,
M<:Any,
P<:Cassette.AbstractPass,
T<:Union{Nothing,Cassette.Tag},
B<:Union{Nothing,Cassette.BindingMetaDictCache}}
A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context
:
julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B} where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
where P<:Cassette.AbstractPass
where T<:Union{Nothing,Tag}
where M
Constructors
Given a context type alias named e.g. MyCtx
, an instance of the type can be constructed via:
MyCtx(; metadata = nothing, pass = Cassette.NoPass())
To construct a new context instance using an existing context instance as a template, see the similarcontext
function.
To enable contextual tagging for a given context instance, see the enabletagging
function.
Fields
name::N<:Cassette.AbstractContextName
: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishesMyCtx
from otherContext
type aliases).metadata::M<:Any
: trace-local metadata as provided to the context constructorpass::P<:Cassette.AbstractPass
: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the@pass
macro for details).tag::T<:Union{Nothing,Tag}
: the tag object that is attached to values when they are tagged w.r.t. the context instancebindingscache::B<:Union{Nothing,BindingMetaDictCache}}
: storage for metadata associated with tagged module bindings
Cassette.similarcontext
— Function.similarcontext(context::Context;
metadata = context.metadata,
pass = context.pass,
tag = context.tag,
bindingscache = context.bindingscache)
Return a copy of the given context
, replacing field values in the returned instance with those provided via the keyword arguments.
Cassette.enabletagging
— Function.enabletagging(context::Cassette.Context, f)
Return a copy of the given context
with the tagging system enabled for the contextual execution of f
.
Cassette uses the type of f
to generate the tag
field of the returned instance.
Note that it is generally unsafe to use the returned instance to contextually execute functions other than f
. Specifically, in cases of nested contextual execution where both inner and outer contexts employ the tagging system, improper application of the tagged system could cause (for example) separate contexts to erroneously interfere with each other's metadata propagation.
See also: hastagging
Cassette.hastagging
— Function.hastagging(::Type{<:Cassette.Context})
Returns true
if the given type indicates that the contextual tagging system is enabled for context instances of the type, returns false
otherwise.
Example
julia> Cassette.@context MyCtx;
julia> ctx = MyCtx();
julia> Cassette.hastagging(typeof(ctx))
false
julia> ctx = Cassette.enabletagging(ctx, sum);
julia> Cassette.hastagging(typeof(ctx))
true
See also: enabletagging
Cassette.@context
— Macro.Cassette.@context Ctx
Define a new Cassette context type with the name Ctx
. In reality, Ctx
is simply a type alias for Cassette.Context{Cassette.nametype(Ctx)}
.
Note that Cassette.execute
is automatically overloaded w.r.t. Ctx
to define several primitives by default. A full list of these default primitives can be obtained by running:
methods(Cassette.execute, (Ctx, Vararg{Any}))
Note also that many of the default primitives' signatures only match when contextual tagging is enabled.
See also: Context
Cassette.overdub
— Function.overdub(context::Context, f, args...)
Execute f(args...)
overdubbed with respect to context
.
More specifically, execute f(args...)
, but with every internal method invocation g(x...)
replaced by statements similar to the following:
begin
prehook(context, g, x...)
tmp = execute(context, g, x...)
tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, g, x...) : tmp
posthook(context, tmp, g, x...)
tmp
end
If Cassette cannot retrieve lowered IR for the method body of f(args...)
(as determined by canoverdub(context, f, args...)
), then overdub(context, f, args...)
will directly translate to a call to fallback(context, f, args...)
.
Additionally, for every method body encountered in execute trace, apply the compiler pass associated with context
if one exists. Note that this user-provided pass is performed on the method IR before method invocations are transformed into the form specified above. See the @pass
macro for further details.
If Cassette.hastagging(typeof(context))
, then a number of additional passes are run in order to accomodate tagged value propagation:
Expr(:new)
is replaced with a call toCassette.tagged_new
- conditional values passed to
Expr(:gotoifnot)
are untagged - arguments to
Expr(:foreigncall)
are untagged - load/stores to external module bindings are intercepted by the tagging system
Cassette.@overdub
— Macro.Cassette.@overdub(ctx, expression)
A convenience macro for executing expression
within the context ctx
. This macro roughly expands to Cassette.overdub(ctx, () -> expression)
.
See also: overdub
Cassette.prehook
— Function.prehook(context::Context, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual prehook for that context.
To understand when/how this method is called, see the documentation for overdub
.
Invoking prehook
is a no-op by default (it immediately returns nothing
).
See also: overdub
, posthook
, execute
, fallback
Examples
Simple trace logging:
julia> Cassette.@context PrintCtx;
julia> Cassette.prehook(::PrintCtx, f, args...) = println(f, args)
julia> Cassette.overdub(PrintCtx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5
Counting the number of method invocations with one or more arguments of a given type:
julia> mutable struct Count{T}
count::Int
end
julia> Cassette.@context CountCtx;
julia> Cassette.prehook(ctx::CountCtx{Count{T}}, f, arg::T, args::T...) where {T} = (ctx.metadata.count += 1)
# count the number of calls of the form `f(::Float64, ::Float64...)`
julia> ctx = CountCtx(metadata = Count{Float64}(0));
julia> Cassette.overdub(ctx, /, 1, 2)
0.5
julia> ctx.metadata.count
2
Cassette.posthook
— Function.posthook(context::Context, output, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual posthook for that context.
To understand when/how this method is called, see the documentation for overdub
.
Invoking posthook
is a no-op by default (it immediately returns nothing
).
See also: overdub
, prehook
, execute
, fallback
Examples
Simple trace logging:
julia> Cassette.@context PrintCtx;
julia> Cassette.posthook(::PrintCtx, output, f, args...) = println(output, " = ", f, args)
julia> Cassette.overdub(PrintCtx(), /, 1, 2)
1.0 = sitofp(Float64, 1)
1.0 = Float64(1,)
1.0 = AbstractFloat(1,)
1.0 = float(1,)
2.0 = sitofp(Float64, 2)
2.0 = Float64(2,)
2.0 = AbstractFloat(2,)
2.0 = float(2,)
0.5 = div_float(1.0, 2.0)
0.5 = /(1.0, 2.0)
0.5
Accumulate the sum of all numeric scalar outputs encountered in the trace:
julia> mutable struct Accum
x::Number
end
julia> Cassette.@context AccumCtx;
julia> Cassette.posthook(ctx::AccumCtx{Accum}, out::Number, f, args...) = (ctx.metadata.x += out)
julia> ctx = AccumCtx(metadata = Accum(0));
julia> Cassette.overdub(ctx, /, 1, 2)
0.5
julia> ctx.metadata.x
13.0
Cassette.execute
— Function.execute(context::Context, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual execution primitive for that context.
To understand when/how this method is called, see the documentation for overdub
.
Invoking execute
immediately returns Cassette.OverdubInstead()
by default.
Cassette.fallback
— Function.fallback(context::Context, f, args...)
Overload this Cassette method w.r.t. a given context in order to define a new contextual execution fallback for that context.
To understand when/how this method is called, see the documentation for overdub
and canoverdub
.
By default, invoking fallback(context, f, args...)
will simply call f(args...)
(with all arguments automatically untagged, if hastagging(typeof(context))
).
See also: canoverdub
, overdub
, execute
, prehook
, posthook
Cassette.canoverdub
— Function.canoverdub(context::Context, f, args...)
Return true
if f(args...)
has a lowered IR representation that Cassette can overdub, return false
otherwise.
Alternatively, but equivalently:
Return false
if overdub(context, f, args...)
directly translates to fallback(context, f, args...)
, return true
otherwise.
Note that unlike execute
, fallback
, etc., this function is not intended to be overloaded.
Cassette.@pass
— Macro.Cassette.@pass transform
Return a Cassette pass that can be provided to the Context
constructor's pass
keyword argument in order to apply transform
to the lowered IR representations of all methods invoked during contextual execution.
transform
must be a Julia object that is callable with the following signature:
transform(::Type{<:Context}, signature::Type{Tuple{...}}, method_body::CodeInfo)::CodeInfo
Note that the @pass
macro expands to an eval
call and thus should only be called at top-level. Furthermore, to avoid world-age issues, transform
should not be overloaded after it has been registered with @pass
.
Note also that transform
should be "relatively pure." More specifically, Julia's compiler has license to apply transform
multiple times, even if only compiling a single method invocation once. Thus, it is required that transform
always return a generally "equivalent" CodeInfo
for a given context, method body, and signature. If your transform
implementation is not naturally "pure" in this sense, then it is still possible to guarantee this property by memoizing your implementation (i.e. maintaining a cache of previously computed IR results, instead of recomputing results every time).
Two special Expr
heads are available to Cassette pass authors that are not normally valid in Julia IR. Expr
s with these heads can be used to interact with the downstream built-in Cassette passes that consume them.
:nooverdub
: Wrap anExpr
with this head value around the first argument in anExpr(:call)
to tell downstream built-in Cassette passes not to overdub that call. For example,Expr(:call, Expr(:nooverdub, GlobalRef(MyModule, :myfunc)), args...)
.:contextslot
: Cassette will replace anyExpr(:contextslot)
with the actualSlotNumber
corresponding to the context object associated with the execution trace. For example, one could construct an IR element that accesses the context'smetadata
field by emitting:Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))
Cassette provides a few IR-munging utility functions of interest to pass authors: insert_statements!
, replace_match!
Cassette.replace_match!
— Function.replace_match!(replace, ismatch, x)
Return x
with all subelements y
replaced with replace(y)
if ismatch(y)
. If !ismatch(y)
, but y
is of type Expr
, Array
, or SubArray
, then replace y
in x
with replace_match!(replace, ismatch, y)
.
Generally, x
should be of type Expr
, Array
, or SubArray
.
Note that this function modifies x
(and potentially its subelements) in-place.
See also: insert_statements!
Cassette.insert_statements!
— Function.insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)
For every statement stmt
at position i
in code
for which stmtcount(stmt, i)
returns an Int
, remove stmt
, and in its place, insert the statements returned by newstmts(stmt, i)
. If stmtcount(stmt, i)
returns nothing
, leave stmt
alone.
For every insertion, all downstream SSAValue
s, label indices, etc. are incremented appropriately according to number of inserted statements.
Proper usage of this function dictates that following properties hold true:
code
is expected to be a valid value for thecode
field of aCodeInfo
object.codelocs
is expected to be a valid value for thecodelocs
field of aCodeInfo
object.newstmts(stmt, i)
should return aVector
of valid IR statements.stmtcount
andnewstmts
must obeystmtcount(stmt, i) == length(newstmts(stmt, i))
ifisa(stmtcount(stmt, i), Int)
.
To gain a mental model for this function's behavior, consider the following scenario. Let's say our code
object contains several statements:
code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]
Let's also say that for our stmtcount
returns 2
for stmtcount(oldstmt2, 2)
, returns 3
for stmtcount(oldstmt5, 5)
, and returns nothing
for all other inputs. From this setup, we can think of code
/codelocs
being modified in the following manner:
newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
newstmts2[1], newstmts2[2],
oldstmt3, oldstmt4,
newstmts5[1], newstmts5[2], newstmts5[3],
oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]
See also: replace_match!
Cassette.tag
— Function.tag(value, context::Context, metadata = Cassette.NoMetaData())
Return value
tagged w.r.t. context
, optionally associating metadata
with the returned Tagged
instance.
Any provided metadata
must obey the type constraints determined by Cassette's metadatatype
method.
Note that hastagging(typeof(context))
must be true
for a value to be tagged w.r.t. to context
.
See also: untag
, enabletagging
, hastagging
Cassette.untag
— Function.Cassette.untagtype
— Function.untagtype(::Type{T}, ::Type{C<:Context})
Return typeof(untag(::T, ::C))
.
In other words, untagtype(typeof(tag(x, context)), typeof(context)) === typeof(x)
is always true
.
If !istaggedtype(T, C)
, then untagtype(T, C) === T
is true
.
Cassette.metadata
— Function.metadata(x, context::Context)
Return the metadata
attached to x
if hasmetadata(x, context)
, otherwise return Cassette.NoMetaData()
.
In other words, metadata(tag(x, context, m)), context) === m
is always true
.
If !hasmetadata(x, context)
, then metadata(x, context) === Cassette.NoMetaData()
is true
.
Cassette.metadatatype
— Function.metadatatype(::Type{<:Context}, ::Type{T})
Overload this Cassette method w.r.t. a given context to define the type of metadata that can be tagged to values of type T
within that context.
By default, this method is set such that associating metadata with any tagged value is disallowed.
Cassette uses metadatatype
to statically compute a context-specific metadata type hiearchy for all tagged values within overdubbed programs. To gain a mental model for this mechanism, consider a simple struct definition as follows:
struct Foo
x::Int
y::Complex{Int}
end
Now, Cassette can use metadatatype
to determine type constraints for metadata structures associated with tagged values of type Foo
. In psuedo-Julia-code, these metadata structures might look something like the following for Foo
:
struct IntMeta
data::metadatatype(Ctx, Int)
meta::Cassette.NoMetaMeta
end
struct ComplexIntMeta
data::metadatatype(Ctx, Complex{Int})
meta::NamedTuple{(:re,:im),Tuple{IntMeta,IntMeta}}
end
struct FooMeta
data::metadatatype(Ctx, Foo)
meta::NamedTuple{(:x,:y),Tuple{IntMeta,ComplexIntMeta}
end
Examples
julia> Cassette.@context Ctx;
# any value of type `Number` can now be tagged with metadata of type `Number`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{<:Number}) = Number
# any value of type `T<:Number` can now be tagged with metadata of type `T`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{T}) where {T<:Number} = T
# any value of type `T<:Number` can now be tagged with metadata of type `promote_type(T, M)`
# where `M` is the type of the trace-local metadata associated with the context
julia> Cassette.metadatatype(::Type{<:Ctx{M}}, ::Type{T}) where {M<:Number,T<:Number} = promote_type(T, M)
Cassette.hasmetadata
— Function.hasmetadata(x, context::Context)
Return true
if !isa(metadata(x, context), Cassette.NoMetaData)
, return false
otherwise.
In other words, hasmetadata(tag(x, context, m), context)
is always true
and hasmetadata(tag(x, context), context)
is always false
.
See also: metadata
Cassette.istagged
— Function.istagged(x, context::Context)
Return true
if x
is tagged w.r.t. context
, return false
otherwise.
In other words, istagged(tag(x, context), context)
is always true
.
See also: tag
, istaggedtype
Cassette.istaggedtype
— Function.