from collections.abc import Callable
from dataclasses import dataclass
from typing import Generic, Protocol, TypeVar
from pydantic import BaseModel
from pharia_skill.csi.inference import ChatStreamResponse, CompletionStreamResponse
Payload = TypeVar("Payload", bound=BaseModel | None)
[docs]
@dataclass
class MessageBegin:
role: str | None
[docs]
@dataclass
class MessageAppend:
text: str
[docs]
@dataclass
class MessageEnd(Generic[Payload]):
payload: Payload | None
MessageItem = MessageBegin | MessageAppend | MessageEnd[Payload]
[docs]
class MessageWriter(Protocol, Generic[Payload]):
"""Write messages to the output stream."""
[docs]
def write(self, item: MessageItem[Payload]) -> None: ...
[docs]
def begin_message(self, role: str | None = None) -> None:
self.write(MessageBegin(role))
[docs]
def append_to_message(self, text: str) -> None:
self.write(MessageAppend(text))
[docs]
def end_message(self, payload: Payload | None = None) -> None:
self.write(MessageEnd(payload))
[docs]
def forward_response(
self,
response: CompletionStreamResponse | ChatStreamResponse,
payload: Callable[..., Payload] | Payload | None = None,
) -> None:
match response:
case CompletionStreamResponse():
self._forward_completion(response, payload)
case ChatStreamResponse():
self._forward_chat(response, payload)
def _forward_completion(
self,
response: CompletionStreamResponse,
payload: Callable[..., Payload] | Payload | None = None,
) -> None:
self.begin_message()
for append in response.stream():
self.append_to_message(append.text)
self.end_message(payload(response) if callable(payload) else payload)
def _forward_chat(
self,
response: ChatStreamResponse,
payload: Callable[..., Payload] | Payload | None = None,
) -> None:
self.begin_message(response.role)
for append in response.stream():
self.append_to_message(append.content)
self.end_message(payload(response) if callable(payload) else payload)