Restartable functions from first principles

Restartable functions from first principles

Often when thinking about Jetpack Compose state, we think in terms of what code will be re-executed when that state changes. We call functions that get re-executed like this “restartable functions”. Composable functions are generally restartable, but some other functions are restartable too, and you can make your own.

📖
This article is part of a series on Compose state. Check out the other articles here.

Other common restartable functions include:

I specifically say “the function passed to” because that's the only code that you as a developer see that gets restarted. The implementation of Modifier.drawBehind, for example, might re-execute some of its own internal code as well, but the drawBehind function itself is not called again.

ℹ️
If I've lost you already, check out Introduction to the Compose Snapshot system.

How do functions restart?

To make a function restartable when state changes, two behaviors need to collaborate:

  1. The caller of the function must track what state objects are read by the function.
  2. The caller of the function must listen for a callback when state objects change, and when one or more of the objects from step (1) are changed re-execute the function. Or, more typically, schedule the function to re-execute at some point in the future.

It's possible to write a very basic (and relatively simple) implementation of this in just a few lines of code:

// Create some state inputs to read later.
val state1 = mutableIntStateOf(0)
val state2 = mutableIntStateOf(0)

// Remember what state objects were read.
val readSet = mutableSetOf<Any>()

fun printResult() {
  // Before tracking new reads, clear the old ones.
  readSet.clear()
  Snapshot.observe(
    // This gets called by every state object's getter (.value).
    readObserver = { readSet += it }
  ) {
    // Any state reads done in this lambda will be tracked in readSet.
    // This lambda is effectively the restartable function.
    val result = state1.value + state2.value
    println(result)
  }
}

// This lambda will be called any time any snapshot in the process is
// applied.
Snapshot.registerApplyObserver { changedObjects, _ ->
    if (changedObjects.any { it in readSet }) {
      // Immediately restart printResult. In a real app, this would
      // probably be scheduled e.g. to happen on the next frame, and
      // only once per frame.
      printResult()
    }
}

// Get the party started!
printResult()

The first part to look at here is the read observer.

Snapshot.observe(
  readObserver = { readSet += it }
) {
  TODO("read state objects, do calculations, etc.")
}

We use the Snapshot.observe function to run some code and register a listener that will be invoked whenever a state object is read. The callback gets passed the state object that was read. This could be any state object—a mutableStateListOf, a mutableStateOf, or even a derivedStateOf—anyone can define a state object, so really the only thing we can do in this callback is add the object to a set to check membership later.

Reporting state reads

To understand how the read observer gets notified, you first need to understand a little bit about how state objects work. The State<T> implementation of mutableStateOf looks like this (source):

override var value: T
    get() = next.readable(this).value

next is the first StateRecord in the linked list, and readable does two things: it selects the state record that is readable in this snapshot, and—what we care about most right now—it calls the current snapshot's read observer with the state object passed to it. It's very clear in the code (source):

fun <T : StateRecord> T.readable(state: StateObject): T {
    val snapshot = Snapshot.current
    snapshot.readObserver?.invoke(state)
    …
}

That's why this is passed to readable—it's the value that ends up being passed to the readObserver lambda above. All state objects—even more complicated ones that don't have a single getter, such as SnapshotStateList—use readable whenever they need to access their data. SnapshotStateList, for example, uses it in every List method it implements, and so each of those methods will invoke the read observer (source):

internal val readable: StateListStateRecord<T>
    get() = (firstStateRecord as StateListStateRecord<T>).readable(this)

…
        
override val size: Int
    get() = readable.list.size

override fun contains(element: T) = readable.list.contains(element)

override fun containsAll(elements: Collection<T>) = readable.list.containsAll(elements)

override fun get(index: Int) = readable.list[index]

override fun indexOf(element: T): Int = readable.list.indexOf(element)

override fun isEmpty() = readable.list.isEmpty()

For more information about StateRecord and StateObject see Implementing snapshot-aware data structures.

Detecting when state objects change

Now that we know what state objects were read in the lambda, we need to know when they change. For that we use the Snapshot.registerApplyObserver function, which takes a callback that will be invoked whenever a snapshot is applied. It's invoked synchronously from whatever thread applies the snapshot, so a real-world implementation needs to be thread-safe, but I'll ignore that here to focus on the change detection logic. The callback receives a set of all the state objects that were written to in the snapshot, so we just loop through those objects to see if any of the ones we care about were written.

Snapshot.registerApplyObserver { changedObjects, _ ->
  if (changedObjects.any { it in readSet }) {
    TODO("schedule observer to run again")
  }
}

Now, somewhere else in our program, we can change one of the inputs and watch the new result get printed:

// Create a new snapshot to explicitly trigger the apply observer.
Snapshot.withMutableSnapshot {
  state1.value = 42
} // <- Here, before withMutableSnapshot returns, the apply
  //    observer will be called.

restartWhenStateChanged

We could wrap this logic up into a generic helper function that lets us restart any code whenever a state object it reads changes, something like this:

suspend fun restartWhenStateChanged(
    block: () -> Unit
) {
    val readSet = mutableSetOf<Any>()

    // Calls block, tracking any state objects it reads.
    fun observeBlock() {
        readSet.clear()
        Snapshot.observe(readObserver = { readSet += it }) {
            block()
        }
    }

    // Restart observeBlock any time a state it read is changed.
    val observerHandle = Snapshot.registerApplyObserver { changedObjects, _ ->
        if (changedObjects.any { it in readSet }) {
            // Immediately re-run block with observation.
            observeBlock()
        }
    }

    try {
        // Bootstrap the call so we have some state objects to watch for.
        observeBlock()

        // Suspend indefinitely to let the apply observer do its thing.
        awaitCancellation()
    } finally {
        observerHandle.dispose()
    }
}

This is pretty close to the simple snippet we started with, but we've put it in a suspend function to manage the lifetime of the restartable code and properly clean up the apply observer when we're done. We can call our function like this:

val state1 = mutableIntStateOf(0)
val state2 = mutableIntStateOf(0)

coroutineScope.launch {
    restartWhenStateChanged {
        println(state1.value + state2.value)
    }
}

That's a nice simple API. But our restartWhenStateChanged implementation is not production-ready:

  1. It re-runs the block synchronously, from the apply observer, which can be executed on any thread. This means the caller has no idea what thread the lambda they pass will run on.
  2. There's a race condition: we're reading from readSet in the apply observer. We could be looping through it on one thread while another thread is running observeBlock and writing to it. This will probably cause a ConcurrentModificationException to be thrown, or at worst, cause weird data inconsistencies.
  3. We're doing potentially a lot of work in the apply observer. Whenever any snapshot is applied, anywhere in the process, every apply observer is invoked, every time, so they must be fast. If changedObjects is large, it could take a while to iterate over, and that won't scale very well.

Let's try to solve (1) first. We're already in a suspend function, which means we implicitly have access to a scheduling mechanism: the coroutine dispatcher. We can make sure the observe block always runs in the originating coroutine context by using a channel to coordinate:

suspend fun restartWhenStateChanged(
    block: () -> Unit
) {
    val readSet = mutableSetOf<Any>()
    val restartTrigger = Channel<Unit>(
        capacity = 1,
        onBufferOverflow = BufferOverflow.DROP_OLDEST
    )

    // Calls block, tracking any state objects it reads.
    fun observeBlock() {
        readSet.clear()
        Snapshot.observe(readObserver = { readSet += it }) {
            block()
        }
    }

    // Restart observeBlock any time a state it read is changed.
    val observerHandle = Snapshot.registerApplyObserver { changedObjects, _ ->
        if (changedObjects.any { it in readSet }) {
            // Schedule observeBlock to be ran again.
            restartTrigger.trySend(Unit)
        }
    }

    try {
        // Bootstrap the call so we have some state objects to watch for.
        observeBlock()

        // Wait for the apply observer to detect a change, then re-run block.
        for (unit in restartTrigger) {
            observeBlock()
        }
    } finally {
        observerHandle.dispose()
    }
}

We make a channel that just resumes the loop in the main function body when the apply observer needs to reschedule the block. We give it a capacity of one and DROP_LATEST behavior since the values in the channel aren't significant, and multiple schedule requests can be coalesced. Now observeBlock is always called directly from the restartWhenStateChanged and always gets its coroutine context. Kotlin guarantees that the code in a suspend function is executed synchronously so we don't have to worry about ordering. And no matter what thread applies a snapshot, the apply observer merely schedules the block to run again, it doesn't run it directly. But we've still got that race condition in the apply observer.

Well, we already have a channel set up but we're just using it as a trigger signal, not passing any data around. We can move the changedObjects processing out of the apply observer using the same channel. changedObjects has type Set<Any>, so we need to change its type to that, and since we now care about the values in the channel, we don't want to drop any, so let's change the capacity to UNLIMITED. This lets the non-suspending apply observer send values into the channel without suspending and they'll always be accepted. Since an unlimited channel can never be full, we don't need to specify buffer overflow behavior anymore.

suspend fun restartWhenStateChanged(
    block: () -> Unit
) {
    val readSet = mutableSetOf<Any>()
    val changedObjectsChannel = Channel<Set<Any>>(
        capacity = Channel.UNLIMITED
    )

    // Calls block, tracking any state objects it reads.
    fun observeBlock() {
        readSet.clear()
        Snapshot.observe(readObserver = { readSet += it }) {
            block()
        }
    }

    // Restart observeBlock any time a state it read is changed.
    val observerHandle = Snapshot.registerApplyObserver { changedObjects, _ ->
        // Sending to an unlimited channel is fast and doesn't depend on the
        // number of changed objects.
        changedObjectsChannel.trySend(changedObjects)
    }

    try {
        // Bootstrap the call so we have some state objects to watch for.
        observeBlock()

        // Wait for a snapshot to be applied…
        for (changedObjects in changedObjectsChannel) {
            // …then check if it changed any objects we care about.
            if (changedObjects.any { it in readSet }) {
                observeBlock()
            }
        }
    } finally {
        observerHandle.dispose()
    }
}

Now whenever a snapshot is applied, all we do is push changedObjects into a channel, a very fast operation. All the actual work, including checking for relevant objects and re-executing the block, are done in the context of the caller. We've made the apply observer faster and solved all the races and threading issues! I think we can ship this.

snapshotFlow

Our handle little restartWhenStateChanged function is pretty basic. All it lets us do is rerun a block of code whenever any state it uses is changed. In practice, this is of limited use. A more helpful function might, for example, allow the block to produce a value from one or more states, and then emit the result values from a Flow. Luckily this exact function is actually built-in to Compose: snapshotFlow.

We can actually use snapshotFlow to (mostly) implement restartWhenStateChanged:

suspend fun restartWhenStateChanged(
    block: () -> Unit
) {
    snapshotFlow(block)
        .collect()
}

Wow, that's a lot simpler! The behavior is almost identical, except that this doesn't support doing state writes in block. For an explanation of why, see Read-only snapshot below.

By some cosmic coincidence, it turns out that the last full implementation of restartWhenStateChanged in the previous section is very close to the actual implementation of snapshotFlow. Take a look for yourself. There are a few minor differences, but the basic logic is the same. Let's go over the differences.

ReaderKind

if (it is StateObjectImpl) {
    it.recordReadIn(ReaderKind.SnapshotFlow)
}
…
val maybeObserved = changed.fastAny { 
    it !is StateObjectImpl || it.isReadIn(ReaderKind.SnapshotFlow)
}

ReaderKind is an internal optimization used by the state objects that are built-in to the Compose runtime (e.g. mutableStateOf) that allows those state objects to track not only when they're read, but what context they were read in. Apply observers can then reject objects more quickly based on this flag. For example, the state observer used for composable functions doesn't ever care about state objects that were only read inside a snapshotFlow, and vice versa, so this flag allows them to ignore those objects without even checking them for membership in their read sets.

Read-only snapshot

Snapshot.takeSnapshot(readObserver).run {
    try {
        enter(block)
    } finally {
        dispose()
    }
}

In restartWhenStateChanged we used Snapshot.observe to observe reads while running the block. However, this allows the block to perform state writes as well as state reads. This might be useful for that function, but snapshotFlow is designed to take a pure function that produces a value—it shouldn't write anything. Snapshot.takeSnapshot takes a read-only snapshot, so any attempt to write a state object inside block will throw. This helps enforce the purity of the function and is also a little more efficient, since the snapshot doesn't need to do any of the bookkeeping involved in state writes.

Change polling

while (true) {
    …
    var changedObjects = appliedChanges.receive()
    
    // Poll for any other changes before running block to minimize the number of
    // additional times it runs for the same data
    while (true) {
        // Assumption: readSet will typically be smaller than changed set
        found = found || readSet.intersects(changedObjects)
        changedObjects = appliedChanges.tryReceive().getOrNull() ?: break
    }
    …
}

Instead of just simply iterating over the channel, snapshotFlow uses a while loop and manually calls receive on the channel. But more notably, it tries to receive multiple times in the same main loop iteration. Since apply observers can be invoked on any thread, and the channel has unlimited capacity, it's possible for a bunch of change sets to queue up in the channel while the main coroutine is waiting to be scheduled, or while it's checking for relevant changed objects. This inner while loop tries to drain the change channel entirely before re-running the block, coalescing the restart when multiple snapshots are applied in quick succession or from multiple threads.

SnapshotStateObserver

That is how restartable functions are built using the lowest-level APIs from the Compose snapshot system. However, with the notable exception of snapshotFlow, most code does not actually do all this directly—it uses a helper class called SnapshotStateObserver. This class uses the same basic principles, but is designed to share some of the bookkeeping data between multiple different restartable scopes to reduce resource usage. It also implements some additional bookkeeping specifically to optimize derivedStateOf. Composable functions are a special case, they do even more complicated bookkeeping than SnapshotStateObserver to optimize even more so they don't use it directly but they use similar logic to handle derivedStateOf.

⚠️
App code should never need to create or use SnapshotStateObserver instances. It's a complicated API that is hard to use correctly, and overkill for most use cases anyway. I'm discussing it here for completeness.

As an exercise, we can rewrite restartWhenStateChanged to use SnapshotStateObserver. However this is not a good use case for SnapshotStateObserver, since it is only meant to be used when you have multiple types of restartable functions to manage and want to share some of the cost between them.

suspend fun restartWhenStateChanged(
    block: () -> Unit
) {
    val observer = SimpleStateObserver(block)
    observer.run()
}

private class SimpleStateObserver(
    private val block: () -> Unit
) {
    private val onChangedChannel = Channel<() -> Unit>(Channel.UNLIMITED)

    // Use the channel to ensure all change notifications are processed in the calling
    // coroutine context.
    private val stateObserver = SnapshotStateObserver(
        onChangedExecutor = { onChangedChannel.trySend(it) }
    )

    // The changed callback is allocated as a single object across all invocations
    // of observeReads below. Due to how this function is used in SnapshotStateObserver,
    // allocating a new one every time will degrade performance over time, so putting in
    // in a val makes the desired behavior explicit.
    private val onValueChangedInBlock = { _: Unit ->
        observeBlock()
    }

    private fun observeBlock() {
        // This function manages the read observer and invalidation logic for us.
        stateObserver.observeReads(
            // Scope can be used to pass different data to the onValueChangedInBlock function,
            // but we don't need that.
            scope = Unit,
            onValueChangedForScope = onValueChangedInBlock,
            block = block
        )
    }

    suspend fun run() {
        // SnapshotStateObserver has explicit start/stop methods to set up its listeners
        // and tear them down.
        stateObserver.start()
        try {
            // Bootstrap the observation.
            observeBlock()
            
            // Process change notifications until cancelled.
            for (onChanged in onChangedChannel) {
                onChanged()
            }
        } finally {
            stateObserver.stop()
        }
    }
}

As you can see, it's not really better. Firstly, we need to extract most of the logic into a class so we can reference onValueChangedInBlock from observeBlock and vice versa. We still have to specify how to do scheduling. And the observeReads method takes a scope parameter that we're not even using.

For a more realistic example, take a look at OwnerSnapshotObserver in Compose UI.

Conclusion

Restartable functions are a core concept of Jetpack Compose. They use snapshots to observe state reads in their own code, and state changes from anywhere else. snapshotFlow is the most common API you'll probably use to make your own restartable code.

The official Compose docs have more information:

Also, check out my other articles about Compose state.

Thanks for reading!