Files
opencode-session-manager/database.py
2026-03-07 17:15:22 -07:00

383 lines
14 KiB
Python

"""Database layer for OpenCode session operations."""
import sqlite3
from pathlib import Path
from typing import Optional, List, Tuple
from dataclasses import dataclass, field
from datetime import datetime
@dataclass
class Project:
"""Project data structure."""
id: str
worktree: str
name: Optional[str]
vcs: Optional[str]
time_created: int
time_updated: int
@dataclass
class Workspace:
"""Workspace data structure."""
id: str
project_id: str
branch: Optional[str]
type: str
name: Optional[str]
directory: Optional[str]
@dataclass
class Session:
"""Session data structure."""
id: str
project_id: str
workspace_id: Optional[str]
parent_id: Optional[str]
slug: str
directory: str
title: str
version: str
time_created: int
time_updated: int
time_archived: Optional[int] = None
# Computed fields
message_count: int = 0
todo_count: int = 0
workspace_name: Optional[str] = None
project_name: Optional[str] = None
class Database:
"""Database connection and query handler."""
def __init__(self, db_path: str = "opencode.db"):
self.db_path = Path(db_path)
self.conn: Optional[sqlite3.Connection] = None
self.projects: dict[str, Project] = {}
self.workspaces: dict[str, Workspace] = {}
self.workspaces_by_project: dict[str, List[Workspace]] = {}
def connect(self):
"""Establish database connection."""
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
def load_reference_data(self):
"""Load projects and workspaces into memory for fast lookups."""
assert self.conn is not None
cursor = self.conn.cursor()
# Load projects
cursor.execute("SELECT id, worktree, name, vcs, time_created, time_updated FROM project")
for row in cursor.fetchall():
proj = Project(
id=row["id"],
worktree=row["worktree"],
name=row["name"],
vcs=row["vcs"],
time_created=row["time_created"],
time_updated=row["time_updated"]
)
self.projects[proj.id] = proj
# Load workspaces
cursor.execute("SELECT id, project_id, branch, type, name, directory FROM workspace")
for row in cursor.fetchall():
ws = Workspace(
id=row["id"],
project_id=row["project_id"],
branch=row["branch"],
type=row["type"],
name=row["name"],
directory=row["directory"]
)
self.workspaces[ws.id] = ws
# Build workspace lookup by project
self.workspaces_by_project = {}
for ws in self.workspaces.values():
self.workspaces_by_project.setdefault(ws.project_id, []).append(ws)
def get_sessions(self,
include_archived: bool = False,
project_id: Optional[str] = None,
workspace_id: Optional[str] = None,
search: str = "") -> List[Session]:
"""
Fetch sessions with optional filtering.
Args:
include_archived: Include archived sessions
project_id: Filter by project
workspace_id: Filter by workspace
search: Text search in title or slug
Returns:
List of Session objects with computed counts
"""
assert self.conn is not None, "Database not connected"
cursor = self.conn.cursor()
query = """
SELECT s.*,
(SELECT COUNT(*) FROM message WHERE session_id = s.id) as msg_count,
(SELECT COUNT(*) FROM todo WHERE session_id = s.id) as todo_count
FROM session s
WHERE 1=1
"""
params = []
if not include_archived:
query += " AND s.time_archived IS NULL"
if project_id:
query += " AND s.project_id = ?"
params.append(project_id)
if workspace_id:
query += " AND s.workspace_id = ?"
params.append(workspace_id)
if search:
query += " AND (LOWER(s.title) LIKE LOWER(?) OR LOWER(s.slug) LIKE LOWER(?))"
params.extend([f"%{search}%", f"%{search}%"])
query += " ORDER BY s.time_created DESC"
cursor.execute(query, params)
rows = cursor.fetchall()
sessions = []
for row in rows:
sess = Session(
id=row["id"],
project_id=row["project_id"],
workspace_id=row["workspace_id"],
parent_id=row["parent_id"],
slug=row["slug"],
directory=row["directory"],
title=row["title"],
version=row["version"],
time_created=row["time_created"],
time_updated=row["time_updated"],
time_archived=row["time_archived"],
message_count=row["msg_count"],
todo_count=row["todo_count"]
)
# Add project name
if sess.project_id in self.projects:
sess.project_name = self.projects[sess.project_id].name or self.projects[sess.project_id].worktree
# Add workspace name
if sess.workspace_id and sess.workspace_id in self.workspaces:
ws = self.workspaces[sess.workspace_id]
sess.workspace_name = ws.name or ws.branch or ws.id
sessions.append(sess)
return sessions
def get_project_tree(self) -> List[Tuple[str, List[str], bool]]:
"""
Build a hierarchical view of projects and workspaces.
Returns:
List of tuples: (project_name, [workspace_names], has_workspaces)
"""
result = []
for proj in self.projects.values():
ws_list = self.workspaces_by_project.get(proj.id, [])
ws_names = [w.name or w.branch or w.id for w in ws_list]
has_workspaces = len(ws_list) > 0
result.append((proj.name or proj.worktree, ws_names, has_workspaces))
return sorted(result)
def get_session_counts_by_project(self) -> dict:
"""Return a dict of project_id -> active session count."""
assert self.conn is not None
cursor = self.conn.cursor()
cursor.execute(
"SELECT project_id, COUNT(*) FROM session WHERE time_archived IS NULL GROUP BY project_id"
)
return {row[0]: row[1] for row in cursor.fetchall()}
def get_session(self, session_id: str) -> Optional[Session]:
"""Get a single session by ID."""
sessions = self.get_sessions(include_archived=True)
for s in sessions:
if s.id == session_id:
return s
return None
def move_sessions(self, session_ids: List[str], target_project_id: str,
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str]]:
"""
Move sessions to a different project (and optionally workspace).
Returns:
(success, list of SQL statements executed)
"""
assert self.conn is not None, "Database not connected"
cursor = self.conn.cursor()
# Verify target project exists
if target_project_id not in self.projects:
return False, [f"ERROR: Project {target_project_id} not found"]
# Verify workspace exists and belongs to target project
if target_workspace_id:
if target_workspace_id not in self.workspaces:
return False, [f"ERROR: Workspace {target_workspace_id} not found"]
if self.workspaces[target_workspace_id].project_id != target_project_id:
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"]
sql_statements = []
for sess_id in session_ids:
if target_workspace_id:
sql = "UPDATE session SET project_id = ?, workspace_id = ? WHERE id = ?"
cursor.execute(sql, (target_project_id, target_workspace_id, sess_id))
sql_statements.append(f"UPDATE session SET project_id = {target_project_id}, workspace_id = {target_workspace_id} WHERE id = {sess_id}")
else:
# Keep current workspace if any, or set to NULL
sql = "UPDATE session SET project_id = ? WHERE id = ?"
cursor.execute(sql, (target_project_id, sess_id))
sql_statements.append(f"UPDATE session SET project_id = {target_project_id} WHERE id = {sess_id}")
self.conn.commit()
return True, sql_statements
def copy_sessions(self, session_ids: List[str], target_project_id: str,
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str], List[str]]:
"""
Copy sessions to a target project.
Returns:
(success, list of SQL statements, list of new session IDs)
"""
assert self.conn is not None, "Database not connected"
import uuid
cursor = self.conn.cursor()
# Verify target project exists
if target_project_id not in self.projects:
return False, [f"ERROR: Project {target_project_id} not found"], []
# Verify workspace exists and belongs to target project
if target_workspace_id:
if target_workspace_id not in self.workspaces:
return False, [f"ERROR: Workspace {target_workspace_id} not found"], []
if self.workspaces[target_workspace_id].project_id != target_project_id:
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"], []
new_session_ids = []
sql_statements = []
for sess_id in session_ids:
# Get original session
cursor.execute("SELECT * FROM session WHERE id = ?", (sess_id,))
row = cursor.fetchone()
if not row:
continue
# Create new session ID
new_id = str(uuid.uuid4()).replace("-", "")
while new_id[:3] != "ses":
new_id = "ses_" + new_id
# Insert new session
cols = [k for k in row.keys() if k != "id"]
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
values = [row[col] for col in cols]
# Override project_id and optionally workspace_id
project_idx = cols.index("project_id")
values[project_idx] = target_project_id
if target_workspace_id:
if "workspace_id" in cols:
ws_idx = cols.index("workspace_id")
values[ws_idx] = target_workspace_id
else:
# Clear workspace if not specified
if "workspace_id" in cols:
ws_idx = cols.index("workspace_id")
values[ws_idx] = None
sql = f"INSERT INTO session (id, {col_list}) VALUES (?, {placeholders})"
cursor.execute(sql, [new_id] + values)
sql_statements.append(f"INSERT INTO session ... VALUES ({new_id}, ...)")
# Copy related records
for table, foreign_key in [("message", "session_id"), ("part", "session_id"), ("todo", "session_id"), ("session_share", "session_id")]:
try:
cursor.execute(f"SELECT * FROM {table} WHERE {foreign_key} = ?", (sess_id,))
except sqlite3.OperationalError:
continue # Table doesn't exist in this schema version
rows = cursor.fetchall()
for r in rows:
# Include foreign_key in cols so the new session_id is written
cols = [k for k in r.keys() if k != "id"]
values = [new_id if k == foreign_key else r[k] for k in cols]
col_list = ", ".join(cols)
placeholders = ", ".join(["?"] * len(cols))
# Generate new ID for tables with id column
if "id" in r.keys():
new_table_id = str(uuid.uuid4()).replace("-", "")
sql = f"INSERT INTO {table} (id, {col_list}) VALUES (?, {placeholders})"
cursor.execute(sql, [new_table_id] + values)
else:
sql = f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})"
cursor.execute(sql, values)
new_session_ids.append(new_id)
self.conn.commit()
return True, sql_statements, new_session_ids
def delete_sessions(self, session_ids: List[str]) -> Tuple[bool, str]:
"""
Permanently delete sessions and all related records.
Returns:
(success, error_message)
"""
assert self.conn is not None, "Database not connected"
cursor = self.conn.cursor()
try:
for sess_id in session_ids:
for table, foreign_key in [("message", "session_id"), ("part", "session_id"), ("todo", "session_id"), ("session_share", "session_id")]:
try:
cursor.execute(f"DELETE FROM {table} WHERE {foreign_key} = ?", (sess_id,))
except sqlite3.OperationalError:
continue
cursor.execute("DELETE FROM session WHERE id = ?", (sess_id,))
self.conn.commit()
return True, ""
except sqlite3.Error as e:
self.conn.rollback()
return False, str(e)
def create_backup(self) -> Path:
"""Create a timestamped backup of the database."""
if not self.db_path.exists():
raise FileNotFoundError(f"Database not found: {self.db_path}")
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
backup_path = self.db_path.parent / f"{self.db_path.stem}-{timestamp}{self.db_path.suffix}"
import shutil
shutil.copy2(self.db_path, backup_path)
return backup_path