Files
opencode-session-manager/database.py
2026-03-08 14:22:39 -06:00

463 lines
17 KiB
Python

"""Database layer for OpenCode session operations."""
import sqlite3
from pathlib import Path
from typing import Optional, List, Tuple
from dataclasses import dataclass, field
from datetime import datetime
@dataclass
class Project:
"""Project data structure."""
id: str
worktree: str
name: Optional[str]
vcs: Optional[str]
time_created: int
time_updated: int
@dataclass
class Workspace:
"""Workspace data structure."""
id: str
project_id: str
branch: Optional[str]
type: str
name: Optional[str]
directory: Optional[str]
@dataclass
class Session:
"""Session data structure."""
id: str
project_id: str
workspace_id: Optional[str]
parent_id: Optional[str]
slug: str
directory: str
title: str
version: str
time_created: int
time_updated: int
time_archived: Optional[int] = None
# Computed fields
message_count: int = 0
todo_count: int = 0
workspace_name: Optional[str] = None
project_name: Optional[str] = None
class Database:
"""Database connection and query handler."""
def __init__(self, db_path: str = "opencode.db"):
self.db_path = Path(db_path)
self.conn: Optional[sqlite3.Connection] = None
self.projects: dict[str, Project] = {}
self.workspaces: dict[str, Workspace] = {}
self.workspaces_by_project: dict[str, List[Workspace]] = {}
def connect(self):
"""Establish database connection."""
if not self.db_path.exists():
raise FileNotFoundError(f"Database not found: {self.db_path}")
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
def load_reference_data(self):
"""Load projects and workspaces into memory for fast lookups."""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
# Load projects
cursor.execute("SELECT id, worktree, name, vcs, time_created, time_updated FROM project")
for row in cursor.fetchall():
proj = Project(
id=row["id"],
worktree=row["worktree"],
name=row["name"],
vcs=row["vcs"],
time_created=row["time_created"],
time_updated=row["time_updated"]
)
self.projects[proj.id] = proj
# Load workspaces
cursor.execute("SELECT id, project_id, branch, type, name, directory FROM workspace")
for row in cursor.fetchall():
ws = Workspace(
id=row["id"],
project_id=row["project_id"],
branch=row["branch"],
type=row["type"],
name=row["name"],
directory=row["directory"]
)
self.workspaces[ws.id] = ws
# Build workspace lookup by project
self.workspaces_by_project = {}
for ws in self.workspaces.values():
self.workspaces_by_project.setdefault(ws.project_id, []).append(ws)
def get_sessions(self,
include_archived: bool = False,
project_id: Optional[str] = None,
workspace_id: Optional[str] = None,
search: str = "") -> List[Session]:
"""
Fetch sessions with optional filtering.
Args:
include_archived: Include archived sessions
project_id: Filter by project
workspace_id: Filter by workspace
search: Text search in title or slug
Returns:
List of Session objects with computed counts
"""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
query = """
SELECT s.*,
(SELECT COUNT(*) FROM message WHERE session_id = s.id) as msg_count,
(SELECT COUNT(*) FROM todo WHERE session_id = s.id) as todo_count
FROM session s
WHERE 1=1
"""
params = []
if not include_archived:
query += " AND s.time_archived IS NULL"
if project_id:
query += " AND s.project_id = ?"
params.append(project_id)
if workspace_id:
query += " AND s.workspace_id = ?"
params.append(workspace_id)
if search:
query += " AND (LOWER(s.title) LIKE LOWER(?) OR LOWER(s.slug) LIKE LOWER(?))"
params.extend([f"%{search}%", f"%{search}%"])
query += " ORDER BY s.time_created DESC"
cursor.execute(query, params)
rows = cursor.fetchall()
sessions = []
for row in rows:
sess = Session(
id=row["id"],
project_id=row["project_id"],
workspace_id=row["workspace_id"],
parent_id=row["parent_id"],
slug=row["slug"],
directory=row["directory"],
title=row["title"],
version=row["version"],
time_created=row["time_created"],
time_updated=row["time_updated"],
time_archived=row["time_archived"],
message_count=row["msg_count"],
todo_count=row["todo_count"]
)
# Add project name
if sess.project_id in self.projects:
sess.project_name = self.projects[sess.project_id].name or self.projects[sess.project_id].worktree
# Add workspace name
if sess.workspace_id and sess.workspace_id in self.workspaces:
ws = self.workspaces[sess.workspace_id]
sess.workspace_name = ws.name or ws.branch or ws.id
sessions.append(sess)
return sessions
def get_project_tree(self) -> List[Tuple[str, List[str], bool]]:
"""
Build a hierarchical view of projects and workspaces.
Returns:
List of tuples: (project_name, [workspace_names], has_workspaces)
"""
result = []
for proj in self.projects.values():
ws_list = self.workspaces_by_project.get(proj.id, [])
ws_names = [w.name or w.branch or w.id for w in ws_list]
has_workspaces = len(ws_list) > 0
result.append((proj.name or proj.worktree, ws_names, has_workspaces))
return sorted(result)
def get_session_counts_by_project(self) -> dict:
"""Return a dict of project_id -> active session count."""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
cursor.execute(
"SELECT project_id, COUNT(*) FROM session WHERE time_archived IS NULL GROUP BY project_id"
)
return {row[0]: row[1] for row in cursor.fetchall()}
def get_session(self, session_id: str) -> Optional[Session]:
"""Get a single session by ID."""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM session WHERE id = ?", (session_id,))
row = cursor.fetchone()
if not row:
return None
sess = Session(
id=row["id"],
project_id=row["project_id"],
workspace_id=row["workspace_id"],
parent_id=row["parent_id"],
slug=row["slug"],
directory=row["directory"],
title=row["title"],
version=row["version"],
time_created=row["time_created"],
time_updated=row["time_updated"],
time_archived=row["time_archived"],
)
# Message count
cursor.execute("SELECT COUNT(*) FROM message WHERE session_id = ?", (session_id,))
sess.message_count = cursor.fetchone()[0]
# Project name
if sess.project_id in self.projects:
proj = self.projects[sess.project_id]
sess.project_name = proj.name or proj.worktree
# Workspace name
if sess.workspace_id and sess.workspace_id in self.workspaces:
ws = self.workspaces[sess.workspace_id]
sess.workspace_name = ws.name or ws.branch or ws.id
return sess
def move_sessions(self, session_ids: List[str], target_project_id: str,
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str]]:
"""
Move sessions to a different project (and optionally workspace).
Returns:
(success, list of SQL statements executed)
"""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
# Verify target project exists
if target_project_id not in self.projects:
return False, [f"ERROR: Project {target_project_id} not found"]
# Verify workspace exists and belongs to target project
if target_workspace_id:
if target_workspace_id not in self.workspaces:
return False, [f"ERROR: Workspace {target_workspace_id} not found"]
if self.workspaces[target_workspace_id].project_id != target_project_id:
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"]
sql_statements = []
for sess_id in session_ids:
if target_workspace_id:
sql = "UPDATE session SET project_id = ?, workspace_id = ? WHERE id = ?"
cursor.execute(sql, (target_project_id, target_workspace_id, sess_id))
sql_statements.append(f"UPDATE session SET project_id = {target_project_id}, workspace_id = {target_workspace_id} WHERE id = {sess_id}")
else:
# Keep current workspace if any, or set to NULL
sql = "UPDATE session SET project_id = ? WHERE id = ?"
cursor.execute(sql, (target_project_id, sess_id))
sql_statements.append(f"UPDATE session SET project_id = {target_project_id} WHERE id = {sess_id}")
self.conn.commit()
return True, sql_statements
def copy_sessions(self, session_ids: List[str], target_project_id: str,
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str], List[str]]:
"""
Copy sessions to a target project.
Returns:
(success, list of SQL statements, list of new session IDs)
"""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
import uuid
cursor = self.conn.cursor()
# Verify target project exists
if target_project_id not in self.projects:
return False, [f"ERROR: Project {target_project_id} not found"], []
# Verify workspace exists and belongs to target project
if target_workspace_id:
if target_workspace_id not in self.workspaces:
return False, [f"ERROR: Workspace {target_workspace_id} not found"], []
if self.workspaces[target_workspace_id].project_id != target_project_id:
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"], []
new_session_ids = []
sql_statements = []
for sess_id in session_ids:
# Get original session
cursor.execute("SELECT * FROM session WHERE id = ?", (sess_id,))
row = cursor.fetchone()
if not row:
continue
# Create new session ID
new_id = "ses_" + uuid.uuid4().hex
# Insert new session
cols = [k for k in row.keys() if k != "id"]
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
values = [row[col] for col in cols]
# Override project_id and optionally workspace_id
project_idx = cols.index("project_id")
values[project_idx] = target_project_id
if target_workspace_id:
if "workspace_id" in cols:
ws_idx = cols.index("workspace_id")
values[ws_idx] = target_workspace_id
else:
# Clear workspace if not specified
if "workspace_id" in cols:
ws_idx = cols.index("workspace_id")
values[ws_idx] = None
sql = f"INSERT INTO session (id, {col_list}) VALUES (?, {placeholders})"
cursor.execute(sql, [new_id] + values)
sql_statements.append(f"INSERT INTO session ... VALUES ({new_id}, ...)")
# Copy messages first, building old_id -> new_id map so parts can be relinked
msg_id_map: dict = {}
try:
cursor.execute("SELECT * FROM message WHERE session_id = ?", (sess_id,))
except sqlite3.OperationalError:
pass
else:
for r in cursor.fetchall():
new_msg_id = uuid.uuid4().hex
msg_id_map[r["id"]] = new_msg_id
cols = [k for k in r.keys() if k != "id"]
values = [new_id if k == "session_id" else r[k] for k in cols]
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
cursor.execute(
f"INSERT INTO message (id, {col_list}) VALUES (?, {placeholders})",
[new_msg_id] + values,
)
# Copy parts, relinking both session_id and message_id
try:
cursor.execute("SELECT * FROM part WHERE session_id = ?", (sess_id,))
except sqlite3.OperationalError:
pass
else:
for r in cursor.fetchall():
new_part_id = uuid.uuid4().hex
cols = [k for k in r.keys() if k != "id"]
values = []
for k in cols:
if k == "session_id":
values.append(new_id)
elif k == "message_id":
values.append(msg_id_map.get(r[k], r[k]))
else:
values.append(r[k])
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
cursor.execute(
f"INSERT INTO part (id, {col_list}) VALUES (?, {placeholders})",
[new_part_id] + values,
)
# Copy todo and session_share (session_id only, no secondary FKs)
for table in ("todo", "session_share"):
try:
cursor.execute(f"SELECT * FROM {table} WHERE session_id = ?", (sess_id,))
except sqlite3.OperationalError:
continue
for r in cursor.fetchall():
cols = [k for k in r.keys() if k != "id"]
values = [new_id if k == "session_id" else r[k] for k in cols]
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
if "id" in r.keys():
cursor.execute(
f"INSERT INTO {table} (id, {col_list}) VALUES (?, {placeholders})",
[uuid.uuid4().hex] + values,
)
else:
cursor.execute(
f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})",
values,
)
new_session_ids.append(new_id)
self.conn.commit()
return True, sql_statements, new_session_ids
def delete_sessions(self, session_ids: List[str]) -> Tuple[bool, str]:
"""
Permanently delete sessions and all related records.
Returns:
(success, error_message)
"""
if self.conn is None:
raise RuntimeError("Database is not connected. Call connect() first.")
cursor = self.conn.cursor()
try:
for sess_id in session_ids:
for table, foreign_key in [("message", "session_id"), ("part", "session_id"), ("todo", "session_id"), ("session_share", "session_id")]:
try:
cursor.execute(f"DELETE FROM {table} WHERE {foreign_key} = ?", (sess_id,))
except sqlite3.OperationalError:
continue
cursor.execute("DELETE FROM session WHERE id = ?", (sess_id,))
self.conn.commit()
return True, ""
except sqlite3.Error as e:
self.conn.rollback()
return False, str(e)
def create_backup(self) -> Path:
"""Create a timestamped backup of the database."""
if not self.db_path.exists():
raise FileNotFoundError(f"Database not found: {self.db_path}")
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
backup_path = self.db_path.parent / f"{self.db_path.stem}-{timestamp}{self.db_path.suffix}"
import shutil
shutil.copy2(self.db_path, backup_path)
return backup_path