"""SQLite-backed persistence for Codex memory vector packets.
This module extends :mod:`memory_vector_bridge` by providing a durable SQL
storage layer for fragment collections and their derived vector embeddings.
It focuses on SQLite for ease of deployment, yet the SQL emitted is portable
and can be adapted to other relational engines if required. The store keeps
fragments, packet metadata and decoded vector payloads which allows for
subsequent similarity searches without regenerating embeddings.
"""
from __future__ import annotations
import json
import sqlite3
import time
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator, Mapping, MutableMapping, Sequence
from memory_vector_bridge import (
MemoryVectorPacket,
build_memory_vector_packet,
cosine_similarity,
)
from sqlite_helpers import normalise_sqlite_connect_target
__all__ = [
"MemorySQLStore",
"SQLPacketRecord",
]
def _clean_fragments(fragments: Iterable[str]) -> list[str]:
"""Return ``fragments`` stripped of whitespace and empty values."""
cleaned: list[str] = []
for fragment in fragments:
if not fragment:
continue
text = fragment.strip()
if text:
cleaned.append(text)
return cleaned
@dataclass(slots=True)
class SQLPacketRecord:
"""Representation of a stored packet for API responses."""
packet_id: int
session_key: str
label: str | None
summary: str
algorithm: str
vector_size: int
fragment_count: int
created_at: float
def as_dict(self) -> MutableMapping[str, object]:
"""Return a JSON-serialisable mapping."""
return {
"packet_id": self.packet_id,
"session_key": self.session_key,
"label": self.label,
"summary": self.summary,
"algorithm": self.algorithm,
"vector_size": self.vector_size,
"fragment_count": self.fragment_count,
"created_at": self.created_at,
}
class MemorySQLStore:
"""Persist memory fragments and vector packets in an SQLite database."""
def __init__(
self,
database: str | Path | None = "memory_vectors.db",
*,
timeout: float = 5.0,
connection: sqlite3.Connection | None = None,
) -> None:
"""Initialise the store.
Parameters
----------
database:
Path to the SQLite database file. Required when ``connection`` is
not provided. Ignored when ``connection`` is supplied.
timeout:
SQLite busy timeout used when the store owns the connection.
connection:
Optional pre-configured DB-API connection. When supplied the store
will reuse it instead of creating a new SQLite connection. This is
primarily used by integration flows that rely on externally managed
connection pools (for example, PostgreSQL via :mod:`psycopg`).
"""
if connection is None and database is None:
raise ValueError("database must be provided when connection is None")
self.database = str(database) if database is not None else ""
target, uri = normalise_sqlite_connect_target(self.database)
connect_kwargs = {
"timeout": timeout,
"detect_types": sqlite3.PARSE_DECLTYPES,
"check_same_thread": False,
}
if uri:
connect_kwargs["uri"] = True
self._connection = connection or sqlite3.connect(target, **connect_kwargs)
try:
self._connection.row_factory = sqlite3.Row
except Exception: # pragma: no cover - optional DB drivers may not expose row_factory
pass
self._owns_connection = connection is None
self._ensure_schema()
# ------------------------------------------------------------------
# connection helpers
# ------------------------------------------------------------------
@contextmanager
def _cursor(self) -> Iterator[sqlite3.Cursor]:
cursor = self._connection.cursor()
try:
yield cursor
self._connection.commit()
except Exception:
self._connection.rollback()
raise
finally:
cursor.close()
def close(self) -> None:
"""Close the underlying database connection when owned by the store."""
if self._owns_connection:
self._connection.close()
def __enter__(self) -> "MemorySQLStore":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
# ------------------------------------------------------------------
# schema management
# ------------------------------------------------------------------
def _ensure_schema(self) -> None:
with self._cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_key TEXT NOT NULL UNIQUE,
created_at REAL NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS fragments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
text TEXT NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS packets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
label TEXT,
summary TEXT NOT NULL,
algorithm TEXT NOT NULL,
encoded_vector TEXT NOT NULL,
vector_json TEXT NOT NULL,
vector_size INTEGER NOT NULL,
fragment_count INTEGER NOT NULL,
created_at REAL NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS packet_fragments (
packet_id INTEGER NOT NULL REFERENCES packets(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
text TEXT NOT NULL,
PRIMARY KEY(packet_id, position)
)
"""
)
def _ensure_session(self, session_key: str) -> int:
now = time.time()
with self._cursor() as cur:
cur.execute(
"INSERT OR IGNORE INTO sessions(session_key, created_at) VALUES (?, ?)",
(session_key, now),
)
cur.execute("SELECT id FROM sessions WHERE session_key = ?", (session_key,))
row = cur.fetchone()
if row is None: # pragma: no cover - defensive safeguard
raise RuntimeError(f"Failed to materialise session for key '{session_key}'")
return int(row[0])
def _lookup_session_id(self, session_key: str) -> int | None:
with self._cursor() as cur:
cur.execute("SELECT id FROM sessions WHERE session_key = ?", (session_key,))
row = cur.fetchone()
return int(row[0]) if row is not None else None
def _replace_session_fragments(self, session_id: int, fragments: Sequence[str]) -> None:
with self._cursor() as cur:
cur.execute("DELETE FROM fragments WHERE session_id = ?", (session_id,))
if fragments:
cur.executemany(
"INSERT INTO fragments(session_id, position, text) VALUES (?, ?, ?)",
((session_id, index, text) for index, text in enumerate(fragments)),
)
# ------------------------------------------------------------------
# public API
# ------------------------------------------------------------------
def store_fragments(self, session_key: str, fragments: Iterable[str]) -> list[str]:
"""Persist ``fragments`` for ``session_key`` and return the cleaned list."""
cleaned = _clean_fragments(fragments)
session_id = self._ensure_session(session_key)
self._replace_session_fragments(session_id, cleaned)
return cleaned
def store_packet(
self,
session_key: str,
packet: MemoryVectorPacket,
*,
fragments: Iterable[str] | None = None,
label: str | None = None,
) -> int:
"""Persist ``packet`` and optionally refresh associated ``fragments``."""
cleaned = _clean_fragments(fragments or []) if fragments is not None else None
session_id = self._ensure_session(session_key)
timestamp = time.time()
with self._cursor() as cur:
if cleaned is not None:
cur.execute("DELETE FROM fragments WHERE session_id = ?", (session_id,))
if cleaned:
cur.executemany(
"INSERT INTO fragments(session_id, position, text) VALUES (?, ?, ?)",
((session_id, index, text) for index, text in enumerate(cleaned)),
)
cur.execute(
"""
INSERT INTO packets(
session_id,
label,
summary,
algorithm,
encoded_vector,
vector_json,
vector_size,
fragment_count,
created_at
)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
session_id,
label,
packet.summary,
packet.algorithm,
packet.encoded_vector,
json.dumps(packet.decode()),
packet.vector_size,
packet.fragment_count,
timestamp,
),
)
packet_id = int(cur.lastrowid)
if cleaned is not None:
cur.execute("DELETE FROM packet_fragments WHERE packet_id = ?", (packet_id,))
if cleaned:
cur.executemany(
"INSERT INTO packet_fragments(packet_id, position, text) VALUES (?, ?, ?)",
((packet_id, index, text) for index, text in enumerate(cleaned)),
)
return packet_id
def store_from_fragments(
self,
session_key: str,
fragments: Iterable[str],
*,
label: str | None = None,
algorithm: str = "gzip",
) -> int:
"""Create and persist a packet derived from ``fragments``."""
packet = build_memory_vector_packet(fragments, algorithm=algorithm)
return self.store_packet(session_key, packet, fragments=fragments, label=label)
def list_sessions(self) -> list[str]:
"""Return all session keys known to the store."""
with self._cursor() as cur:
cur.execute("SELECT session_key FROM sessions ORDER BY session_key")
rows = cur.fetchall()
return [str(row[0]) for row in rows]
def list_packets(self, session_key: str | None = None) -> list[SQLPacketRecord]:
"""Return packets optionally filtered by ``session_key``."""
if session_key is None:
query = """
SELECT p.id, s.session_key, p.label, p.summary, p.algorithm,
p.vector_size, p.fragment_count, p.created_at
FROM packets AS p
JOIN sessions AS s ON s.id = p.session_id
ORDER BY p.created_at DESC, p.id DESC
"""
params: Sequence[object] = ()
else:
query = """
SELECT p.id, s.session_key, p.label, p.summary, p.algorithm,
p.vector_size, p.fragment_count, p.created_at
FROM packets AS p
JOIN sessions AS s ON s.id = p.session_id
WHERE s.session_key = ?
ORDER BY p.created_at DESC, p.id DESC
"""
params = (session_key,)
with self._cursor() as cur:
cur.execute(query, params)
rows = cur.fetchall()
return [
SQLPacketRecord(
packet_id=int(row[0]),
session_key=str(row[1]),
label=row[2] if row[2] is None else str(row[2]),
summary=str(row[3]),
algorithm=str(row[4]),
vector_size=int(row[5]),
fragment_count=int(row[6]),
created_at=float(row[7]),
)
for row in rows
]
def load_packet(self, packet_id: int) -> MemoryVectorPacket:
"""Return the stored packet instance for ``packet_id``."""
with self._cursor() as cur:
cur.execute(
"""
SELECT summary, algorithm, encoded_vector, vector_size, fragment_count
FROM packets
WHERE id = ?
""",
(packet_id,),
)
row = cur.fetchone()
if row is None:
raise KeyError(f"Packet {packet_id} not found")
return MemoryVectorPacket(
summary=str(row[0]),
algorithm=str(row[1]),
encoded_vector=str(row[2]),
vector_size=int(row[3]),
fragment_count=int(row[4]),
)
def packet_payload(self, packet_id: int) -> Mapping[str, object]:
"""Return the JSON payload stored for ``packet_id``."""
packet = self.load_packet(packet_id)
payload = packet.as_dict()
payload["packet_id"] = packet_id
return payload
def get_fragments_for_packet(self, packet_id: int) -> list[str]:
"""Return fragments captured when ``packet_id`` was stored."""
with self._cursor() as cur:
cur.execute(
"SELECT text FROM packet_fragments WHERE packet_id = ? ORDER BY position",
(packet_id,),
)
rows = cur.fetchall()
if rows:
return [str(row[0]) for row in rows]
with self._cursor() as cur:
cur.execute(
"""
SELECT f.text
FROM fragments AS f
JOIN packets AS p ON p.session_id = f.session_id
WHERE p.id = ?
ORDER BY f.position
""",
(packet_id,),
)
rows = cur.fetchall()
return [str(row[0]) for row in rows]
def get_session_fragments(self, session_key: str) -> list[str]:
"""Return all stored fragments for ``session_key``."""
session_id = self._lookup_session_id(session_key)
if session_id is None:
return []
with self._cursor() as cur:
cur.execute(
"SELECT text FROM fragments WHERE session_id = ? ORDER BY position",
(session_id,),
)
rows = cur.fetchall()
return [str(row[0]) for row in rows]
def search_similar(
self,
vector: Sequence[float],
*,
limit: int = 5,
min_similarity: float = 0.0,
) -> list[dict[str, object]]:
"""Return packets ranked by cosine similarity to ``vector``."""
if not vector:
return []
with self._cursor() as cur:
cur.execute(
"""
SELECT p.id, s.session_key, p.label, p.vector_json
"""FastAPI application exposing the SQL-backed memory vector store."""
from __future__ import annotations
from typing import Sequence
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, model_validator
from memory_sql_store import MemorySQLStore
from memory_vector_bridge import build_memory_vector_packet, fragments_to_embedding
app = FastAPI(title="Aeon Memory SQL API", version="1.0.0")
store = MemorySQLStore()
class FragmentRequest(BaseModel):
"""Request model storing fragments for a session."""
fragments: list[str] = Field(..., description="Ordered memory fragments")
class PacketCreateRequest(FragmentRequest):
"""Request payload for storing a packet created from fragments."""
session_key: str = Field(..., description="Identifier for the fragment collection")
label: str | None = Field(default=None, description="Optional human readable label")
algorithm: str = Field(default="gzip", description="Encoding algorithm for the vector")
class SearchRequest(BaseModel):
"""Payload for cosine similarity searches against stored packets."""
vector: list[float] | None = Field(default=None, description="Explicit vector for similarity search")
fragments: list[str] | None = Field(default=None, description="Fragments used to derive a search vector")
limit: int = Field(default=5, ge=1, le=50, description="Number of results to return")
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity threshold")
@model_validator(mode="after")
def _validate_vector(self) -> "SearchRequest":
if self.vector is None and self.fragments is None:
raise ValueError("Either 'vector' or 'fragments' must be provided")
if self.vector is not None and self.fragments is not None:
raise ValueError("Provide only one of 'vector' or 'fragments'")
return self
def _ensure_vector(sequence: Sequence[float]) -> list[float]:
try:
return [float(item) for item in sequence]
except (TypeError, ValueError) as exc: # pragma: no cover - defensive guard
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/sessions/{session_key}/fragments")
def update_fragments(session_key: str, request: FragmentRequest) -> dict[str, object]:
"""Store fragments for ``session_key`` and return the stored count."""
cleaned = store.store_fragments(session_key, request.fragments)
return {"session_key": session_key, "count": len(cleaned)}
@app.post("/packets", response_model=dict)
def create_packet(request: PacketCreateRequest) -> dict[str, object]:
"""Build a :class:`MemoryVectorPacket` from fragments and persist it."""
packet = build_memory_vector_packet(request.fragments, algorithm=request.algorithm)
packet_id = store.store_packet(
request.session_key,
packet,
fragments=request.fragments,
label=request.label,
)
payload = packet.as_dict()
payload["packet_id"] = packet_id
payload["session_key"] = request.session_key
payload["label"] = request.label
return payload
@app.get("/packets")
def list_packets(session_key: str | None = None) -> dict[str, object]:
"""Return stored packets optionally filtered by ``session_key``."""
records = [record.as_dict() for record in store.list_packets(session_key)]
return {"results": records}
@app.get("/packets/{packet_id}")
def get_packet(packet_id: int) -> dict[str, object]:
"""Return a stored packet and its associated fragments."""
try:
packet = store.load_packet(packet_id)
except KeyError as exc: # pragma: no cover - defensive guard
raise HTTPException(status_code=404, detail=str(exc)) from exc
fragments = store.get_fragments_for_packet(packet_id)
payload = packet.as_dict()
payload["packet_id"] = packet_id
payload["fragments"] = fragments
return payload
@app.post("/search")
def search_packets(request: SearchRequest) -> dict[str, object]:
"""Perform cosine similarity search against stored packets."""
if request.vector is not None:
vector = _ensure_vector(request.vector)
else:
vector = fragments_to_embedding(request.fragments or [])
if not vector:
raise HTTPException(status_code=400, detail="Unable to derive a search vector")
results = store.search_similar(vector, limit=request.limit, min_similarity=request.min_similarity)
return {"results": results}
@app.on_event("shutdown")
def shutdown_event() -> None:
"""Ensure the SQLite connection is closed on shutdown."""
store.close()
__all__ = ["app", "store"]
Wait… I thought memory was hard…