Source code for pharia_skill.message_stream.decorator

import inspect
import traceback
from typing import Callable, Type, TypeVar

from pydantic import BaseModel

from pharia_skill import Csi
from pharia_skill.message_stream.writer import MessageWriter, Payload

UserInput = TypeVar("UserInput", bound=BaseModel)


[docs] def message_stream( func: Callable[[Csi, MessageWriter[Payload], UserInput], None], ) -> Callable[[Csi, MessageWriter[Payload], UserInput], None]: """Turn a function with a specific signature into a (streaming) skill that can be deployed on Pharia Kernel. By using the response object, a Skill decorated with `@message_stream` can return intermediate results that are streamed to the caller. The decorated function must be typed. It must have exactly three arguments. The first argument must be of type `Csi`. The second argument must be a `Response` object. The third argument must be a Pydantic model. The function must not return anything. Example:: from pharia_skill import Csi, ChatParams, Message, message_stream, MessageWriter from pydantic import BaseModel from pharia_skill.csi.inference import FinishReason class Input(BaseModel): topic: str class SkillOutput(BaseModel): finish_reason: FinishReason @message_stream def haiku_stream(csi: Csi, writer: MessageWriter[SkillOutput], input: Input) -> None: model = "llama-3.1-8b-instruct" messages = [ Message.system("You are a poet who strictly speaks in haikus."), Message.user(input.topic), ] params = ChatParams() with csi.chat_stream(model, messages, params) as response: writer.begin_message() for event in response.stream(): writer.append_to_message(event.content) writer.end_message(SkillOutput(finish_reason=response.finish_reason())) """ # The import is inside the decorator to ensure the imports only run when the decorator is interpreted. # This is because we can only import them when targeting the `message-stream-skill` world. # If we target the `skill` world with a component and have the imports for the `message-stream-skill` world # in this module at the top-level, we will get a build error in case this module is in the module graph. from pharia_skill.bindings import exports from pharia_skill.bindings.exports.message_stream import ( Error_Internal, Error_InvalidInput, ) from pharia_skill.bindings.imports import streaming_output as wit from pharia_skill.bindings.types import Err from pharia_skill.message_stream.wit_writer import WitMessageWriter from pharia_skill.wit_csi.csi import WitCsi signature = list(inspect.signature(func).parameters.values()) assert len(signature) == 3, ( "Message Stream Skills must have exactly three arguments." ) input_model: Type[UserInput] = signature[2].annotation assert isinstance(input_model, type) and issubclass(input_model, BaseModel), ( "The third argument must be a Pydantic model, found: " + str(input_model) ) assert func.__annotations__.get("return") is None, ( "The function must not return anything" ) # We don't require the schema of the end payload in the function definition. # The only use case for this would be to know the metadata of the end payload. # Since we don't do metadata for streaming skills at the moment, it is not needed. class MessageStream(exports.MessageStream): def run(self, input: bytes, output: wit.StreamOutput) -> None: """This is the function that gets executed when running the Skill as a Wasm component.""" try: validated = input_model.model_validate_json(input) except Exception: raise Err(Error_InvalidInput(traceback.format_exc())) try: with WitMessageWriter[Payload](output) as writer: func(WitCsi(), writer, validated) except Exception: raise Err(Error_Internal(traceback.format_exc())) assert "MessageStream" not in func.__globals__, ( "Make sure to decorate with `@message_stream` only once." ) def trace_message_stream( csi: Csi, writer: MessageWriter[Payload], input: UserInput ) -> None: """This function is returned by the decorator and executed at test time. The `opentelemetry` library import is moved to within the function to not make it a dependency of the Wasm component. """ from opentelemetry import trace from pharia_skill.testing.dev.streaming_output import MessageRecorder with trace.get_tracer(__name__).start_as_current_span(func.__name__) as span: writer.span = span # type: ignore span.set_attribute("input", input.model_dump_json()) func(csi, writer, input) # We do rely on the user passing in a message recorder at test time if they # want tracing to work. if isinstance(writer, MessageRecorder): span.set_attribute("output", writer.skill_output()) return func.__globals__["MessageStream"] = MessageStream trace_message_stream.__globals__["MessageStream"] = MessageStream return trace_message_stream