"""OpenAI Provider""" from openai import AsyncOpenAI from app.providers.base import BaseLLMProvider, LLMMessage, LLMResponse class OpenAIProvider(BaseLLMProvider): def __init__(self, api_key: str, model: str = "gpt-4", base_url: str | None = None): self.api_key = api_key self.model = model self.client = AsyncOpenAI(api_key=api_key, base_url=base_url) async def chat(self, messages: list[LLMMessage]) -> LLMResponse: response = await self.client.chat.completions.create( model=self.model, messages=[{"role": m.role, "content": m.content} for m in messages], ) return LLMResponse( content=response.choices[0].message.content or "", model=response.model, usage={ "total_tokens": response.usage.total_tokens, "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, }, ) async def chat_stream(self, messages: list[LLMMessage]) -> str: stream = await self.client.chat.completions.create( model=self.model, messages=[{"role": m.role, "content": m.content} for m in messages], stream=True, ) async for chunk in stream: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content async def embed(self, texts: list[str]) -> list[list[float]]: response = await self.client.embeddings.create( model=self.model, input=texts, ) return [item.embedding for item in response.data]