Fix sqlite compatibility for cognition models

This commit is contained in:
Alexa Amundson
2025-11-20 16:45:53 -06:00
parent a332017fc9
commit 6a93dd62d2
3 changed files with 70 additions and 23 deletions

View File

@@ -18,13 +18,54 @@ Tables:
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, Float, Boolean, DateTime, ForeignKey, JSON, Enum from sqlalchemy import Column, Integer, String, Text, Float, Boolean, DateTime, ForeignKey, JSON, Enum
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.dialects.postgresql import UUID as PGUUID, JSONB
from sqlalchemy.types import TypeDecorator, CHAR
import uuid import uuid
import enum import enum
from ..database import Base from ..database import Base
class GUID(TypeDecorator):
"""Platform-independent GUID/UUID type.
Stores UUIDs as native UUID type on PostgreSQL and as 36-character
strings on other databases (e.g., SQLite used in tests).
"""
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(PGGUID(as_uuid=True))
return dialect.type_descriptor(CHAR(36))
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid.UUID):
return value if dialect.name == "postgresql" else str(value)
return str(uuid.UUID(value))
def process_result_value(self, value, dialect):
if value is None:
return value
return uuid.UUID(value)
class JSONBCompat(TypeDecorator):
"""JSONB that falls back to generic JSON on non-Postgres engines."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB)
return dialect.type_descriptor(JSON())
class WorkflowStatus(str, enum.Enum): class WorkflowStatus(str, enum.Enum):
"""Workflow execution status""" """Workflow execution status"""
PENDING = "pending" PENDING = "pending"
@@ -49,13 +90,13 @@ class Workflow(Base):
""" """
__tablename__ = "workflows" __tablename__ = "workflows"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False, index=True) name = Column(String(200), nullable=False, index=True)
description = Column(Text) description = Column(Text)
# Workflow configuration # Workflow configuration
mode = Column(Enum(ExecutionMode), default=ExecutionMode.SEQUENTIAL, nullable=False) mode = Column(Enum(ExecutionMode), default=ExecutionMode.SEQUENTIAL, nullable=False)
steps = Column(JSONB, nullable=False) # List of workflow steps steps = Column(JSONBCompat(), nullable=False) # List of workflow steps
timeout_seconds = Column(Integer, default=600) timeout_seconds = Column(Integer, default=600)
# Metadata # Metadata
@@ -66,7 +107,7 @@ class Workflow(Base):
is_template = Column(Boolean, default=False) is_template = Column(Boolean, default=False)
# Tags for categorization # Tags for categorization
tags = Column(JSONB, default=list) tags = Column(JSONBCompat(), default=list)
# Relationships # Relationships
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan") executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
@@ -83,8 +124,8 @@ class WorkflowExecution(Base):
""" """
__tablename__ = "workflow_executions" __tablename__ = "workflow_executions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
workflow_id = Column(UUID(as_uuid=True), ForeignKey("workflows.id"), nullable=False, index=True) workflow_id = Column(GUID(), ForeignKey("workflows.id"), nullable=False, index=True)
# Execution details # Execution details
status = Column(Enum(WorkflowStatus), default=WorkflowStatus.PENDING, nullable=False, index=True) status = Column(Enum(WorkflowStatus), default=WorkflowStatus.PENDING, nullable=False, index=True)
@@ -93,17 +134,17 @@ class WorkflowExecution(Base):
duration_seconds = Column(Float) duration_seconds = Column(Float)
# Results # Results
step_results = Column(JSONB) # Results from each step step_results = Column(JSONBCompat()) # Results from each step
error_message = Column(Text) error_message = Column(Text)
error_details = Column(JSONB) error_details = Column(JSONBCompat())
# Metrics # Metrics
overall_confidence = Column(Float) overall_confidence = Column(Float)
total_agents_used = Column(Integer) total_agents_used = Column(Integer)
# Context # Context
initial_context = Column(JSONB) initial_context = Column(JSONBCompat())
final_memory = Column(JSONB) final_memory = Column(JSONBCompat())
# Relationships # Relationships
workflow = relationship("Workflow", back_populates="executions") workflow = relationship("Workflow", back_populates="executions")
@@ -122,8 +163,8 @@ class ReasoningTrace(Base):
""" """
__tablename__ = "reasoning_traces" __tablename__ = "reasoning_traces"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
execution_id = Column(UUID(as_uuid=True), ForeignKey("workflow_executions.id"), nullable=False, index=True) execution_id = Column(GUID(), ForeignKey("workflow_executions.id"), nullable=False, index=True)
# Step identification # Step identification
workflow_step_name = Column(String(100), nullable=False) workflow_step_name = Column(String(100), nullable=False)
@@ -138,7 +179,7 @@ class ReasoningTrace(Base):
confidence_score = Column(Float) confidence_score = Column(Float)
# Additional metadata # Additional metadata
metadata = Column(JSONB) trace_metadata = Column("metadata", JSONBCompat())
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Relationships # Relationships
@@ -157,12 +198,12 @@ class AgentMemory(Base):
""" """
__tablename__ = "agent_memory" __tablename__ = "agent_memory"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
execution_id = Column(UUID(as_uuid=True), ForeignKey("workflow_executions.id"), index=True) execution_id = Column(GUID(), ForeignKey("workflow_executions.id"), index=True)
# Memory data # Memory data
context = Column(JSONB, nullable=False) # Shared context dictionary context = Column(JSONBCompat(), nullable=False) # Shared context dictionary
confidence_scores = Column(JSONB) # Confidence per step confidence_scores = Column(JSONBCompat()) # Confidence per step
# Metadata # Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
@@ -188,7 +229,7 @@ class PromptRegistry(Base):
""" """
__tablename__ = "prompt_registry" __tablename__ = "prompt_registry"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
# Prompt identification # Prompt identification
agent_name = Column(String(50), nullable=False, index=True) agent_name = Column(String(50), nullable=False, index=True)
@@ -201,8 +242,8 @@ class PromptRegistry(Base):
# Metadata # Metadata
description = Column(Text) description = Column(Text)
metadata = Column(JSONB) # Author, purpose, etc. prompt_metadata = Column("metadata", JSONBCompat()) # Author, purpose, etc.
tags = Column(JSONB, default=list) tags = Column(JSONBCompat(), default=list)
# Usage stats # Usage stats
usage_count = Column(Integer, default=0) usage_count = Column(Integer, default=0)
@@ -227,11 +268,11 @@ class AgentPerformanceMetric(Base):
""" """
__tablename__ = "agent_performance_metrics" __tablename__ = "agent_performance_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(GUID(), primary_key=True, default=uuid.uuid4)
# Agent identification # Agent identification
agent_name = Column(String(50), nullable=False, index=True) agent_name = Column(String(50), nullable=False, index=True)
execution_id = Column(UUID(as_uuid=True), ForeignKey("workflow_executions.id"), index=True) execution_id = Column(GUID(), ForeignKey("workflow_executions.id"), index=True)
# Performance metrics # Performance metrics
execution_time_seconds = Column(Float) execution_time_seconds = Column(Float)

View File

@@ -242,7 +242,7 @@ async def run_cognition(
input_context=request.input, input_context=request.input,
output=str(step_value), output=str(step_value),
confidence_score=pipeline.get('confidence', 0.0), confidence_score=pipeline.get('confidence', 0.0),
metadata={'mode': request.mode} trace_metadata={'mode': request.mode}
) )
db.add(trace) db.add(trace)
step_number += 1 step_number += 1

View File

@@ -3,12 +3,18 @@ import pytest
import pytest_asyncio import pytest_asyncio
import asyncio import asyncio
import os import os
import sys
from pathlib import Path
from typing import AsyncGenerator from typing import AsyncGenerator
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
from app.main import app from app.main import app
from app.database import get_db, Base from app.database import get_db, Base
from app.config import settings from app.config import settings