Source code for agentstr.database.postgres

import asyncpg
import json
from datetime import datetime, timezone
from typing import Self, Any, List, Literal

from agentstr.models import Message, User
from agentstr.database.base import BaseDatabase
from agentstr.logger import get_logger

logger = get_logger(__name__)


[docs] class PostgresDatabase(BaseDatabase): """PostgreSQL implementation using `asyncpg`.""" USER_TABLE_NAME = "agentstr_users" MESSAGE_TABLE_NAME = "agentstr_messages"
[docs] def __init__(self, conn_str: str, *, agent_name: str | None = None): super().__init__(conn_str, agent_name)
[docs] async def async_init(self) -> Self: logger.debug("Connecting to Postgres: %s", self.conn_str) self.conn = await asyncpg.connect(dsn=self.conn_str) await self._ensure_user_table() await self._ensure_message_table() return self
[docs] async def close(self) -> None: if self.conn: await self.conn.close() self.conn = None
# --------------------------- helpers ------------------------------- async def _ensure_user_table(self) -> None: await self.conn.execute( f"""CREATE TABLE IF NOT EXISTS {self.USER_TABLE_NAME} ( agent_name TEXT NOT NULL, user_id TEXT NOT NULL, available_balance INTEGER NOT NULL, current_thread_id TEXT, PRIMARY KEY (agent_name, user_id) )""" ) # Index for agent filtering await self.conn.execute( f"CREATE INDEX IF NOT EXISTS idx_{self.USER_TABLE_NAME}_agent_name ON {self.USER_TABLE_NAME} (agent_name)" ) async def _ensure_message_table(self) -> None: """Create message table if it doesn't exist.""" await self.conn.execute( f"""CREATE TABLE IF NOT EXISTS {self.MESSAGE_TABLE_NAME} ( agent_name TEXT NOT NULL, thread_id TEXT NOT NULL, idx INTEGER NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, message TEXT, content TEXT NOT NULL, kind TEXT, satoshis INTEGER, extra_inputs TEXT, extra_outputs TEXT, created_at TIMESTAMP NOT NULL, PRIMARY KEY (agent_name, thread_id, idx, user_id) )""", ) await self.conn.execute( f"CREATE INDEX IF NOT EXISTS idx_{self.MESSAGE_TABLE_NAME}_thread ON {self.MESSAGE_TABLE_NAME} (agent_name, thread_id)" ) await self.conn.execute( f"CREATE INDEX IF NOT EXISTS idx_{self.MESSAGE_TABLE_NAME}_user ON {self.MESSAGE_TABLE_NAME} (agent_name, user_id)" )
[docs] async def add_message( self, thread_id: str, user_id: str, role: Literal["user", "agent", "tool"], message: str = "", content: str = "", kind: str = "request", satoshis: int | None = None, extra_inputs: dict[str, Any] = {}, extra_outputs: dict[str, Any] = {}, ) -> Message: next_idx: int = await self.conn.fetchval( f"SELECT COALESCE(MAX(idx), -1) + 1 FROM {self.MESSAGE_TABLE_NAME} WHERE agent_name = $1 AND thread_id = $2", self.agent_name, thread_id, ) created_at = datetime.now(timezone.utc) await self.conn.execute( f"INSERT INTO {self.MESSAGE_TABLE_NAME} (agent_name, thread_id, idx, user_id, role, message, content, kind, satoshis, extra_inputs, extra_outputs, created_at) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12)", self.agent_name, thread_id, next_idx, user_id, role, message, content, kind, satoshis, json.dumps(extra_inputs) if extra_inputs else None, json.dumps(extra_outputs) if extra_outputs else None, created_at, ) return Message( agent_name=self.agent_name, thread_id=thread_id, idx=next_idx, user_id=user_id, role=role, message=message, content=content, kind=kind, satoshis=satoshis, extra_inputs=extra_inputs, extra_outputs=extra_outputs, created_at=created_at.astimezone(timezone.utc), )
[docs] async def get_messages( self, thread_id: str, user_id: str, *, limit: int | None = None, before_idx: int | None = None, after_idx: int | None = None, reverse: bool = False, ) -> List[Message]: """Retrieve messages for *thread_id* ordered by idx.""" base_query = f"SELECT * FROM {self.MESSAGE_TABLE_NAME} WHERE agent_name = $1 AND thread_id = $2 AND user_id = $3" params: list[Any] = [self.agent_name, thread_id, user_id] param_pos = 4 # next positional argument index for $ placeholders if after_idx is not None: base_query += f" AND idx > ${param_pos}" params.append(after_idx) param_pos += 1 if before_idx is not None: base_query += f" AND idx < ${param_pos}" params.append(before_idx) param_pos += 1 order = "DESC" if reverse else "ASC" base_query += f" ORDER BY idx {order}" if limit is not None: base_query += f" LIMIT ${param_pos}" params.append(limit) rows = await self.conn.fetch(base_query, *params) return [Message.from_row(row) for row in rows]
# --------------------------- API ---------------------------------- # ------------------- thread helpers -------------------
[docs] async def get_current_thread_id(self, user_id: str) -> str | None: """Return the current thread id for *user_id* within this agent scope.""" user = await self.get_user(user_id) return user.current_thread_id
[docs] async def set_current_thread_id(self, user_id: str, thread_id: str | None) -> None: """Persist *thread_id* as the current thread for *user_id*.""" user = await self.get_user(user_id) user.current_thread_id = thread_id await self.upsert_user(user)
# --------------------------- API ----------------------------------
[docs] async def get_user(self, user_id: str) -> User: logger.debug("[Postgres] Getting user %s", user_id) row = await self.conn.fetchrow( f"SELECT available_balance, current_thread_id FROM {self.USER_TABLE_NAME} WHERE agent_name = $1 AND user_id = $2", self.agent_name, user_id, ) if row: return User(user_id=user_id, available_balance=row["available_balance"], current_thread_id=row["current_thread_id"]) return User(user_id=user_id)
[docs] async def upsert_user(self, user: User) -> None: logger.debug("[Postgres] Upserting user %s", user) await self.conn.execute( f"""INSERT INTO {self.USER_TABLE_NAME} (agent_name, user_id, available_balance, current_thread_id) VALUES ($1, $2, $3, $4) ON CONFLICT (agent_name, user_id) DO UPDATE SET available_balance = EXCLUDED.available_balance, current_thread_id = EXCLUDED.current_thread_id""", self.agent_name, user.user_id, user.available_balance, user.current_thread_id, )