feat: add AI coding assistant packages

This commit is contained in:
root 2026-05-19 03:59:50 +08:00
parent 98d113883e
commit 3213c98fc7
22 changed files with 568 additions and 0 deletions

View File

@ -0,0 +1,121 @@
"""Agent-internal message types.
All message types share a common `to_dict()` / `from_dict()` interface for
serialization. `AgentMessage.from_dict()` is a factory that dispatches to the
correct concrete class based on the ``role`` field.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class ToolCall:
"""A single tool call requested by the assistant."""
id: str
name: str
arguments: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {"id": self.id, "name": self.name, "arguments": self.arguments}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ToolCall:
return cls(
id=data["id"],
name=data["name"],
arguments=data.get("arguments", {}),
)
@dataclass(slots=True)
class SystemMessage:
role: str = field(default="system", init=False)
content: str
def to_dict(self) -> dict[str, Any]:
return {"role": self.role, "content": self.content}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> SystemMessage:
return cls(content=data["content"])
@dataclass(slots=True)
class UserMessage:
role: str = field(default="user", init=False)
content: str
def to_dict(self) -> dict[str, Any]:
return {"role": self.role, "content": self.content}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> UserMessage:
return cls(content=data["content"])
@dataclass(slots=True)
class AssistantMessage:
role: str = field(default="assistant", init=False)
content: str
tool_calls: list[ToolCall] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"role": self.role,
"content": self.content,
"tool_calls": [tc.to_dict() for tc in self.tool_calls],
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> AssistantMessage:
tool_calls = [
ToolCall.from_dict(tc) for tc in data.get("tool_calls", [])
]
return cls(content=data.get("content", ""), tool_calls=tool_calls)
@dataclass(slots=True)
class ToolResultMessage:
role: str = field(default="tool", init=False)
tool_call_id: str
content: str
is_error: bool = False
def to_dict(self) -> dict[str, Any]:
return {
"role": self.role,
"tool_call_id": self.tool_call_id,
"content": self.content,
"is_error": self.is_error,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ToolResultMessage:
return cls(
tool_call_id=data["tool_call_id"],
content=data["content"],
is_error=data.get("is_error", False),
)
# Union type for type annotations
AgentMessage = SystemMessage | UserMessage | AssistantMessage | ToolResultMessage
def agent_message_from_dict(data: dict[str, Any]) -> AgentMessage:
"""Factory: create the correct AgentMessage subclass from a dict."""
role = data.get("role")
dispatch: dict[str, type[AgentMessage]] = {
"system": SystemMessage,
"user": UserMessage,
"assistant": AssistantMessage,
"tool": ToolResultMessage,
}
cls = dispatch.get(role)
if cls is None:
raise ValueError(f"unknown role: {role!r}")
return cls.from_dict(data)

View File

@ -0,0 +1,109 @@
"""LLM Provider abstraction — types and ABC.
Defines the contract that every LLM provider adapter must implement,
plus shared data types for messages, tool definitions, and stream chunks.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
# ── LLM API message format ───────────────────────────────────
@dataclass(slots=True)
class Message:
"""A message in the LLM provider's native API format."""
role: str # "system" | "user" | "assistant" | "tool"
content: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
name: str | None = None
# ── Tool definition types ────────────────────────────────────
@dataclass(slots=True)
class ToolParameter:
"""Describes a single parameter of a tool."""
type: str
description: str = ""
required: bool = True
enum: list[str] | None = None
@dataclass(slots=True)
class ToolDef:
"""Tool definition passed to LLM providers."""
name: str
description: str
parameters: dict[str, ToolParameter] = field(default_factory=dict)
def to_openai_format(self) -> dict[str, Any]:
"""Convert to OpenAI function-calling tool format."""
properties: dict[str, Any] = {}
required: list[str] = []
for pname, param in self.parameters.items():
prop: dict[str, Any] = {"type": param.type, "description": param.description}
if param.enum is not None:
prop["enum"] = param.enum
properties[pname] = prop
if param.required:
required.append(pname)
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
# ── Stream chunk ─────────────────────────────────────────────
@dataclass(slots=True)
class StreamChunk:
"""A single chunk emitted during LLM streaming."""
type: str # "content" | "tool_call" | "done" | "error"
content: str | None = None
tool_call: dict[str, Any] | None = None
# ── LLM Provider ABC ─────────────────────────────────────────
class LLMProvider(ABC):
"""Abstract base class for LLM provider adapters."""
@abstractmethod
async def stream_chat(
self,
messages: list[Message],
tools: list[ToolDef] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamChunk]:
"""Stream a chat completion from the LLM."""
@abstractmethod
def convert_messages(self, agent_messages: list[Any]) -> list[Message]:
"""Convert agent-internal messages to the provider's Message format."""
@abstractmethod
def convert_tools(self, tools: list[ToolDef]) -> list[dict[str, Any]]:
"""Convert ToolDef list to the provider's native tool format."""

View File

@ -0,0 +1,30 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "ai-coding-assistant"
version = "0.1.0"
description = "An extensible AI coding assistant with multi-LLM support"
requires-python = ">=3.11"
dependencies = [
"httpx>=0.27",
"pydantic>=2.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"respx>=0.21",
]
[project.scripts]
ai-ca = "ai_coding_assistant.cli:main"
[tool.hatch.build.targets.wheel]
packages = ["packages/core/ai_coding_assistant"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View File

View File

@ -0,0 +1,156 @@
"""Tests for agent message models — RED phase."""
from __future__ import annotations
import json
import pytest
from ai_coding_assistant.agent.messages import (
AgentMessage,
AssistantMessage,
SystemMessage,
ToolCall,
ToolResultMessage,
UserMessage,
agent_message_from_dict,
)
# ── Construction ──────────────────────────────────────────────
class TestUserMessage:
def test_create(self):
msg = UserMessage(content="hello")
assert msg.role == "user"
assert msg.content == "hello"
def test_to_dict(self):
msg = UserMessage(content="hello")
d = msg.to_dict()
assert d == {"role": "user", "content": "hello"}
def test_from_dict(self):
msg = UserMessage.from_dict({"role": "user", "content": "hello"})
assert msg.role == "user"
assert msg.content == "hello"
class TestSystemMessage:
def test_create(self):
msg = SystemMessage(content="you are helpful")
assert msg.role == "system"
assert msg.content == "you are helpful"
def test_to_dict_roundtrip(self):
original = SystemMessage(content="you are helpful")
restored = SystemMessage.from_dict(original.to_dict())
assert restored.content == original.content
class TestAssistantMessage:
def test_create_text_only(self):
msg = AssistantMessage(content="hi there")
assert msg.role == "assistant"
assert msg.content == "hi there"
assert msg.tool_calls == []
def test_create_with_tool_calls(self):
tc = ToolCall(id="call_1", name="echo", arguments={"text": "hello"})
msg = AssistantMessage(content="", tool_calls=[tc])
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0].name == "echo"
def test_to_dict_with_tool_calls(self):
tc = ToolCall(id="call_1", name="echo", arguments={"text": "hello"})
msg = AssistantMessage(content="", tool_calls=[tc])
d = msg.to_dict()
assert d["role"] == "assistant"
assert len(d["tool_calls"]) == 1
assert d["tool_calls"][0]["id"] == "call_1"
def test_from_dict_with_tool_calls(self):
data = {
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "call_1", "name": "echo", "arguments": {"text": "hello"}}
],
}
msg = AssistantMessage.from_dict(data)
assert msg.role == "assistant"
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0].name == "echo"
assert msg.tool_calls[0].arguments == {"text": "hello"}
def test_roundtrip(self):
tc = ToolCall(id="call_1", name="echo", arguments={"text": "hello"})
original = AssistantMessage(content="thinking...", tool_calls=[tc])
restored = AssistantMessage.from_dict(original.to_dict())
assert restored.content == original.content
assert len(restored.tool_calls) == 1
assert restored.tool_calls[0].id == tc.id
class TestToolResultMessage:
def test_create(self):
msg = ToolResultMessage(
tool_call_id="call_1", content="result text", is_error=False
)
assert msg.role == "tool"
assert msg.tool_call_id == "call_1"
assert msg.is_error is False
def test_create_error(self):
msg = ToolResultMessage(
tool_call_id="call_2", content="something failed", is_error=True
)
assert msg.is_error is True
def test_to_dict_roundtrip(self):
original = ToolResultMessage(
tool_call_id="call_1", content="result text", is_error=False
)
restored = ToolResultMessage.from_dict(original.to_dict())
assert restored.tool_call_id == original.tool_call_id
assert restored.content == original.content
assert restored.is_error == original.is_error
# ── AgentMessage union helpers ────────────────────────────────
class TestAgentMessageUnion:
def test_from_dict_user(self):
msg = agent_message_from_dict({"role": "user", "content": "hi"})
assert isinstance(msg, UserMessage)
def test_from_dict_system(self):
msg = agent_message_from_dict({"role": "system", "content": "sys"})
assert isinstance(msg, SystemMessage)
def test_from_dict_assistant(self):
msg = agent_message_from_dict({"role": "assistant", "content": "hi"})
assert isinstance(msg, AssistantMessage)
def test_from_dict_tool(self):
msg = agent_message_from_dict(
{"role": "tool", "tool_call_id": "c1", "content": "ok", "is_error": False}
)
assert isinstance(msg, ToolResultMessage)
def test_from_dict_unknown_role_raises(self):
with pytest.raises(ValueError, match="unknown role"):
agent_message_from_dict({"role": "alien", "content": "???"})
def test_json_serializable(self):
"""All message types should be JSON-serializable via to_dict()."""
tc = ToolCall(id="call_1", name="echo", arguments={"text": "hi"})
msgs: list[AgentMessage] = [
SystemMessage(content="system"),
UserMessage(content="hello"),
AssistantMessage(content="", tool_calls=[tc]),
ToolResultMessage(tool_call_id="call_1", content="hi", is_error=False),
]
for msg in msgs:
json_str = json.dumps(msg.to_dict())
assert json_str # non-empty

View File

@ -0,0 +1,147 @@
"""Tests for LLM Provider abstraction — RED phase."""
from __future__ import annotations
import pytest
from ai_coding_assistant.llm.base import (
LLMProvider,
Message,
StreamChunk,
ToolDef,
ToolParameter,
)
# ── Data model construction ──────────────────────────────────
class TestMessage:
def test_user_message(self):
msg = Message(role="user", content="hello")
assert msg.role == "user"
assert msg.content == "hello"
assert msg.tool_calls is None
def test_assistant_with_tool_calls(self):
msg = Message(
role="assistant",
content="",
tool_calls=[{"id": "c1", "name": "echo", "arguments": {"text": "hi"}}],
)
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0]["name"] == "echo"
def test_system_message(self):
msg = Message(role="system", content="you are helpful")
assert msg.role == "system"
def test_tool_message(self):
msg = Message(
role="tool",
content="result",
tool_call_id="c1",
)
assert msg.role == "tool"
assert msg.tool_call_id == "c1"
class TestToolParameter:
def test_create(self):
param = ToolParameter(
type="string",
description="text to echo",
)
assert param.type == "string"
def test_with_enum(self):
param = ToolParameter(
type="string",
description="mode",
enum=["fast", "slow"],
)
assert param.enum == ["fast", "slow"]
class TestToolDef:
def test_create(self):
td = ToolDef(
name="echo",
description="Echo back the input",
parameters={
"text": ToolParameter(type="string", description="text to echo"),
},
)
assert td.name == "echo"
assert "text" in td.parameters
def test_to_openai_format(self):
td = ToolDef(
name="echo",
description="Echo back the input",
parameters={
"text": ToolParameter(type="string", description="text to echo"),
},
)
result = td.to_openai_format()
assert result["type"] == "function"
assert result["function"]["name"] == "echo"
assert "text" in result["function"]["parameters"]["properties"]
class TestStreamChunk:
def test_content_chunk(self):
chunk = StreamChunk(type="content", content="Hello")
assert chunk.type == "content"
assert chunk.content == "Hello"
assert chunk.tool_call is None
def test_tool_call_chunk(self):
chunk = StreamChunk(
type="tool_call",
tool_call={"id": "c1", "name": "echo", "arguments": "{}"},
)
assert chunk.type == "tool_call"
assert chunk.tool_call["name"] == "echo"
def test_done_chunk(self):
chunk = StreamChunk(type="done")
assert chunk.type == "done"
def test_error_chunk(self):
chunk = StreamChunk(type="error", content="rate limited")
assert chunk.type == "error"
assert chunk.content == "rate limited"
# ── LLMProvider ABC ──────────────────────────────────────────
class TestLLMProviderABC:
def test_cannot_instantiate(self):
"""LLMProvider is abstract and cannot be instantiated directly."""
with pytest.raises(TypeError):
LLMProvider() # type: ignore[abstract]
def test_subclass_must_implement_methods(self):
"""A subclass that doesn't implement all abstract methods cannot be instantiated."""
class IncompleteProvider(LLMProvider):
pass
with pytest.raises(TypeError):
IncompleteProvider() # type: ignore[abstract]
def test_subclass_can_be_created_with_all_methods(self):
"""A subclass that implements all abstract methods can be instantiated."""
class DummyProvider(LLMProvider):
async def stream_chat(self, messages, tools=None, **kwargs):
yield StreamChunk(type="done")
def convert_messages(self, agent_messages):
return []
def convert_tools(self, tools):
return []
provider = DummyProvider()
assert isinstance(provider, LLMProvider)

5
memory/2026-05-13.md Normal file
View File

@ -0,0 +1,5 @@
# 2026-05-13
## 纪要
- 工作区已初始化,可在此持续记录当天上下文。