diff --git a/app/api/endpoints/analysis_endpoints.py b/app/api/endpoints/analysis_endpoints.py index bb70798..21497f0 100644 --- a/app/api/endpoints/analysis_endpoints.py +++ b/app/api/endpoints/analysis_endpoints.py @@ -24,6 +24,11 @@ logger = logging.getLogger(__name__) +@router.post("/create", response_model=AnalysisRead) +async def create_analysis_test(data: AnalysisRead) -> AnalysisRead: + pass + + @router.get("/{analysis_id}", response_model=AnalysisRead) async def get_analysis( analysis_id: UUID, diff --git a/app/core/llm/prompts.py b/app/core/llm/prompts.py index 0eb88aa..9b62df1 100644 --- a/app/core/llm/prompts.py +++ b/app/core/llm/prompts.py @@ -109,6 +109,27 @@ class AnalysisPrompt: Limitations of Previous Response: [LLM fills this] Overall Confidence in Previous Response (0-100): [LLM provides numerical score] """ + GET_ANTH_CONFIDENCE = """ + Statement: {statement} + Label: {label} + Is the given label correct for the statement? + Answer yes or no: + """ + + GET_ANTH_CONFIDENCE_MOD = """ + Statement: {statement} + Label: {label} + Explanation: {explanation} + Given the generated explanation, is the given label correct for the statement? + Answer yes or no: + """ + + GET_ANTH_CONFIDENCE_MOD_2 = """ + Statement: {statement} + Veracity score (a score from 0 to 100, where 0 represents definitively false and 100 represents definitively true): {score} + Is the given veracity score correct for the statement? + Answer yes or no: + """ HIGH_ASSERT = """ diff --git a/app/core/llm/together_ai_llama.py b/app/core/llm/together_ai_llama.py index fad9f66..817fc83 100644 --- a/app/core/llm/together_ai_llama.py +++ b/app/core/llm/together_ai_llama.py @@ -65,7 +65,7 @@ async def generate_response(self, messages: List[Message], temperature: float = messages=[{"role": m.role, "content": m.content} for m in messages], temperature=temperature, # CRITICAL: This enables the confidence data you need - logprobs=1, + logprobs=5, ) choice = response.choices[0] diff --git a/app/models/database/models.py b/app/models/database/models.py index 766ddf2..625f428 100644 --- a/app/models/database/models.py +++ b/app/models/database/models.py @@ -17,6 +17,7 @@ text, ARRAY, DOUBLE_PRECISION, + LargeBinary, ) from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -140,6 +141,8 @@ class AnalysisModel(Base): index=True, ) + log_probs: Mapped[bytes] = mapped_column(LargeBinary, nullable=True) + claim: Mapped["ClaimModel"] = relationship(back_populates="analyses", doc="Related claim") searches: Mapped[List["SearchModel"]] = relationship(back_populates="analysis", cascade="all, delete-orphan") feedbacks: Mapped[List["FeedbackModel"]] = relationship( diff --git a/app/models/domain/analysis.py b/app/models/domain/analysis.py index 2a6cac5..b107940 100644 --- a/app/models/domain/analysis.py +++ b/app/models/domain/analysis.py @@ -1,13 +1,22 @@ from dataclasses import dataclass from datetime import datetime -from typing import Optional, List +from typing import Optional, List, Dict from uuid import UUID +import pickle from app.models.database.models import AnalysisModel, AnalysisStatus from app.models.domain.feedback import Feedback from app.models.domain.search import Search +@dataclass +class LogProbsData: + anth_conf_score: float + tokens: List[str] + probs: List[float] + alternatives: List[Dict[str, float]] + + @dataclass class Analysis: """Domain model for claim analysis.""" @@ -20,12 +29,22 @@ class Analysis: status: str created_at: datetime updated_at: datetime + log_probs: Optional[LogProbsData] = None searches: Optional[List["Search"]] = None feedback: Optional[List["Feedback"]] = None @classmethod def from_model(cls, model: "AnalysisModel") -> "Analysis": """Create domain model from database model.""" + + log_probs_obj = None + if model.log_probs: + try: + log_probs_obj = pickle.loads(model.log_probs) + except Exception: + # Fallback in case unpickling fails (e.g., corrupt data) + log_probs_obj = None + return cls( id=model.id, claim_id=model.claim_id, @@ -35,12 +54,44 @@ def from_model(cls, model: "AnalysisModel") -> "Analysis": status=model.status.value, created_at=model.created_at, updated_at=model.updated_at, + log_probs=log_probs_obj, searches=[Search.from_model(s) for s in model.searches] if model.searches else None, feedback=[Feedback.from_model(f) for f in model.feedbacks] if model.feedbacks else None, ) + @classmethod + def from_model_safe(cls, model: "AnalysisModel") -> "Analysis": + """Create domain model from database model, explicitly ignoring relationships.""" + + log_probs_obj = None + if model.log_probs: + try: + log_probs_obj = pickle.loads(model.log_probs) + except Exception: + log_probs_obj = None + + return cls( + id=model.id, + claim_id=model.claim_id, + veracity_score=model.veracity_score, + confidence_score=model.confidence_score, + analysis_text=model.analysis_text, + status=model.status.value, + created_at=model.created_at, + updated_at=model.updated_at, + log_probs=log_probs_obj, + # empty initalization (they are empty at creation) + searches=None, + feedback=None, + ) + def to_model(self) -> "AnalysisModel": """Convert to database model.""" + + log_probs_bytes = None + if self.log_probs: + log_probs_bytes = pickle.dumps(self.log_probs) + return AnalysisModel( id=self.id, claim_id=self.claim_id, @@ -48,4 +99,5 @@ def to_model(self) -> "AnalysisModel": confidence_score=self.confidence_score, analysis_text=self.analysis_text, status=AnalysisStatus(self.status), + log_probs=log_probs_bytes, ) diff --git a/app/repositories/implementations/analysis_repository.py b/app/repositories/implementations/analysis_repository.py index 4ae9e41..ecd7a46 100644 --- a/app/repositories/implementations/analysis_repository.py +++ b/app/repositories/implementations/analysis_repository.py @@ -17,29 +17,11 @@ def __init__(self, session: AsyncSession): super().__init__(session, AnalysisModel) def _to_model(self, analysis: Analysis) -> AnalysisModel: - return AnalysisModel( - id=analysis.id, - claim_id=analysis.claim_id, - veracity_score=analysis.veracity_score, - confidence_score=analysis.confidence_score, - analysis_text=analysis.analysis_text, - status=AnalysisStatus(analysis.status), - ) + return analysis.to_model() def _to_domain(self, model: AnalysisModel) -> Analysis: """Convert database model to domain model without loading relationships.""" - return Analysis( - id=model.id, - claim_id=model.claim_id, - veracity_score=model.veracity_score, - confidence_score=model.confidence_score, - analysis_text=model.analysis_text, - status=model.status.value, - created_at=model.created_at, - updated_at=model.updated_at, - searches=None, - feedback=None, - ) + return Analysis.from_model_safe(model) async def create(self, analysis: Analysis) -> Analysis: """Create new analysis.""" diff --git a/app/schemas/analysis_schema.py b/app/schemas/analysis_schema.py index 0afbf5c..2499e04 100644 --- a/app/schemas/analysis_schema.py +++ b/app/schemas/analysis_schema.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, ConfigDict from datetime import datetime from uuid import UUID +from typing import Optional +from app.models.domain.analysis import LogProbsData class AnalysisCreate(BaseModel): @@ -17,6 +19,7 @@ class AnalysisRead(BaseModel): confidence_score: float analysis_text: str created_at: datetime + log_probs: Optional[LogProbsData] model_config = ConfigDict(from_attributes=True) diff --git a/app/services/analysis_orchestrator.py b/app/services/analysis_orchestrator.py index 8b6df07..7ecb9e7 100644 --- a/app/services/analysis_orchestrator.py +++ b/app/services/analysis_orchestrator.py @@ -11,7 +11,7 @@ from app.core.llm.interfaces import LLMProvider from app.models.database.models import AnalysisStatus, ClaimStatus, ConversationStatus, MessageSenderType from app.models.domain.claim import Claim -from app.models.domain.analysis import Analysis +from app.models.domain.analysis import Analysis, LogProbsData from app.models.domain.search import Search from app.models.domain.message import Message from app.core.llm.messages import Message as LLMMessage @@ -203,8 +203,8 @@ async def _generate_analysis( yield {"type": "content", "content": chunk.text} else: full_text = "".join(analysis_text) - logger.warning(f"length {len(analysis_text)}, {analysis_text}") - logger.warning(f"length {len(log_probs)}, {log_probs}") + # logger.warning(f"length {len(analysis_text)}, {analysis_text}") + # logger.warning(f"length {len(log_probs)}, {log_probs}") try: # Clean the text before parsing @@ -247,6 +247,8 @@ async def _generate_analysis( con_score = await self._generate_logprob_confidence_score(log_probs=log_probs) logger.info(con_score) current_analysis.confidence_score = float(con_score) + # log_data = await self._get_anth_confidence_score(statement=claim_text, veracity_score=veracity_score) + # current_analysis.log_probs = log_data updated_analysis = await self._analysis_repo.update(current_analysis) @@ -748,6 +750,55 @@ async def _generate_confidence_score(self, statement: str, analysis: str, source return "" + async def _get_anth_confidence_score(self, statement: str, veracity_score: float): + # label = 'true' + # if veracity_score < 50: + # label ='false' + # elif veracity_score == 50: + # label = 'unknown' + + messages = [ + LLMMessage( + role="user", + content=AnalysisPrompt.GET_ANTH_CONFIDENCE_MOD_2.format( + date=datetime.now().isoformat(), + statement=statement, + score=veracity_score, + ), + ) + ] + response = await self._llm.generate_response(messages) + raw_logprobs = response.metadata.get("raw_logprobs") + + # Initialize defaults + log_probs_obj = LogProbsData(anth_conf_score=0, tokens=[], probs=[], alternatives=[]) + + if raw_logprobs: + # 2. Get the log_prob for "Yes" from the first token's alternatives + # First entry in top_logprobs list, get the value for key 'Yes' + first_token_alts = raw_logprobs.top_logprobs[0] + + p_yes = sum(math.exp(val) for key, val in first_token_alts.items() if key.strip().lower() == "yes") + + p_no = sum(math.exp(val) for key, val in first_token_alts.items() if key.strip().lower() == "no") + + # Convert log probability to linear probability: p = exp(log_p) + pvlm_score = 0 + if p_yes + p_no == 0: + pvlm_score = 0.5 # Neutral/Uncertain + else: + pvlm_score = p_yes / (p_yes + p_no) + + # 3. Construct the LogProbsData object + log_probs_obj = LogProbsData( + anth_conf_score=pvlm_score, + tokens=raw_logprobs.tokens, + probs=raw_logprobs.token_logprobs, + alternatives=raw_logprobs.top_logprobs, + ) + logger.info(f"Confidence: {pvlm_score}, Data: {log_probs_obj}") + return log_probs_obj + def _extract_search_query_or_none( self, assistant_response: str, diff --git a/migrations/versions/d2ffae797992_logits_column.py b/migrations/versions/d2ffae797992_logits_column.py new file mode 100644 index 0000000..42b5a68 --- /dev/null +++ b/migrations/versions/d2ffae797992_logits_column.py @@ -0,0 +1,30 @@ +"""logits column + +Revision ID: d2ffae797992 +Revises: 142219b495ef +Create Date: 2025-12-15 16:55:40.163460 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "d2ffae797992" +down_revision: Union[str, None] = "142219b495ef" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("analysis", sa.Column("log_probs", sa.LargeBinary(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("analysis", "log_probs") + # ### end Alembic commands ###