Stopping AI Assistant Streaming - getting 'Run' from 'Stream'

Very happy to see streaming in the assistants API! I’m trying to figure out how best to allow users to stop assistant-streaming (and providing my workaround for others).

I am using something akin to the following:

def converse(self, user_id, message_text):
    """
    Continue an existing conversation, or start a new one if none exists
    """
    self.threads[user_id] = self.get_thread(user_id)

    self.client.beta.threads.messages.create(
        thread_id=self.threads[user_id].id,
        role="user",
        content=message_text
    )

    assistant = get_assistant(user_id)

    def generate_response():
        # Create a generator for streaming, which will yield the current
        # state of the text. Flask can handle generators, then we can use
        # js to async stream it
        with self.client.beta.threads.runs.create_and_stream(
            thread_id=self.threads[user_id].id,
            assistant_id=assistant.id,
        ) as stream:
            self.streams[user_id] = stream

            for text in stream.text_deltas:
                print(text, end="", flush=True)
                yield(text)

    self.generators[user_id] = generate_response()

This successfully allows me to later pass along the generator for streaming (similar to what is seen here

However, if I later want to stop generating using something like:

self.client.beta.threads.runs.cancel(
    thread_id=thread.id,
    run_id=run.id
)

I am unsure how to pull the run.id from the stream object. The stream.current_run property seems to be the intended solution, and indeed I can see that it is set during streaming:

...
for text in stream.text_deltas:
    print(f"Current run: {stream.current_run}", flush=True)
    yield(text)

However, it is only set when actively streaming:

...
print(f"This prints 'None': {stream.current_run}", flush=True)
for text in stream.text_deltas:
    print(f"Current run: {stream.current_run}", flush=True)
    yield(text)

This is troublesome, of course, as the server wants to keep track of the run.id so we can cancel it, but the client is the one dictating when to start/consume the streaming generator.

Things I have tried:

  • Simply calling stream = client.beta.threads.runs.create_and_stream(..)
    but this returns a AssistantStreamManager which throws AttributeError: 'AssistantStreamManager' object has no attribute 'current_run'.
  • Creating separate ‘initialize’ and ‘stream’ generators, where ‘initialize’ returns the stream context, and ‘stream’ uses it again to stream, but trying to re-enter context throws AttributeError: __enter__ on stream
  • Creating the stream first, then defining the generator inside the streams context. This doesn’t work, as once you exit the context, the later generator use triggers raise StreamClosed()
  • Using the class EventHandler(AssistantEventHandler): to create a custom handler, and pull the run from that instead - but I also could not figure out how the AssistantEventHandler is intended to access the run.id

(I could provide code for these as needed, but it seemed bloated to try to fit it all here)

Here is what “works”:

with self.client.beta.threads.runs.create_and_stream(
    thread_id=self.threads[user_id].id,
    assistant_id=assistant.id,
) as stream:
    self.streams[user_id] = stream

    # Top populate stream.current_run, it seems we have to be
    # actively iterating through streamed content.
    # As a temporary workaround, our first 'yield' will just
    # allow us to start the stream - which we call instantly
    # - then future ones return the actual streamed content
    started = False
    for text in stream.text_deltas:
        if not started:
            yield()
            started = True
        print(text, end="", flush=True)
        yield(text)


self.generators[net_id] = generate_response()
# Call the generator once to start the run and set the
# stream.current_run property
next(self.generators[net_id])
assert self.streams.get(net_id) and self.streams[net_id].current_run

I really seems like I am missing something - either with the API, or with generator best practices.
How do I get the run.id from a stream? Why does using stream = client.beta.threads.runs.create_and_stream(...) not provide a stream.current_run? Why is the current_run property not populated unless actively iterating?

I am certain this behavior is the result of valid, technical reasons - just trying to figure out what the best approach is for stopping a run generated from create_and_stream.

What I think myself (and many others) are trying to do is ‘mimic some behaviors seen on chat-gpt’. E.g.:

  • Starting a stream on the server, and keeping track of the the stream, run and thread. Then, users can stop runs, create new branches, edit previous messages, etc.
  • Provide a way for clients to consume the stream (as a QOL, and so the user can ‘watch the AI as it works’)
    I am sure there is an intended approach to do this with the assistants api, and that I am just missing it.

Forgive me if I have misunderstood any aspect of of the API or of Python’s generators, I have been a developer for many years - but am always learning.

Let me know if I can clarify anything. Thank you for any help!

3 Likes

I have exactly the same issue (node.js) and have no idea how I can access the runId. Has anyone have any idea?

Actually, I found a walkaround. I add the runId as a 1st stream chunk and then in the frontend I have if statement to check if it is the first chunk. If it is then I save it and carry on with the rest of the chunks as normal.

export async function POST(request) {
  try {
    const assistantId = process.env.XXX;
    const body = await request.json(); // Parse the incoming request body as JSON
    const threadId = body.threadId;

    const stream = await openai.beta.threads.runs.create(threadId, {
      assistant_id: assistantId,
      stream: true,
    });

    // Create an encoder to encode the data to be sent over the stream
    const encoder = new TextEncoder();

    // Create a readable stream to send SSE data
    const readableStream = new ReadableStream({
      async start(controller) {
        for await (const event of stream) {
          // The only way to pass the runId to the client is to add it to the stream
          // Get the runId from the last event before the first message delta
          if (event.event === "thread.message.in_progress") {
            const runId = event.data?.run_id;
            controller.enqueue(encoder.encode(JSON.stringify({ runId })));
          }

          // Queue the encoded message delta into the stream
          if (event.event === "thread.message.delta") {
            const delta = event.data?.delta?.content[0]?.text?.value;
            controller.enqueue(encoder.encode(delta));
          }
        }

        // Prevent more content from being added to the stream
        controller.close();
      },
    });

    // Return the readable stream with the appropriate headers for SSE
    return new Response(readableStream, {
      headers: {
        "Content-Type": "text/html; charset=utf-8",
        "Access-Control-Allow-Origin": "*",
      },
    });
  } catch (error) {
    console.error("Error creating stream:", error);
    return new Response(error.message, { status: error.status });
  }
}

Frontend

const streamRun = async (threadId, assistantMessageEl) => {
      const apiUrl = vercelApiRoot + "/stream-run";

      try {
        runStreamController = new AbortController();
        const { signal } = runStreamController;

        const response = await fetch(apiUrl, {
          method: "POST",
          headers: {
            "Content-Type": "application/json",
          },
          body: JSON.stringify({
            threadId,
          }),
          mode: "cors",
          signal,
        });

        if (!response.ok) {
          throw new Error(`HTTP error! Status: ${response.status}`);
        }

        // Change the loading message to thinking
        const loadingMessage = `
          <div class="loading-message">
            Thinking
            <div class="loading-spinner"></div>
          </div>
        `;
        assistantMessageEl.innerHTML = loadingMessage;

        // Change the send button text to stop
        sendBtn.innerHTML = "Stop";
        isGeneratingResponse = true;

        const reader = response.body.getReader();
        const decoder = new TextDecoder();

        let receivedText = "";
        while (true) {
          const { value, done } = await reader.read();
          if (done) break; // When no more data needs to be read, break the reading

          const decodedValue = decoder.decode(value, { stream: true });

          if (!activeRunId) {
            // Extract the run ID from the first chunk of the stream
            activeRunId = JSON.parse(decodedValue).runId;
          } else {
            receivedText += decodedValue;
          }

          // Replace the loading message and spinner with the received stream of text chunks
          fillMessageWithStream(assistantMessageEl, receivedText);
        }

        // Reset the send button text
        sendBtn.innerHTML = "Send";
        isGeneratingResponse = false;

        console.log("Stream completed");
      } catch (error) {
        if (error.name === "AbortError") {
          console.log("Stream aborted");
        } else {
          console.error("Error streaming run:", error.message);
        }
      } finally {
        activeRunId = null;
        runStreamController = null;
      }
    };

And I cancel with aborting the call in front-end + cancelling the run (request is sent to openAi to cancel the run)

console.log("Stopping (aborting) stream...");
if (runStreamController) runStreamController.abort();
if (activeRunId) cancelRun();