Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions app/api/endpoints/analysis_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
logger = logging.getLogger(__name__)


@router.post("/create", response_model=AnalysisRead)
async def create_analysis_test(data: AnalysisRead) -> AnalysisRead:
pass


@router.get("/{analysis_id}", response_model=AnalysisRead)
async def get_analysis(
analysis_id: UUID,
Expand Down
21 changes: 21 additions & 0 deletions app/core/llm/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,27 @@ class AnalysisPrompt:
Limitations of Previous Response: [LLM fills this]
Overall Confidence in Previous Response (0-100): [LLM provides numerical score]
"""
GET_ANTH_CONFIDENCE = """
Statement: {statement}
Label: {label}
Is the given label correct for the statement?
Answer yes or no:
"""

GET_ANTH_CONFIDENCE_MOD = """
Statement: {statement}
Label: {label}
Explanation: {explanation}
Given the generated explanation, is the given label correct for the statement?
Answer yes or no:
"""

GET_ANTH_CONFIDENCE_MOD_2 = """
Statement: {statement}
Veracity score (a score from 0 to 100, where 0 represents definitively false and 100 represents definitively true): {score}
Is the given veracity score correct for the statement?
Answer yes or no:
"""

HIGH_ASSERT = """

Expand Down
2 changes: 1 addition & 1 deletion app/core/llm/together_ai_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def generate_response(self, messages: List[Message], temperature: float =
messages=[{"role": m.role, "content": m.content} for m in messages],
temperature=temperature,
# CRITICAL: This enables the confidence data you need
logprobs=1,
logprobs=5,
)

choice = response.choices[0]
Expand Down
3 changes: 3 additions & 0 deletions app/models/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
text,
ARRAY,
DOUBLE_PRECISION,
LargeBinary,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship

Expand Down Expand Up @@ -140,6 +141,8 @@ class AnalysisModel(Base):
index=True,
)

log_probs: Mapped[bytes] = mapped_column(LargeBinary, nullable=True)

claim: Mapped["ClaimModel"] = relationship(back_populates="analyses", doc="Related claim")
searches: Mapped[List["SearchModel"]] = relationship(back_populates="analysis", cascade="all, delete-orphan")
feedbacks: Mapped[List["FeedbackModel"]] = relationship(
Expand Down
54 changes: 53 additions & 1 deletion app/models/domain/analysis.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, List
from typing import Optional, List, Dict
from uuid import UUID
import pickle

from app.models.database.models import AnalysisModel, AnalysisStatus
from app.models.domain.feedback import Feedback
from app.models.domain.search import Search


@dataclass
class LogProbsData:
anth_conf_score: float
tokens: List[str]
probs: List[float]
alternatives: List[Dict[str, float]]


@dataclass
class Analysis:
"""Domain model for claim analysis."""
Expand All @@ -20,12 +29,22 @@ class Analysis:
status: str
created_at: datetime
updated_at: datetime
log_probs: Optional[LogProbsData] = None
searches: Optional[List["Search"]] = None
feedback: Optional[List["Feedback"]] = None

@classmethod
def from_model(cls, model: "AnalysisModel") -> "Analysis":
"""Create domain model from database model."""

log_probs_obj = None
if model.log_probs:
try:
log_probs_obj = pickle.loads(model.log_probs)
except Exception:
# Fallback in case unpickling fails (e.g., corrupt data)
log_probs_obj = None

return cls(
id=model.id,
claim_id=model.claim_id,
Expand All @@ -35,17 +54,50 @@ def from_model(cls, model: "AnalysisModel") -> "Analysis":
status=model.status.value,
created_at=model.created_at,
updated_at=model.updated_at,
log_probs=log_probs_obj,
searches=[Search.from_model(s) for s in model.searches] if model.searches else None,
feedback=[Feedback.from_model(f) for f in model.feedbacks] if model.feedbacks else None,
)

@classmethod
def from_model_safe(cls, model: "AnalysisModel") -> "Analysis":
"""Create domain model from database model, explicitly ignoring relationships."""

log_probs_obj = None
if model.log_probs:
try:
log_probs_obj = pickle.loads(model.log_probs)
except Exception:
log_probs_obj = None

return cls(
id=model.id,
claim_id=model.claim_id,
veracity_score=model.veracity_score,
confidence_score=model.confidence_score,
analysis_text=model.analysis_text,
status=model.status.value,
created_at=model.created_at,
updated_at=model.updated_at,
log_probs=log_probs_obj,
# empty initalization (they are empty at creation)
searches=None,
feedback=None,
)

def to_model(self) -> "AnalysisModel":
"""Convert to database model."""

log_probs_bytes = None
if self.log_probs:
log_probs_bytes = pickle.dumps(self.log_probs)

return AnalysisModel(
id=self.id,
claim_id=self.claim_id,
veracity_score=self.veracity_score,
confidence_score=self.confidence_score,
analysis_text=self.analysis_text,
status=AnalysisStatus(self.status),
log_probs=log_probs_bytes,
)
22 changes: 2 additions & 20 deletions app/repositories/implementations/analysis_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,11 @@ def __init__(self, session: AsyncSession):
super().__init__(session, AnalysisModel)

def _to_model(self, analysis: Analysis) -> AnalysisModel:
return AnalysisModel(
id=analysis.id,
claim_id=analysis.claim_id,
veracity_score=analysis.veracity_score,
confidence_score=analysis.confidence_score,
analysis_text=analysis.analysis_text,
status=AnalysisStatus(analysis.status),
)
return analysis.to_model()

def _to_domain(self, model: AnalysisModel) -> Analysis:
"""Convert database model to domain model without loading relationships."""
return Analysis(
id=model.id,
claim_id=model.claim_id,
veracity_score=model.veracity_score,
confidence_score=model.confidence_score,
analysis_text=model.analysis_text,
status=model.status.value,
created_at=model.created_at,
updated_at=model.updated_at,
searches=None,
feedback=None,
)
return Analysis.from_model_safe(model)

async def create(self, analysis: Analysis) -> Analysis:
"""Create new analysis."""
Expand Down
3 changes: 3 additions & 0 deletions app/schemas/analysis_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from uuid import UUID
from typing import Optional
from app.models.domain.analysis import LogProbsData


class AnalysisCreate(BaseModel):
Expand All @@ -17,6 +19,7 @@ class AnalysisRead(BaseModel):
confidence_score: float
analysis_text: str
created_at: datetime
log_probs: Optional[LogProbsData]

model_config = ConfigDict(from_attributes=True)

Expand Down
57 changes: 54 additions & 3 deletions app/services/analysis_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from app.core.llm.interfaces import LLMProvider
from app.models.database.models import AnalysisStatus, ClaimStatus, ConversationStatus, MessageSenderType
from app.models.domain.claim import Claim
from app.models.domain.analysis import Analysis
from app.models.domain.analysis import Analysis, LogProbsData
from app.models.domain.search import Search
from app.models.domain.message import Message
from app.core.llm.messages import Message as LLMMessage
Expand Down Expand Up @@ -203,8 +203,8 @@ async def _generate_analysis(
yield {"type": "content", "content": chunk.text}
else:
full_text = "".join(analysis_text)
logger.warning(f"length {len(analysis_text)}, {analysis_text}")
logger.warning(f"length {len(log_probs)}, {log_probs}")
# logger.warning(f"length {len(analysis_text)}, {analysis_text}")
# logger.warning(f"length {len(log_probs)}, {log_probs}")

try:
# Clean the text before parsing
Expand Down Expand Up @@ -247,6 +247,8 @@ async def _generate_analysis(
con_score = await self._generate_logprob_confidence_score(log_probs=log_probs)
logger.info(con_score)
current_analysis.confidence_score = float(con_score)
# log_data = await self._get_anth_confidence_score(statement=claim_text, veracity_score=veracity_score)
# current_analysis.log_probs = log_data

updated_analysis = await self._analysis_repo.update(current_analysis)

Expand Down Expand Up @@ -748,6 +750,55 @@ async def _generate_confidence_score(self, statement: str, analysis: str, source

return ""

async def _get_anth_confidence_score(self, statement: str, veracity_score: float):
# label = 'true'
# if veracity_score < 50:
# label ='false'
# elif veracity_score == 50:
# label = 'unknown'

messages = [
LLMMessage(
role="user",
content=AnalysisPrompt.GET_ANTH_CONFIDENCE_MOD_2.format(
date=datetime.now().isoformat(),
statement=statement,
score=veracity_score,
),
)
]
response = await self._llm.generate_response(messages)
raw_logprobs = response.metadata.get("raw_logprobs")

# Initialize defaults
log_probs_obj = LogProbsData(anth_conf_score=0, tokens=[], probs=[], alternatives=[])

if raw_logprobs:
# 2. Get the log_prob for "Yes" from the first token's alternatives
# First entry in top_logprobs list, get the value for key 'Yes'
first_token_alts = raw_logprobs.top_logprobs[0]

p_yes = sum(math.exp(val) for key, val in first_token_alts.items() if key.strip().lower() == "yes")

p_no = sum(math.exp(val) for key, val in first_token_alts.items() if key.strip().lower() == "no")

# Convert log probability to linear probability: p = exp(log_p)
pvlm_score = 0
if p_yes + p_no == 0:
pvlm_score = 0.5 # Neutral/Uncertain
else:
pvlm_score = p_yes / (p_yes + p_no)

# 3. Construct the LogProbsData object
log_probs_obj = LogProbsData(
anth_conf_score=pvlm_score,
tokens=raw_logprobs.tokens,
probs=raw_logprobs.token_logprobs,
alternatives=raw_logprobs.top_logprobs,
)
logger.info(f"Confidence: {pvlm_score}, Data: {log_probs_obj}")
return log_probs_obj

def _extract_search_query_or_none(
self,
assistant_response: str,
Expand Down
30 changes: 30 additions & 0 deletions migrations/versions/d2ffae797992_logits_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""logits column

Revision ID: d2ffae797992
Revises: 142219b495ef
Create Date: 2025-12-15 16:55:40.163460

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "d2ffae797992"
down_revision: Union[str, None] = "142219b495ef"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("analysis", sa.Column("log_probs", sa.LargeBinary(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("analysis", "log_probs")
# ### end Alembic commands ###