Source code for agentstr.database.sqlite

from typing import Optional, Any, List, Literal, Self
import json
import aiosqlite
from datetime import datetime, timezone

from agentstr.models import Message, User
from agentstr.database.base import BaseDatabase
from agentstr.logger import get_logger

logger = get_logger(__name__)


[docs] class SQLiteDatabase(BaseDatabase): """SQLite implementation using `aiosqlite`."""
[docs] def __init__(self, conn_str: Optional[str] = None, *, agent_name: str | None = None): super().__init__(conn_str or "sqlite://agentstr_local.db", agent_name) # Strip the scheme to obtain the filesystem path. self._db_path = self.conn_str.replace("sqlite://", "", 1)
# --------------------------- helpers ------------------------------- async def _ensure_user_table(self) -> None: async with self.conn.execute( """CREATE TABLE IF NOT EXISTS user ( agent_name TEXT NOT NULL, user_id TEXT NOT NULL, available_balance INTEGER NOT NULL, current_thread_id TEXT, PRIMARY KEY (agent_name, user_id) )""" ): pass # Index on agent_name for faster agent filtering await self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_user_agent_name ON user (agent_name)" ) await self.conn.commit() async def _ensure_message_table(self) -> None: async with self.conn.execute( """CREATE TABLE IF NOT EXISTS message ( agent_name TEXT NOT NULL, thread_id TEXT NOT NULL, idx INTEGER NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, message TEXT, content TEXT NOT NULL, kind TEXT, satoshis INTEGER, extra_inputs TEXT, extra_outputs TEXT, created_at DATETIME NOT NULL, PRIMARY KEY (agent_name, thread_id, idx, user_id) )""" ): pass # Index on agent_name for faster agent filtering await self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_message_agent_name ON message (agent_name)" ) # Index on thread_id for faster thread filtering await self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_message_thread_id ON message (thread_id)" ) await self.conn.commit() # --------------------------- API ----------------------------------
[docs] async def async_init(self) -> Self: self.conn = await aiosqlite.connect(self._db_path) # Return rows as mappings so we can access by column name self.conn.row_factory = aiosqlite.Row await self._ensure_user_table() await self._ensure_message_table() return self
[docs] async def close(self) -> None: if self.conn: await self.conn.close() self.conn = None
[docs] async def get_user(self, user_id: str) -> User: logger.debug("[SQLite] Getting user %s", user_id) async with self.conn.execute( "SELECT available_balance, current_thread_id FROM user WHERE agent_name = ? AND user_id = ?", (self.agent_name, user_id), ) as cursor: row = await cursor.fetchone() if row: return User(user_id=user_id, available_balance=row[0], current_thread_id=row[1]) return User(user_id=user_id)
[docs] async def get_current_thread_id(self, user_id: str) -> str | None: """Return the current thread id for *user_id* within this agent scope.""" user = await self.get_user(user_id) return user.current_thread_id
[docs] async def set_current_thread_id(self, user_id: str, thread_id: str | None) -> None: """Persist *thread_id* as the current thread for *user_id*.""" user = await self.get_user(user_id) user.current_thread_id = thread_id await self.upsert_user(user)
[docs] async def upsert_user(self, user: User) -> None: logger.debug("[SQLite] Upserting user %s", user) await self.conn.execute( """INSERT INTO user (agent_name, user_id, available_balance, current_thread_id) VALUES (?, ?, ?, ?) ON CONFLICT(agent_name, user_id) DO UPDATE SET available_balance = excluded.available_balance, current_thread_id = excluded.current_thread_id""", (self.agent_name, user.user_id, user.available_balance, user.current_thread_id), ) await self.conn.commit()
[docs] async def add_message( self, thread_id: str, user_id: str, role: Literal["user", "agent", "tool"], message: str = "", content: str = "", kind: str = "request", satoshis: int | None = None, extra_inputs: dict[str, Any] = {}, extra_outputs: dict[str, Any] = {}, ) -> Message: """Append a message to a thread and return the stored model.""" # Determine next index for thread async with self.conn.execute( "SELECT COALESCE(MAX(idx), -1) + 1 AS next_idx FROM message WHERE agent_name = ? AND thread_id = ?", (self.agent_name, thread_id), ) as cursor: row = await cursor.fetchone() next_idx = row[0] created_at = datetime.now(timezone.utc).isoformat() await self.conn.execute( "INSERT INTO message (agent_name, thread_id, idx, user_id, role, message, content, kind, satoshis, extra_inputs, extra_outputs, created_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)", ( self.agent_name, thread_id, next_idx, user_id, role, message, content, kind, satoshis, json.dumps(extra_inputs) if extra_inputs else None, json.dumps(extra_outputs) if extra_outputs else None, created_at, ), ) await self.conn.commit() return Message( agent_name=self.agent_name, thread_id=thread_id, idx=next_idx, user_id=user_id, role=role, message=message, content=content, kind=kind, satoshis=satoshis, extra_inputs=extra_inputs, extra_outputs=extra_outputs, created_at=datetime.fromisoformat(created_at).astimezone(timezone.utc), )
[docs] async def get_messages( self, thread_id: str, user_id: str, *, limit: int | None = None, before_idx: int | None = None, after_idx: int | None = None, reverse: bool = False, ) -> List[Message]: """Retrieve messages for *thread_id* with optional pagination.""" query = "SELECT * FROM message WHERE agent_name = ? AND thread_id = ? AND user_id = ?" params: list[Any] = [self.agent_name, thread_id, user_id] if after_idx is not None: query += " AND idx > ?" params.append(after_idx) if before_idx is not None: query += " AND idx < ?" params.append(before_idx) order = "DESC" if reverse else "ASC" query += f" ORDER BY idx {order}" if limit is not None: query += " LIMIT ?" params.append(limit) async with self.conn.execute(query, tuple(params)) as cursor: rows = await cursor.fetchall() return [Message.from_row(dict(r)) for r in rows]