463 lines
17 KiB
Python
463 lines
17 KiB
Python
"""Database layer for OpenCode session operations."""
|
|
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Optional, List, Tuple
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
|
|
|
|
@dataclass
|
|
class Project:
|
|
"""Project data structure."""
|
|
id: str
|
|
worktree: str
|
|
name: Optional[str]
|
|
vcs: Optional[str]
|
|
time_created: int
|
|
time_updated: int
|
|
|
|
|
|
@dataclass
|
|
class Workspace:
|
|
"""Workspace data structure."""
|
|
id: str
|
|
project_id: str
|
|
branch: Optional[str]
|
|
type: str
|
|
name: Optional[str]
|
|
directory: Optional[str]
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
"""Session data structure."""
|
|
id: str
|
|
project_id: str
|
|
workspace_id: Optional[str]
|
|
parent_id: Optional[str]
|
|
slug: str
|
|
directory: str
|
|
title: str
|
|
version: str
|
|
time_created: int
|
|
time_updated: int
|
|
time_archived: Optional[int] = None
|
|
# Computed fields
|
|
message_count: int = 0
|
|
todo_count: int = 0
|
|
workspace_name: Optional[str] = None
|
|
project_name: Optional[str] = None
|
|
|
|
|
|
class Database:
|
|
"""Database connection and query handler."""
|
|
|
|
def __init__(self, db_path: str = "opencode.db"):
|
|
self.db_path = Path(db_path)
|
|
self.conn: Optional[sqlite3.Connection] = None
|
|
self.projects: dict[str, Project] = {}
|
|
self.workspaces: dict[str, Workspace] = {}
|
|
self.workspaces_by_project: dict[str, List[Workspace]] = {}
|
|
|
|
def connect(self):
|
|
"""Establish database connection."""
|
|
if not self.db_path.exists():
|
|
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
|
def close(self):
|
|
"""Close database connection."""
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
def load_reference_data(self):
|
|
"""Load projects and workspaces into memory for fast lookups."""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
|
|
# Load projects
|
|
cursor.execute("SELECT id, worktree, name, vcs, time_created, time_updated FROM project")
|
|
for row in cursor.fetchall():
|
|
proj = Project(
|
|
id=row["id"],
|
|
worktree=row["worktree"],
|
|
name=row["name"],
|
|
vcs=row["vcs"],
|
|
time_created=row["time_created"],
|
|
time_updated=row["time_updated"]
|
|
)
|
|
self.projects[proj.id] = proj
|
|
|
|
# Load workspaces
|
|
cursor.execute("SELECT id, project_id, branch, type, name, directory FROM workspace")
|
|
for row in cursor.fetchall():
|
|
ws = Workspace(
|
|
id=row["id"],
|
|
project_id=row["project_id"],
|
|
branch=row["branch"],
|
|
type=row["type"],
|
|
name=row["name"],
|
|
directory=row["directory"]
|
|
)
|
|
self.workspaces[ws.id] = ws
|
|
|
|
# Build workspace lookup by project
|
|
self.workspaces_by_project = {}
|
|
for ws in self.workspaces.values():
|
|
self.workspaces_by_project.setdefault(ws.project_id, []).append(ws)
|
|
|
|
def get_sessions(self,
|
|
include_archived: bool = False,
|
|
project_id: Optional[str] = None,
|
|
workspace_id: Optional[str] = None,
|
|
search: str = "") -> List[Session]:
|
|
"""
|
|
Fetch sessions with optional filtering.
|
|
|
|
Args:
|
|
include_archived: Include archived sessions
|
|
project_id: Filter by project
|
|
workspace_id: Filter by workspace
|
|
search: Text search in title or slug
|
|
|
|
Returns:
|
|
List of Session objects with computed counts
|
|
"""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
|
|
query = """
|
|
SELECT s.*,
|
|
(SELECT COUNT(*) FROM message WHERE session_id = s.id) as msg_count,
|
|
(SELECT COUNT(*) FROM todo WHERE session_id = s.id) as todo_count
|
|
FROM session s
|
|
WHERE 1=1
|
|
"""
|
|
params = []
|
|
|
|
if not include_archived:
|
|
query += " AND s.time_archived IS NULL"
|
|
|
|
if project_id:
|
|
query += " AND s.project_id = ?"
|
|
params.append(project_id)
|
|
|
|
if workspace_id:
|
|
query += " AND s.workspace_id = ?"
|
|
params.append(workspace_id)
|
|
|
|
if search:
|
|
query += " AND (LOWER(s.title) LIKE LOWER(?) OR LOWER(s.slug) LIKE LOWER(?))"
|
|
params.extend([f"%{search}%", f"%{search}%"])
|
|
|
|
query += " ORDER BY s.time_created DESC"
|
|
|
|
cursor.execute(query, params)
|
|
rows = cursor.fetchall()
|
|
|
|
sessions = []
|
|
for row in rows:
|
|
sess = Session(
|
|
id=row["id"],
|
|
project_id=row["project_id"],
|
|
workspace_id=row["workspace_id"],
|
|
parent_id=row["parent_id"],
|
|
slug=row["slug"],
|
|
directory=row["directory"],
|
|
title=row["title"],
|
|
version=row["version"],
|
|
time_created=row["time_created"],
|
|
time_updated=row["time_updated"],
|
|
time_archived=row["time_archived"],
|
|
message_count=row["msg_count"],
|
|
todo_count=row["todo_count"]
|
|
)
|
|
|
|
# Add project name
|
|
if sess.project_id in self.projects:
|
|
sess.project_name = self.projects[sess.project_id].name or self.projects[sess.project_id].worktree
|
|
|
|
# Add workspace name
|
|
if sess.workspace_id and sess.workspace_id in self.workspaces:
|
|
ws = self.workspaces[sess.workspace_id]
|
|
sess.workspace_name = ws.name or ws.branch or ws.id
|
|
|
|
sessions.append(sess)
|
|
|
|
return sessions
|
|
|
|
def get_project_tree(self) -> List[Tuple[str, List[str], bool]]:
|
|
"""
|
|
Build a hierarchical view of projects and workspaces.
|
|
|
|
Returns:
|
|
List of tuples: (project_name, [workspace_names], has_workspaces)
|
|
"""
|
|
result = []
|
|
for proj in self.projects.values():
|
|
ws_list = self.workspaces_by_project.get(proj.id, [])
|
|
ws_names = [w.name or w.branch or w.id for w in ws_list]
|
|
has_workspaces = len(ws_list) > 0
|
|
result.append((proj.name or proj.worktree, ws_names, has_workspaces))
|
|
return sorted(result)
|
|
|
|
def get_session_counts_by_project(self) -> dict:
|
|
"""Return a dict of project_id -> active session count."""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
cursor.execute(
|
|
"SELECT project_id, COUNT(*) FROM session WHERE time_archived IS NULL GROUP BY project_id"
|
|
)
|
|
return {row[0]: row[1] for row in cursor.fetchall()}
|
|
|
|
def get_session(self, session_id: str) -> Optional[Session]:
|
|
"""Get a single session by ID."""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT * FROM session WHERE id = ?", (session_id,))
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
|
|
sess = Session(
|
|
id=row["id"],
|
|
project_id=row["project_id"],
|
|
workspace_id=row["workspace_id"],
|
|
parent_id=row["parent_id"],
|
|
slug=row["slug"],
|
|
directory=row["directory"],
|
|
title=row["title"],
|
|
version=row["version"],
|
|
time_created=row["time_created"],
|
|
time_updated=row["time_updated"],
|
|
time_archived=row["time_archived"],
|
|
)
|
|
|
|
# Message count
|
|
cursor.execute("SELECT COUNT(*) FROM message WHERE session_id = ?", (session_id,))
|
|
sess.message_count = cursor.fetchone()[0]
|
|
|
|
# Project name
|
|
if sess.project_id in self.projects:
|
|
proj = self.projects[sess.project_id]
|
|
sess.project_name = proj.name or proj.worktree
|
|
|
|
# Workspace name
|
|
if sess.workspace_id and sess.workspace_id in self.workspaces:
|
|
ws = self.workspaces[sess.workspace_id]
|
|
sess.workspace_name = ws.name or ws.branch or ws.id
|
|
|
|
return sess
|
|
|
|
def move_sessions(self, session_ids: List[str], target_project_id: str,
|
|
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str]]:
|
|
"""
|
|
Move sessions to a different project (and optionally workspace).
|
|
|
|
Returns:
|
|
(success, list of SQL statements executed)
|
|
"""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
|
|
# Verify target project exists
|
|
if target_project_id not in self.projects:
|
|
return False, [f"ERROR: Project {target_project_id} not found"]
|
|
|
|
# Verify workspace exists and belongs to target project
|
|
if target_workspace_id:
|
|
if target_workspace_id not in self.workspaces:
|
|
return False, [f"ERROR: Workspace {target_workspace_id} not found"]
|
|
if self.workspaces[target_workspace_id].project_id != target_project_id:
|
|
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"]
|
|
|
|
sql_statements = []
|
|
for sess_id in session_ids:
|
|
if target_workspace_id:
|
|
sql = "UPDATE session SET project_id = ?, workspace_id = ? WHERE id = ?"
|
|
cursor.execute(sql, (target_project_id, target_workspace_id, sess_id))
|
|
sql_statements.append(f"UPDATE session SET project_id = {target_project_id}, workspace_id = {target_workspace_id} WHERE id = {sess_id}")
|
|
else:
|
|
# Keep current workspace if any, or set to NULL
|
|
sql = "UPDATE session SET project_id = ? WHERE id = ?"
|
|
cursor.execute(sql, (target_project_id, sess_id))
|
|
sql_statements.append(f"UPDATE session SET project_id = {target_project_id} WHERE id = {sess_id}")
|
|
|
|
self.conn.commit()
|
|
return True, sql_statements
|
|
|
|
def copy_sessions(self, session_ids: List[str], target_project_id: str,
|
|
target_workspace_id: Optional[str] = None) -> Tuple[bool, List[str], List[str]]:
|
|
"""
|
|
Copy sessions to a target project.
|
|
|
|
Returns:
|
|
(success, list of SQL statements, list of new session IDs)
|
|
"""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
import uuid
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
# Verify target project exists
|
|
if target_project_id not in self.projects:
|
|
return False, [f"ERROR: Project {target_project_id} not found"], []
|
|
|
|
# Verify workspace exists and belongs to target project
|
|
if target_workspace_id:
|
|
if target_workspace_id not in self.workspaces:
|
|
return False, [f"ERROR: Workspace {target_workspace_id} not found"], []
|
|
if self.workspaces[target_workspace_id].project_id != target_project_id:
|
|
return False, [f"ERROR: Workspace {target_workspace_id} does not belong to project {target_project_id}"], []
|
|
|
|
new_session_ids = []
|
|
sql_statements = []
|
|
|
|
for sess_id in session_ids:
|
|
# Get original session
|
|
cursor.execute("SELECT * FROM session WHERE id = ?", (sess_id,))
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
continue
|
|
|
|
# Create new session ID
|
|
new_id = "ses_" + uuid.uuid4().hex
|
|
|
|
# Insert new session
|
|
cols = [k for k in row.keys() if k != "id"]
|
|
col_list = ", ".join(cols)
|
|
placeholders = ", ".join(["?"] * len(cols))
|
|
values = [row[col] for col in cols]
|
|
|
|
# Override project_id and optionally workspace_id
|
|
project_idx = cols.index("project_id")
|
|
values[project_idx] = target_project_id
|
|
|
|
if target_workspace_id:
|
|
if "workspace_id" in cols:
|
|
ws_idx = cols.index("workspace_id")
|
|
values[ws_idx] = target_workspace_id
|
|
else:
|
|
# Clear workspace if not specified
|
|
if "workspace_id" in cols:
|
|
ws_idx = cols.index("workspace_id")
|
|
values[ws_idx] = None
|
|
|
|
sql = f"INSERT INTO session (id, {col_list}) VALUES (?, {placeholders})"
|
|
cursor.execute(sql, [new_id] + values)
|
|
sql_statements.append(f"INSERT INTO session ... VALUES ({new_id}, ...)")
|
|
|
|
# Copy messages first, building old_id -> new_id map so parts can be relinked
|
|
msg_id_map: dict = {}
|
|
try:
|
|
cursor.execute("SELECT * FROM message WHERE session_id = ?", (sess_id,))
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
else:
|
|
for r in cursor.fetchall():
|
|
new_msg_id = uuid.uuid4().hex
|
|
msg_id_map[r["id"]] = new_msg_id
|
|
cols = [k for k in r.keys() if k != "id"]
|
|
values = [new_id if k == "session_id" else r[k] for k in cols]
|
|
col_list = ", ".join(cols)
|
|
placeholders = ", ".join(["?"] * len(cols))
|
|
cursor.execute(
|
|
f"INSERT INTO message (id, {col_list}) VALUES (?, {placeholders})",
|
|
[new_msg_id] + values,
|
|
)
|
|
|
|
# Copy parts, relinking both session_id and message_id
|
|
try:
|
|
cursor.execute("SELECT * FROM part WHERE session_id = ?", (sess_id,))
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
else:
|
|
for r in cursor.fetchall():
|
|
new_part_id = uuid.uuid4().hex
|
|
cols = [k for k in r.keys() if k != "id"]
|
|
values = []
|
|
for k in cols:
|
|
if k == "session_id":
|
|
values.append(new_id)
|
|
elif k == "message_id":
|
|
values.append(msg_id_map.get(r[k], r[k]))
|
|
else:
|
|
values.append(r[k])
|
|
col_list = ", ".join(cols)
|
|
placeholders = ", ".join(["?"] * len(cols))
|
|
cursor.execute(
|
|
f"INSERT INTO part (id, {col_list}) VALUES (?, {placeholders})",
|
|
[new_part_id] + values,
|
|
)
|
|
|
|
# Copy todo and session_share (session_id only, no secondary FKs)
|
|
for table in ("todo", "session_share"):
|
|
try:
|
|
cursor.execute(f"SELECT * FROM {table} WHERE session_id = ?", (sess_id,))
|
|
except sqlite3.OperationalError:
|
|
continue
|
|
for r in cursor.fetchall():
|
|
cols = [k for k in r.keys() if k != "id"]
|
|
values = [new_id if k == "session_id" else r[k] for k in cols]
|
|
col_list = ", ".join(cols)
|
|
placeholders = ", ".join(["?"] * len(cols))
|
|
if "id" in r.keys():
|
|
cursor.execute(
|
|
f"INSERT INTO {table} (id, {col_list}) VALUES (?, {placeholders})",
|
|
[uuid.uuid4().hex] + values,
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})",
|
|
values,
|
|
)
|
|
|
|
new_session_ids.append(new_id)
|
|
|
|
self.conn.commit()
|
|
return True, sql_statements, new_session_ids
|
|
|
|
def delete_sessions(self, session_ids: List[str]) -> Tuple[bool, str]:
|
|
"""
|
|
Permanently delete sessions and all related records.
|
|
|
|
Returns:
|
|
(success, error_message)
|
|
"""
|
|
if self.conn is None:
|
|
raise RuntimeError("Database is not connected. Call connect() first.")
|
|
cursor = self.conn.cursor()
|
|
|
|
try:
|
|
for sess_id in session_ids:
|
|
for table, foreign_key in [("message", "session_id"), ("part", "session_id"), ("todo", "session_id"), ("session_share", "session_id")]:
|
|
try:
|
|
cursor.execute(f"DELETE FROM {table} WHERE {foreign_key} = ?", (sess_id,))
|
|
except sqlite3.OperationalError:
|
|
continue
|
|
cursor.execute("DELETE FROM session WHERE id = ?", (sess_id,))
|
|
self.conn.commit()
|
|
return True, ""
|
|
except sqlite3.Error as e:
|
|
self.conn.rollback()
|
|
return False, str(e)
|
|
|
|
def create_backup(self) -> Path:
|
|
"""Create a timestamped backup of the database."""
|
|
if not self.db_path.exists():
|
|
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
backup_path = self.db_path.parent / f"{self.db_path.stem}-{timestamp}{self.db_path.suffix}"
|
|
import shutil
|
|
shutil.copy2(self.db_path, backup_path)
|
|
return backup_path
|