📦 EqualifyEverything / equalify-reflow

📄 test_document_context.py · 146 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146"""Tests for DocumentSummary and ObservationContext models."""


from src.shared.models.document_context import DocumentSummary, ObservationContext
from src.shared.models.observation import Observation, ObservationLocation
from src.shared.models.remediation import HeadingTree


class TestDocumentSummary:
    """Tests for the DocumentSummary model."""

    def test_create_document_summary(self):
        """Test creating a document summary."""
        summary = DocumentSummary(
            title="yt: An Open Source Framework",
            document_type="research_paper",
            topic_summary="Research paper about yt, an analysis framework.",
            structure_summary="9 pages, two-column layout",
            key_entities=["yt", "Enzo", "Matthew Turk"],
            domain_terms=["parallelization", "MPI", "OpenMP"],
            expected_elements=["abstract", "introduction", "references"],
            audience_level="academic",
        )

        assert summary.title == "yt: An Open Source Framework"
        assert summary.document_type == "research_paper"
        assert len(summary.key_entities) == 3
        assert "yt" in summary.key_entities
        assert summary.audience_level == "academic"

    def test_default_values(self):
        """Test default values for optional fields."""
        summary = DocumentSummary(
            title="Test Document",
            document_type="other",
            topic_summary="A test document",
        )

        assert summary.structure_summary == ""
        assert summary.key_entities == []
        assert summary.domain_terms == []
        assert summary.expected_elements == []
        assert summary.audience_level == "general"

    def test_json_serialization(self):
        """Test JSON serialization and deserialization."""
        summary = DocumentSummary(
            title="Test",
            document_type="syllabus",
            topic_summary="Course syllabus",
            key_entities=["CS101", "Professor Smith"],
        )

        json_str = summary.model_dump_json()
        restored = DocumentSummary.model_validate_json(json_str)

        assert restored.title == summary.title
        assert restored.key_entities == summary.key_entities


class TestObservationContext:
    """Tests for the ObservationContext model."""

    def _create_observation(self) -> Observation:
        """Helper to create a test observation."""
        return Observation(
            job_id="job-123",
            agent="figures",
            visual_description="Image shows a chart",
            markup_description="Empty alt text",
            location=ObservationLocation(
                location_type="element",
                value="img[src='fig1.png']",
                page_num=1,
            ),
        )

    def _create_summary(self) -> DocumentSummary:
        """Helper to create a test summary."""
        return DocumentSummary(
            title="Test Document",
            document_type="research_paper",
            topic_summary="A research paper",
        )

    def _create_heading_tree(self) -> HeadingTree:
        """Helper to create a test heading tree."""
        return HeadingTree(
            document_title="Test Document",
            sections=[],
        )

    def test_create_observation_context(self):
        """Test creating an observation context."""
        obs = self._create_observation()
        summary = self._create_summary()
        tree = self._create_heading_tree()

        context = ObservationContext(
            observation=obs,
            document_summary=summary,
            heading_tree=tree,
            markdown_excerpt="## Introduction\n\nSome text...",
            before_context="# Title\n\n",
            after_context="\n\n## Methods",
            page_num=1,
        )

        assert context.observation.job_id == "job-123"
        assert context.document_summary.title == "Test Document"
        assert context.page_num == 1

    def test_optional_visual_context(self):
        """Test that visual context is optional."""
        obs = self._create_observation()
        summary = self._create_summary()
        tree = self._create_heading_tree()

        context = ObservationContext(
            observation=obs,
            document_summary=summary,
            heading_tree=tree,
            markdown_excerpt="text",
            page_num=1,
        )

        assert context.page_image_base64 is None
        assert context.line_range is None

    def test_line_range(self):
        """Test line range field."""
        obs = self._create_observation()
        summary = self._create_summary()
        tree = self._create_heading_tree()

        context = ObservationContext(
            observation=obs,
            document_summary=summary,
            heading_tree=tree,
            markdown_excerpt="text",
            page_num=1,
            line_range=(10, 25),
        )

        assert context.line_range == (10, 25)