import inspect
import json
import traceback
from typing import Callable, Type, TypeVar
from pydantic import (
BaseModel,
# For generation of JSON schemas, Pydantic imports the `root_model` module at runtime: https://github.com/pydantic/pydantic/blob/main/pydantic/json_schema.py#L1500
# As `componentize-py` resolves imports at build time, we are required to add this import here.
RootModel, # noqa: F401
)
from .bindings import exports
from .bindings.types import Err
from .csi import Csi
from .wit_csi import WitCsi
UserInput = TypeVar("UserInput", bound=BaseModel)
UserOutput = TypeVar("UserOutput", bound=BaseModel)
[docs]
def skill(
func: Callable[[Csi, UserInput], UserOutput],
) -> Callable[[Csi, UserInput], UserOutput]:
"""Turn a function with a specific signature into a skill that can be deployed on Pharia Kernel.
The decorated function must be typed. It must have exactly two input arguments. The first argument
must be of type `Csi`. The second argument must be a Pydantic model. The type of the return value
must also be a Pydantic model. Each module is expected to have only one function that is decorated
with `skill`.
Example::
from pharia_skill import ChatParams, Csi, Message, skill
from pydantic import BaseModel
class Input(BaseModel):
topic: str
class Output(BaseModel):
haiku: str
@skill
def run(csi: Csi, input: Input) -> Output:
system = Message.system("You are a poet who strictly speaks in haikus.")
user = Message.user(input.topic)
params = ChatParams(max_tokens=64)
response = csi.chat("llama-3.1-8b-instruct", [system, user], params)
return Output(haiku=response.message.content.strip())
"""
# The import is inside the decorator to ensure the imports only run when the decorator is interpreted.
# This is because we can only import them when targeting the `skill` world.
# If we target the `message-stream-skill` world with a component and have the imports for the `skill` world
# in this module at the top-level, we will get a build error in case this module is in the module graph.
from .bindings.exports.skill_handler import (
Error_Internal,
Error_InvalidInput,
SkillMetadata,
)
signature = list(inspect.signature(func).parameters.values())
assert len(signature) == 2, "Skills must have exactly two arguments."
input_model: Type[UserInput] = signature[1].annotation
assert issubclass(input_model, BaseModel), (
"The second argument must be a Pydantic model"
)
assert func.__annotations__.get("return") is not None, (
"The function must have a return type annotation"
)
output_model: Type[UserOutput] = func.__annotations__["return"]
assert issubclass(output_model, BaseModel), (
"The return type must be a Pydantic model"
)
# This code here inside the decorator (but outside of the `class SkillHandler`) is executed at build time.
# In version 0.3 of the wit world, we did not account for the fact that the metadata method may return
# an error. However, as pydantic does some imports at runtime, we need to take this possibility into account.
# By calculating the metadata at build time, we can (in case there is an error) give the user direct feedback,
# instead of failing at runtime.
description = func.__doc__
input_schema = json.dumps(input_model.model_json_schema()).encode()
output_schema = json.dumps(output_model.model_json_schema()).encode()
metadata = SkillMetadata(description, input_schema, output_schema)
class SkillHandler(exports.SkillHandler):
def run(self, input: bytes) -> bytes:
"""This is the function that gets executed when running the Skill as a Wasm component."""
try:
validated = input_model.model_validate_json(input)
except Exception:
raise Err(Error_InvalidInput(traceback.format_exc()))
try:
result = func(WitCsi(), validated)
return result.model_dump_json().encode()
except Exception:
raise Err(Error_Internal(traceback.format_exc()))
def metadata(self) -> SkillMetadata:
return metadata
assert "SkillHandler" not in func.__globals__, "`@skill` can only be used once."
def trace_skill(csi: Csi, input: UserInput) -> UserOutput:
"""This function is returned by the decorator and executed at test time.
The `opentelemetry` library import is moved to within the function to not make it a dependency of the
Wasm component.
"""
from opentelemetry import trace
with trace.get_tracer(__name__).start_as_current_span(func.__name__) as span:
span.set_attribute("input", input.model_dump_json())
result = func(csi, input)
span.set_attribute("output", result.model_dump_json())
return result
func.__globals__["SkillHandler"] = SkillHandler
trace_skill.__globals__["SkillHandler"] = SkillHandler
return trace_skill