Multiple Tools in Assistant Streaming API

i found a solution that worked for me.

  1. i found a post on medium from HawkFlow.ai and that started me on the right track by using the event handler.
  2. refactored his code to suite my setup with how i have my functions
  3. created a list variable to store the function calls from on_tool_call_created
  4. looped through the tool calls and submitted them to my functions and added each response to a tool_outputs list
  5. submit that list in submit_tool_outputs.stream

sounds convoluted when writing it out but here’s my code:

Blockquote
class EventHandler(AssistantEventHandler):
def init(self, request, thread_id, assistant_id):
super().init()
self.output = None
self.request = request
self.tool_id = None
self.function_arguments = None
self.thread_id = thread_id
self.assistant_id = assistant_id
self.run_id = None
self.run_step = None
self.function_name = “”
self.arguments = “”
self.tool_calls =

@override
def on_text_created(self, text) -> None:
    print(f"\nassistant on_text_created > ", end="", flush=True)

@override
def on_text_delta(self, delta, snapshot):
    print(f"{delta.value}")

@override
def on_end(self, ):
    print(f"\n end assistant > ",self.current_run_step_snapshot, end="", flush=True)

@override
def on_exception(self, exception: Exception) -> None:
    print(f"\nassistant > {exception}\n", end="", flush=True)

@override
def on_message_created(self, message: Message) -> None:
    print(f"\nassistant on_message_created > {message}\n", end="", flush=True)

@override
def on_message_done(self, message: Message) -> None:
    print(f"\nassistant on_message_done > {message}\n", end="", flush=True)

@override
def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
    pass

def on_tool_call_created(self, tool_call):
    print(f"\nassistant on_tool_call_created > {tool_call}")
    self.function_name = tool_call.function.name
    self.function_arguments = tool_call.function.arguments  # Capture the arguments
    self.tool_id = tool_call.id
    print(f"\on_tool_call_created > run_step.status > {self.run_step.status}")
    print(f"\nassistant > {tool_call.type} {self.function_name}\n", flush=True)

    keep_retrieving_run = client.beta.threads.runs.retrieve(
        thread_id=self.thread_id,
        run_id=self.run_id
    )

    while keep_retrieving_run.status in ["queued", "in_progress"]: 
        keep_retrieving_run = client.beta.threads.runs.retrieve(
            thread_id=self.thread_id,
            run_id=self.run_id
        )
        print(f"\nSTATUS: {keep_retrieving_run.status}")

        for action in keep_retrieving_run.actions:
            if action['type'] == 'tool_call':
                tool_call = action['tool_call']

                self.tool_calls.append(tool_call)

@override
def on_tool_call_done(self, tool_call: ToolCall) -> None:       
    keep_retrieving_run = client.beta.threads.runs.retrieve(
        thread_id=self.thread_id,
        run_id=self.run_id
    )

    print(f"\nDONE STATUS: {keep_retrieving_run.status}")

    if keep_retrieving_run.status == "completed":
        all_messages = client.beta.threads.messages.list(
            thread_id=self.thread_id
        )

        print(all_messages.data[0].content[0].text.value, "", "")
        return

    elif keep_retrieving_run.status == "requires_action":
        print("here you would call your function")

        run = client.beta.threads.runs.retrieve(
            thread_id=self.thread_id,
            run_id=self.run_id
        )
        tool_outputs = []
        for tool_call in self.tool_calls:
            # process tool_call
            tool_call_result = ai_call(tool_call, self.request)
            tool_call_id = tool_call_result['tool_call_id']
            output = tool_call_result['output']
            output_json = json.dumps(output)
            formatted_output = {
                "tool_call_id": tool_call_id,
                "output": output_json
            }
            tool_outputs.append(formatted_output)

        with client.beta.threads.runs.submit_tool_outputs_stream(
            thread_id=self.thread_id,
            run_id=self.run_id,
            tool_outputs=[{
                "tool_call_id": self.tool_id,
                "output": json.dumps(tool_outputs),
            }],
            event_handler=EventHandler(self.request, self.thread_id, self.assistant_id)
        ) as stream:
            stream.until_done()                       
    else:
        print(f"\nassistant on_tool_call_done > {tool_call}\n", end="", flush=True)

@override
def on_run_step_created(self, run_step: RunStep) -> None:
    print(f"on_run_step_created")
    self.run_id = run_step.run_id
    self.run_step = run_step
    print("The type ofrun_step run step is ", type(run_step), flush=True)
    print(f"\n run step created assistant > {run_step}\n", flush=True)

@override
def on_run_step_done(self, run_step: RunStep) -> None:
    print(f"\n run step done assistant > {run_step}\n", flush=True)

def on_tool_call_delta(self, delta, snapshot): 
    if delta.type == 'function':
        print(delta.function.arguments, end="", flush=True)
        self.arguments += delta.function.arguments
    elif delta.type == 'code_interpreter':
        print(f"on_tool_call_delta > code_interpreter")
        if delta.code_interpreter.input:
            print(delta.code_interpreter.input, end="", flush=True)
        if delta.code_interpreter.outputs:
            print(f"\n\noutput >", flush=True)
            for output in delta.code_interpreter.outputs:
                if output.type == "logs":
                    print(f"\n{output.logs}", flush=True)
    else:
        print("ELSE")
        print(delta, end="", flush=True)

@override
def on_event(self, event: AssistantStreamEvent) -> None:
    if event.event == "thread.run.requires_action":
        print("\nthread.run.requires_action > submit tool call")
        print(f"ARGS: {self.arguments}")
1 Like