import asyncio
import json
import time
import uuid
from collections.abc import Callable
from expiringdict import ExpiringDict
from pydantic import BaseModel
from pynostr.encrypted_dm import EncryptedDirectMessage
from pynostr.event import Event, EventKind
from pynostr.filters import Filters
from pynostr.key import PrivateKey, PublicKey
from pynostr.utils import get_public_key, get_timestamp
from websockets.asyncio.client import connect
from websockets.exceptions import ConnectionClosedError
from agentstr.logger import get_logger
logger = get_logger(__name__)
[docs]
class DecryptedMessage(BaseModel):
"""A decrypted message from a Nostr relay."""
event: Event #: The Nostr event containing the message.
message: str #: The decrypted message content.
[docs]
def create_subscription(filters: Filters) -> list[str]:
"""Create a subscription for the given filters.
Args:
filters: The filters to apply to the subscription.
Returns:
A list containing the subscription request components.
"""
return ["REQ", uuid.uuid4().hex, filters.to_dict()]
[docs]
class EventRelay:
"""Handles communication with a single Nostr relay.
Args:
relay: WebSocket URL of the Nostr relay.
private_key: Private key for signing events.
public_key: Optional public key (derived from private_key if not provided).
"""
[docs]
def __init__(self, relay: str, private_key: PrivateKey | None = None, public_key: PublicKey | None = None):
self.relay = relay
self.private_key = private_key
self.public_key = public_key if public_key else (self.private_key.public_key if self.private_key else None)
[docs]
async def get_events(self, filters: Filters, limit: int = 10, timeout: int = 30, close_on_eose: bool = True) -> list[Event]:
"""Fetch events matching the given filters from this relay.
Args:
filters: The filters to apply when fetching events.
limit: Maximum number of events to return. Defaults to 10.
timeout: Maximum time to wait for events in seconds. Defaults to 30.
close_on_eose: Whether to close the subscription after EOSE. Defaults to True.
Returns:
A list of up to `limit` events that match the filters, or an empty list if none found.
Note:
Times out after `timeout` seconds if no matching events are found.
"""
limit = filters.limit if filters.limit else limit
subscription = create_subscription(filters)
events = []
t0 = time.time()
time_remaining = timeout
async with connect(self.relay) as ws:
logger.debug(f"Sending subscription: {json.dumps(subscription)}")
await ws.send(json.dumps(subscription))
t0 = time.time()
found = 0
await asyncio.sleep(0)
try:
while time.time() < t0 + timeout and found < limit:
response = await asyncio.wait_for(ws.recv(), timeout=time_remaining)
response = json.loads(response)
logger.debug(f"Received full message in get_events: {response}")
if (len(response) > 2):
found += 1
logger.debug(f"Received message {found} in get_event: {response[2]}")
events.append(Event.from_dict(response[2]))
else:
if response[0] == "EOSE":
logger.debug("Received EOSE in get_events")
if close_on_eose:
logger.debug("Closing connection on EOSE.")
break
else:
logger.warning(f"Invalid event: {response}")
await asyncio.sleep(0)
time_remaining = t0 + timeout - time.time()
if time_remaining <= 0:
raise TimeoutError()
except TimeoutError:
logger.warning("Timeout in get_events")
pass
return events
[docs]
async def get_event(self, filters: Filters, timeout: int = 30, close_on_eose: bool = True) -> Event | None:
"""Get a single event matching the filters or None if not found."""
events = await self.get_events(filters, limit=1, timeout=timeout, close_on_eose=close_on_eose)
if len(events) > 0:
return events[0]
else:
return None
[docs]
async def send_event(self, event: Event):
"""Publish an event to this relay."""
if not event.sig:
event.sign(self.private_key.hex())
message = event.to_message()
async with connect(self.relay) as ws:
logger.debug(f"Sending message: {message}")
await ws.send(message)
response = await ws.recv()
logger.debug(f"Received send_event response: {response}")
[docs]
def decrypt_message(self, event: Event) -> DecryptedMessage | None:
if event and event.has_pubkey_ref(self.public_key.hex()):
rdm = EncryptedDirectMessage.from_event(event)
rdm.decrypt(self.private_key.hex(), public_key_hex=event.pubkey)
logger.debug(f"New dm received: {event.date_time()} {rdm.cleartext_content}")
return DecryptedMessage(
event=event,
message=rdm.cleartext_content,
)
return None
[docs]
async def send_message(self, message: str | dict, recipient_pubkey: str, event_ref: str | None = None) -> Event:
recipient = get_public_key(recipient_pubkey)
dm = EncryptedDirectMessage(reference_event_id=event_ref)
if isinstance(message, dict):
message = json.dumps(message)
dm.encrypt(self.private_key.hex(), cleartext_content=message, recipient_pubkey=recipient.hex())
dm_event = dm.to_event()
await self.send_event(dm_event)
return dm_event
[docs]
async def receive_message(self, author_pubkey: str, timestamp: int | None = None, timeout: int = 30) -> DecryptedMessage | None:
"""Wait for and return the next direct message from the specified author."""
author = get_public_key(author_pubkey)
authors = [author.hex()]
filters = Filters(authors=authors, kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE],
pubkey_refs=[self.public_key.hex()], since=timestamp or get_timestamp(), limit=1)
event = await self.get_event(filters, timeout, close_on_eose=False)
if event:
return self.decrypt_message(event)
return None
[docs]
async def send_receive_message(self, message: str | dict, recipient_pubkey: str, timeout: int = 3, event_ref: str | None = None) -> DecryptedMessage | None:
dm_event = await self.send_message(message, recipient_pubkey, event_ref)
timestamp = dm_event.created_at
return await self.receive_message(recipient_pubkey, timestamp, timeout)
[docs]
async def event_listener(self, filters: Filters, callback: Callable[[Event], None], event_cache: ExpiringDict, lock: asyncio.Lock):
"""Continuously listen for events matching filters and call the callback for each one."""
subscription = create_subscription(filters)
logger.debug(f"Sending note subscription: {json.dumps(subscription)}")
latest_timestamp = filters.since or get_timestamp()
while True:
try:
async with connect(self.relay) as ws:
await ws.send(json.dumps(subscription))
while True:
response = await ws.recv()
response = json.loads(response)
if (len(response) > 2):
event = Event.from_dict(response[2])
logger.debug(f"Checking lock with event id: {event.id}")
latest_timestamp = event.created_at
async with lock:
if event.id in event_cache:
continue
event_cache[event.id] = True
logger.info(f"Event listener received event {event.id[:10]}: {event.content}")
try:
await callback(event)
except Exception as e:
logger.error(f"Error in event_listener callback: {e}")
await asyncio.sleep(0)
except Exception as e:
logger.warning(f"Connection closed in event_listener at {int(time.time())} trying again: {e}")
filters.since = latest_timestamp + 1
subscription = create_subscription(filters)
logger.debug(f"Sending event subscription: {json.dumps(subscription)}")
await asyncio.sleep(0)
[docs]
async def direct_message_listener(self, filters: Filters, callback: Callable[[Event, str], None], event_cache: ExpiringDict, lock: asyncio.Lock):
"""Listen for direct messages and call the callback with decrypted content."""
subscription = create_subscription(filters)
logger.debug(f"Sending DM subscription: {json.dumps(subscription)}")
latest_timestamp = filters.since or get_timestamp()
while True:
try:
async with connect(self.relay) as ws:
await ws.send(json.dumps(subscription))
while True:
response = await ws.recv()
response = json.loads(response)
if (len(response) > 2):
logger.debug(f"Received message in direct_message_listener: {response[2]}")
event = Event.from_dict(response[2])
logger.debug(f"Checking lock with event id: {event.id}")
latest_timestamp = event.created_at
async with lock:
if event.id in event_cache:
continue
event_cache[event.id] = True
dm = self.decrypt_message(event)
if dm:
logger.info(f"Listener received DM from {event.pubkey[:10]}: {dm.message}")
try:
await callback(dm.event, dm.message)
except Exception as e:
logger.error(f"Error in direct_message_listener callback: {e}")
await asyncio.sleep(0)
except Exception as e:
logger.warning(f"Connection closed in direct_message_listener at {int(time.time())} trying again: {e}")
filters.since = latest_timestamp + 1
subscription = create_subscription(filters)
logger.debug(f"Sending DM subscription: {json.dumps(subscription)}")
await asyncio.sleep(0)