Cassette API Documentation

Cassette API Documentation

Context{N<:Cassette.AbstractContextName,
        M<:Any,
        P<:Cassette.AbstractPass,
        T<:Union{Nothing,Cassette.Tag},
        B<:Union{Nothing,Cassette.BindingMetaDictCache}}

A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context:

julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B} where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
                                          where P<:Cassette.AbstractPass
                                          where T<:Union{Nothing,Tag}
                                          where M

Constructors

Given a context type alias named e.g. MyCtx, an instance of the type can be constructed via:

MyCtx(; metadata = nothing, pass = Cassette.NoPass())

To construct a new context instance using an existing context instance as a template, see the similarcontext function.

To enable contextual tagging for a given context instance, see the enabletagging function.

Fields

  • name::N<:Cassette.AbstractContextName: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishes MyCtx from other Context type aliases).

  • metadata::M<:Any: trace-local metadata as provided to the context constructor

  • pass::P<:Cassette.AbstractPass: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the @pass macro for details).

  • tag::T<:Union{Nothing,Tag}: the tag object that is attached to values when they are tagged w.r.t. the context instance

  • bindingscache::B<:Union{Nothing,BindingMetaDictCache}}: storage for metadata associated with tagged module bindings

source
similarcontext(context::Context;
               metadata = context.metadata,
               pass = context.pass,
               tag = context.tag,
               bindingscache = context.bindingscache)

Return a copy of the given context, replacing field values in the returned instance with those provided via the keyword arguments.

source
enabletagging(context::Cassette.Context, f)

Return a copy of the given context with the tagging system enabled for the contextual execution of f.

Cassette uses the type of f to generate the tag field of the returned instance.

Note that it is generally unsafe to use the returned instance to contextually execute functions other than f. Specifically, in cases of nested contextual execution where both inner and outer contexts employ the tagging system, improper application of the tagged system could cause (for example) separate contexts to erroneously interfere with each other's metadata propagation.

See also: hastagging

source
Cassette.hastaggingFunction.
hastagging(::Type{<:Cassette.Context})

Returns true if the given type indicates that the contextual tagging system is enabled for context instances of the type, returns false otherwise.

Example

julia> Cassette.@context MyCtx;

julia> ctx = MyCtx();

julia> Cassette.hastagging(typeof(ctx))
false

julia> ctx = Cassette.enabletagging(ctx, sum);

julia> Cassette.hastagging(typeof(ctx))
true

See also: enabletagging

source
Cassette.@context Ctx

Define a new Cassette context type with the name Ctx. In reality, Ctx is simply a type alias for Cassette.Context{Cassette.nametype(Ctx)}.

Note that Cassette.execute is automatically overloaded w.r.t. Ctx to define several primitives by default. A full list of these default primitives can be obtained by running:

methods(Cassette.execute, (Ctx, Vararg{Any}))

Note also that many of the default primitives' signatures only match when contextual tagging is enabled.

See also: Context

source
Cassette.overdubFunction.
overdub(context::Context, f, args...)

Execute f(args...) overdubbed with respect to context.

More specifically, execute f(args...), but with every internal method invocation g(x...) replaced by statements similar to the following:

begin
    prehook(context, g, x...)
    tmp = execute(context, g, x...)
    tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, g, x...) : tmp
    posthook(context, tmp, g, x...)
    tmp
end

If Cassette cannot retrieve lowered IR for the method body of f(args...) (as determined by canoverdub(context, f, args...)), then overdub(context, f, args...) will directly translate to a call to fallback(context, f, args...).

Additionally, for every method body encountered in execute trace, apply the compiler pass associated with context if one exists. Note that this user-provided pass is performed on the method IR before method invocations are transformed into the form specified above. See the @pass macro for further details.

If Cassette.hastagging(typeof(context)), then a number of additional passes are run in order to accomodate tagged value propagation:

  • Expr(:new) is replaced with a call to Cassette.tagged_new
  • conditional values passed to Expr(:gotoifnot) are untagged
  • arguments to Expr(:foreigncall) are untagged
  • load/stores to external module bindings are intercepted by the tagging system
source
Cassette.@overdub(ctx, expression)

A convenience macro for executing expression within the context ctx. This macro roughly expands to Cassette.overdub(ctx, () -> expression).

See also: overdub

source
Cassette.prehookFunction.
prehook(context::Context, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual prehook for that context.

To understand when/how this method is called, see the documentation for overdub.

Invoking prehook is a no-op by default (it immediately returns nothing).

See also: overdub, posthook, execute, fallback

Examples

Simple trace logging:

julia> Cassette.@context PrintCtx;

julia> Cassette.prehook(::PrintCtx, f, args...) = println(f, args)

julia> Cassette.overdub(PrintCtx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5

Counting the number of method invocations with one or more arguments of a given type:

julia> mutable struct Count{T}
           count::Int
       end

julia> Cassette.@context CountCtx;

julia> Cassette.prehook(ctx::CountCtx{Count{T}}, f, arg::T, args::T...) where {T} = (ctx.metadata.count += 1)

# count the number of calls of the form `f(::Float64, ::Float64...)`
julia> ctx = CountCtx(metadata = Count{Float64}(0));

julia> Cassette.overdub(ctx, /, 1, 2)
0.5

julia> ctx.metadata.count
2
source
Cassette.posthookFunction.
posthook(context::Context, output, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual posthook for that context.

To understand when/how this method is called, see the documentation for overdub.

Invoking posthook is a no-op by default (it immediately returns nothing).

See also: overdub, prehook, execute, fallback

Examples

Simple trace logging:

julia> Cassette.@context PrintCtx;

julia> Cassette.posthook(::PrintCtx, output, f, args...) = println(output, " = ", f, args)

julia> Cassette.overdub(PrintCtx(), /, 1, 2)
1.0 = sitofp(Float64, 1)
1.0 = Float64(1,)
1.0 = AbstractFloat(1,)
1.0 = float(1,)
2.0 = sitofp(Float64, 2)
2.0 = Float64(2,)
2.0 = AbstractFloat(2,)
2.0 = float(2,)
0.5 = div_float(1.0, 2.0)
0.5 = /(1.0, 2.0)
0.5

Accumulate the sum of all numeric scalar outputs encountered in the trace:

julia> mutable struct Accum
           x::Number
       end

julia> Cassette.@context AccumCtx;

julia> Cassette.posthook(ctx::AccumCtx{Accum}, out::Number, f, args...) = (ctx.metadata.x += out)

julia> ctx = AccumCtx(metadata = Accum(0));

julia> Cassette.overdub(ctx, /, 1, 2)
0.5

julia> ctx.metadata.x
13.0
source
Cassette.executeFunction.
execute(context::Context, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual execution primitive for that context.

To understand when/how this method is called, see the documentation for overdub.

Invoking execute immediately returns Cassette.OverdubInstead() by default.

See also: overdub, prehook, posthook, fallback

source
Cassette.fallbackFunction.
fallback(context::Context, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual execution fallback for that context.

To understand when/how this method is called, see the documentation for overdub and canoverdub.

By default, invoking fallback(context, f, args...) will simply call f(args...) (with all arguments automatically untagged, if hastagging(typeof(context))).

See also: canoverdub, overdub, execute, prehook, posthook

source
Cassette.canoverdubFunction.
canoverdub(context::Context, f, args...)

Return true if f(args...) has a lowered IR representation that Cassette can overdub, return false otherwise.

Alternatively, but equivalently:

Return false if overdub(context, f, args...) directly translates to fallback(context, f, args...), return true otherwise.

Note that unlike execute, fallback, etc., this function is not intended to be overloaded.

See also: overdub, fallback, execute

source
Cassette.@passMacro.
Cassette.@pass transform

Return a Cassette pass that can be provided to the Context constructor's pass keyword argument in order to apply transform to the lowered IR representations of all methods invoked during contextual execution.

transform must be a Julia object that is callable with the following signature:

transform(::Type{<:Context}, signature::Type{Tuple{...}}, method_body::CodeInfo)::CodeInfo

Note that the @pass macro expands to an eval call and thus should only be called at top-level. Furthermore, to avoid world-age issues, transform should not be overloaded after it has been registered with @pass.

Note also that transform should be "relatively pure." More specifically, Julia's compiler has license to apply transform multiple times, even if only compiling a single method invocation once. Thus, it is required that transform always return a generally "equivalent" CodeInfo for a given context, method body, and signature. If your transform implementation is not naturally "pure" in this sense, then it is still possible to guarantee this property by memoizing your implementation (i.e. maintaining a cache of previously computed IR results, instead of recomputing results every time).

Two special Expr heads are available to Cassette pass authors that are not normally valid in Julia IR. Exprs with these heads can be used to interact with the downstream built-in Cassette passes that consume them.

  • :nooverdub: Wrap an Expr with this head value around the first argument in an Expr(:call) to tell downstream built-in Cassette passes not to overdub that call. For example, Expr(:call, Expr(:nooverdub, GlobalRef(MyModule, :myfunc)), args...).

  • :contextslot: Cassette will replace any Expr(:contextslot) with the actual SlotNumber corresponding to the context object associated with the execution trace. For example, one could construct an IR element that accesses the context's metadata field by emitting: Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))

Cassette provides a few IR-munging utility functions of interest to pass authors: insert_statements!, replace_match!

See also: Context, overdub

source
replace_match!(replace, ismatch, x)

Return x with all subelements y replaced with replace(y) if ismatch(y). If !ismatch(y), but y is of type Expr, Array, or SubArray, then replace y in x with replace_match!(replace, ismatch, y).

Generally, x should be of type Expr, Array, or SubArray.

Note that this function modifies x (and potentially its subelements) in-place.

See also: insert_statements!

source
insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)

For every statement stmt at position i in code for which stmtcount(stmt, i) returns an Int, remove stmt, and in its place, insert the statements returned by newstmts(stmt, i). If stmtcount(stmt, i) returns nothing, leave stmt alone.

For every insertion, all downstream SSAValues, label indices, etc. are incremented appropriately according to number of inserted statements.

Proper usage of this function dictates that following properties hold true:

  • code is expected to be a valid value for the code field of a CodeInfo object.
  • codelocs is expected to be a valid value for the codelocs field of a CodeInfo object.
  • newstmts(stmt, i) should return a Vector of valid IR statements.
  • stmtcount and newstmts must obey stmtcount(stmt, i) == length(newstmts(stmt, i)) if isa(stmtcount(stmt, i), Int).

To gain a mental model for this function's behavior, consider the following scenario. Let's say our code object contains several statements:

code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]

Let's also say that for our stmtcount returns 2 for stmtcount(oldstmt2, 2), returns 3 for stmtcount(oldstmt5, 5), and returns nothing for all other inputs. From this setup, we can think of code/codelocs being modified in the following manner:

newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
           newstmts2[1], newstmts2[2],
           oldstmt3, oldstmt4,
           newstmts5[1], newstmts5[2], newstmts5[3],
           oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]

See also: replace_match!

source
Cassette.tagFunction.
tag(value, context::Context, metadata = Cassette.NoMetaData())

Return value tagged w.r.t. context, optionally associating metadata with the returned Tagged instance.

Any provided metadata must obey the type constraints determined by Cassette's metadatatype method.

Note that hastagging(typeof(context)) must be true for a value to be tagged w.r.t. to context.

See also: untag, enabletagging, hastagging

source
Cassette.untagFunction.
untag(x, context::Context)

Return x untagged w.r.t. context if istagged(x, context), otherwise return x directly.

In other words, untag(tag(x, context), context) === x is always true.

If !istagged(x, context), then untag(x, context) === x is true.

See also: tag, istagged

source
Cassette.untagtypeFunction.
untagtype(::Type{T}, ::Type{C<:Context})

Return typeof(untag(::T, ::C)).

In other words, untagtype(typeof(tag(x, context)), typeof(context)) === typeof(x) is always true.

If !istaggedtype(T, C), then untagtype(T, C) === T is true.

source
Cassette.metadataFunction.
metadata(x, context::Context)

Return the metadata attached to x if hasmetadata(x, context), otherwise return Cassette.NoMetaData().

In other words, metadata(tag(x, context, m)), context) === m is always true.

If !hasmetadata(x, context), then metadata(x, context) === Cassette.NoMetaData() is true.

source
Cassette.metadatatypeFunction.
metadatatype(::Type{<:Context}, ::Type{T})

Overload this Cassette method w.r.t. a given context to define the type of metadata that can be tagged to values of type T within that context.

By default, this method is set such that associating metadata with any tagged value is disallowed.

Cassette uses metadatatype to statically compute a context-specific metadata type hiearchy for all tagged values within overdubbed programs. To gain a mental model for this mechanism, consider a simple struct definition as follows:

struct Foo
    x::Int
    y::Complex{Int}
end

Now, Cassette can use metadatatype to determine type constraints for metadata structures associated with tagged values of type Foo. In psuedo-Julia-code, these metadata structures might look something like the following for Foo:

struct IntMeta
    data::metadatatype(Ctx, Int)
    meta::Cassette.NoMetaMeta
end

struct ComplexIntMeta
    data::metadatatype(Ctx, Complex{Int})
    meta::NamedTuple{(:re,:im),Tuple{IntMeta,IntMeta}}
end

struct FooMeta
    data::metadatatype(Ctx, Foo)
    meta::NamedTuple{(:x,:y),Tuple{IntMeta,ComplexIntMeta}
end

Examples

julia> Cassette.@context Ctx;

# any value of type `Number` can now be tagged with metadata of type `Number`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{<:Number}) = Number

# any value of type `T<:Number` can now be tagged with metadata of type `T`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{T}) where {T<:Number} = T

# any value of type `T<:Number` can now be tagged with metadata of type `promote_type(T, M)`
# where `M` is the type of the trace-local metadata associated with the context
julia> Cassette.metadatatype(::Type{<:Ctx{M}}, ::Type{T}) where {M<:Number,T<:Number} = promote_type(T, M)
source
Cassette.hasmetadataFunction.
hasmetadata(x, context::Context)

Return true if !isa(metadata(x, context), Cassette.NoMetaData), return false otherwise.

In other words, hasmetadata(tag(x, context, m), context) is always true and hasmetadata(tag(x, context), context) is always false.

See also: metadata

source
Cassette.istaggedFunction.
istagged(x, context::Context)

Return true if x is tagged w.r.t. context, return false otherwise.

In other words, istagged(tag(x, context), context) is always true.

See also: tag, istaggedtype

source
Cassette.istaggedtypeFunction.
istaggedtype(::Type{T}, ::Type{C<:Context})

Return typeof(istagged(::T, ::C)).

In other words, istaggedtype(typeof(tag(x, context)), typeof(context)) is always true.

See also: tag, istagged

source