/* ENSSA — RiskTransitionSankey.
   2-column Sankey (term-pair comparison) with segmented pair selector,
   Show Unknown toggle (auto-on on mobile), flow hover/dim/tooltip. Pure SVG. */

/* global React, useMediaQuery */

const SANKEY_COLORS = {
  high:    '#C5283D',
  moderate:'#C97A1F',
  low:     '#2E7D5B',
  unknown: '#9ca3af',
};
const BAND_LABELS = {
  high:'High risk', moderate:'Moderate risk', low:'Low risk', unknown:'Unknown',
};
const TERM_LABEL = { t1:'Term 1', t3:'Term 3', t4:'Term 4' };

function generateFlowPath({ fromX, toX, fromY, toY, height: h }) {
  const cx1 = fromX + (toX - fromX) * 0.4;
  const cx2 = fromX + (toX - fromX) * 0.6;
  return `M ${fromX} ${fromY} C ${cx1} ${fromY}, ${cx2} ${toY}, ${toX} ${toY} L ${toX} ${toY+h} C ${cx2} ${toY+h}, ${cx1} ${fromY+h}, ${fromX} ${fromY+h} Z`;
}

function layoutNodes(counts, bands, totalH, gap) {
  const total = bands.reduce((s, b) => s + (counts[b] || 0), 0) || 1;
  const nodes = []; let y = 0;
  for (const b of bands) {
    const n = counts[b] || 0;
    if (!n) continue;
    const h = Math.max(4, (n / total) * (totalH - gap * (bands.length - 1)));
    nodes.push({ band: b, y, h, n });
    y += h + gap;
  }
  return nodes;
}

function buildFlows(leftNodes, rightNodes, flows, leftX, rightX) {
  const offL = {}, offR = {};
  leftNodes.forEach(n  => { offL[n.band] = 0; });
  rightNodes.forEach(n => { offR[n.band] = 0; });
  const nodeL = {}, nodeR = {};
  leftNodes.forEach(n  => { nodeL[n.band] = n; });
  rightNodes.forEach(n => { nodeR[n.band] = n; });

  // Walk flows in a fixed band order so they tile cleanly on both sides
  const BAND_ORDER = ['high','moderate','low','unknown'];
  const sorted = [...flows].sort((a, b) => {
    const ai = BAND_ORDER.indexOf(a.to), bi = BAND_ORDER.indexOf(b.to);
    const af = BAND_ORDER.indexOf(a.from), bf = BAND_ORDER.indexOf(b.from);
    return af !== bf ? af - bf : ai - bi;
  });

  return sorted
    .filter(f => nodeL[f.from] && nodeR[f.to] && f.n > 0)
    .map(f => {
      const lNode = nodeL[f.from], rNode = nodeR[f.to];
      // height = proportion of FROM node's own students (not total)
      const h = Math.max(1, (f.n / lNode.n) * lNode.h);
      const fromY = lNode.y + (offL[f.from] || 0);
      const toY   = rNode.y + (offR[f.to]   || 0);
      offL[f.from] = (offL[f.from] || 0) + h;
      offR[f.to]   = (offR[f.to]   || 0) + h;
      return {
        key: `${f.from}-${f.to}-${f.fromTerm}-${f.toTerm}`,
        from: f.from, to: f.to, n: f.n,
        fromTerm: f.fromTerm, toTerm: f.toTerm,
        path: generateFlowPath({ fromX: leftX, toX: rightX, fromY, toY, height: h }),
        color: SANKEY_COLORS[f.from],
      };
    });
}

/* Term pairs available for comparison (chronological, exactly 2 terms each) */
const PAIR_TERMS = {
  t1t3: ['t1', 't3'],
  t3t4: ['t3', 't4'],
  t1t4: ['t1', 't4'],
};

function RiskTransitionSankey({ transitions, districtId, schoolId, classId, onOpenClass }) {
  const isMobile = useMediaQuery('(max-width: 767px)');
  const [hovered, setHovered]         = React.useState(null);
  const [tooltip, setTooltip]         = React.useState(null);
  // Desktop/tablet: full multi-term selection (original behaviour, all 3 by default)
  const [activeTerms, setActiveTerms] = React.useState(['t1','t3','t4']);
  // Mobile-only: pair selection (one of three pairs)
  const [mobilePair, setMobilePair]   = React.useState('t1t4');
  const [showUnknown, setShowUnknown] = React.useState(true);
  const [drillFlow, setDrillFlow]     = React.useState(null); // { fromTerm, toTerm, fromRisk, toRisk }

  // Effective unknown: forced on for mobile (no toggle visible there)
  const effectiveShowUnknown = isMobile ? true : showUnknown;
  // Effective active terms: mobile uses the pair, desktop/tablet uses the multi-toggle
  const effectiveActiveTerms = isMobile ? PAIR_TERMS[mobilePair] : activeTerms;

  const toggleTerm = (t) => {
    setActiveTerms(prev => {
      if (prev.includes(t)) {
        if (prev.length <= 2) return prev;
        return prev.filter(x => x !== t);
      }
      return [...prev, t].sort();
    });
  };

  const data = React.useMemo(() => {
    if (!transitions || !transitions.length) return null;
    const counts = (term) => {
      const c = { high: 0, moderate: 0, low: 0, unknown: 0 };
      transitions.forEach(t => { const v = t[term+'Risk']; if (v && c[v] !== undefined) c[v]++; });
      if (!effectiveShowUnknown) c.unknown = 0;
      return c;
    };
    const flowsBetween = (tA, tB) => {
      const map = {};
      transitions.forEach(t => {
        const a = t[tA+'Risk'], b = t[tB+'Risk'];
        if (!a || !b) return;
        if (!effectiveShowUnknown && (a === 'unknown' || b === 'unknown')) return;
        const k = `${a}__${b}__${tA}__${tB}`;
        if (!map[k]) map[k] = { from: a, to: b, n: 0, fromTerm: tA, toTerm: tB };
        map[k].n++;
      });
      return Object.values(map);
    };
    return {
      t1: counts('t1'), t3: counts('t3'), t4: counts('t4'),
      flows13: flowsBetween('t1','t3'),
      flows34: flowsBetween('t3','t4'),
      flows14: flowsBetween('t1','t4'),
    };
  }, [transitions, effectiveShowUnknown]);

  if (!data) {
    return <div style={{ padding: 24, color: 'var(--fg-3)', fontSize: 13 }}>No transition data available.</div>;
  }

  // ── Geometry: same logic as original; mobile is always 2-col via PAIR_TERMS ──
  const at = effectiveActiveTerms;
  const is3 = at.length === 3;
  const mode = is3 ? 't1t3t4'
    : (at.includes('t1') && at.includes('t3')) ? 't1t3'
    : (at.includes('t3') && at.includes('t4')) ? 't3t4'
    : 't1t4';

  const W = is3 ? 800 : 500, H = 320, INNER_H = 260, GAP = 6, NODE_W = 14, TOP = 36;
  const BANDS_T1   = ['high','moderate','low'];
  const BANDS_REST = effectiveShowUnknown ? ['high','moderate','low','unknown'] : ['high','moderate','low'];

  const COL_X = is3
    ? { t1: 140, t3: 400, t4: 660 }
    : mode === 't1t3' ? { t1: 140, t3: 360 }
    : mode === 't3t4' ? { t3: 140, t4: 360 }
    : { t1: 140, t4: 360 };

  const off = nodes => nodes.map(n => ({ ...n, y: n.y + TOP }));
  const nMap = {};
  if (mode === 't1t3t4') {
    nMap.t1 = off(layoutNodes(data.t1, BANDS_T1,   INNER_H, GAP));
    nMap.t3 = off(layoutNodes(data.t3, BANDS_REST,  INNER_H, GAP));
    nMap.t4 = off(layoutNodes(data.t4, BANDS_REST,  INNER_H, GAP));
  } else if (mode === 't1t3') {
    nMap.t1 = off(layoutNodes(data.t1, BANDS_T1,   INNER_H, GAP));
    nMap.t3 = off(layoutNodes(data.t3, BANDS_REST,  INNER_H, GAP));
  } else if (mode === 't3t4') {
    nMap.t3 = off(layoutNodes(data.t3, BANDS_REST,  INNER_H, GAP));
    nMap.t4 = off(layoutNodes(data.t4, BANDS_REST,  INNER_H, GAP));
  } else {
    nMap.t1 = off(layoutNodes(data.t1, BANDS_T1,   INNER_H, GAP));
    nMap.t4 = off(layoutNodes(data.t4, BANDS_REST,  INNER_H, GAP));
  }

  const allFlows = [];
  if (mode === 't1t3t4') {
    allFlows.push(...buildFlows(nMap.t1, nMap.t3, data.flows13, COL_X.t1 + NODE_W, COL_X.t3));
    allFlows.push(...buildFlows(nMap.t3, nMap.t4, data.flows34, COL_X.t3 + NODE_W, COL_X.t4));
  } else if (mode === 't1t3') {
    allFlows.push(...buildFlows(nMap.t1, nMap.t3, data.flows13, COL_X.t1 + NODE_W, COL_X.t3));
  } else if (mode === 't3t4') {
    allFlows.push(...buildFlows(nMap.t3, nMap.t4, data.flows34, COL_X.t3 + NODE_W, COL_X.t4));
  } else {
    allFlows.push(...buildFlows(nMap.t1, nMap.t4, data.flows14, COL_X.t1 + NODE_W, COL_X.t4));
  }

  const firstTermKey = at[0];
  const lastTermKey  = at[at.length - 1];
  const totalTracked = (nMap[firstTermKey] || []).reduce((s, n) => s + n.n, 0);

  const handleEnter = (f, e) => { setHovered(f); setTooltip({ x: e.clientX, y: e.clientY }); };
  const handleLeave = ()    => { setHovered(null); setTooltip(null); };
  const handleMove  = (e)   => { if (hovered) setTooltip({ x: e.clientX, y: e.clientY }); };

  return (
    <div style={{ position: 'relative' }}>
      {/* Controls — desktop/tablet: original 3-toggle + unknown checkbox.
                     mobile: centred pair pill (no unknown toggle). */}
      {isMobile ? (
        <div style={{ display: 'flex', justifyContent: 'center', marginBottom: 12 }}>
          <div role="radiogroup" aria-label="Term pair to compare"
            style={{
              display: 'inline-flex',
              background: 'var(--ink-50)',
              borderRadius: 8,
              padding: 2,
              border: '1px solid var(--border-subtle)',
            }}>
            {[['t1t3', 'T1 → T3'], ['t3t4', 'T3 → T4'], ['t1t4', 'T1 → T4']].map(([id, lbl]) => {
              const active = mobilePair === id;
              return (
                <button key={id} role="radio" aria-checked={active}
                  onClick={() => setMobilePair(id)}
                  style={{
                    border: 0,
                    background: active ? '#fff' : 'transparent',
                    boxShadow: active ? 'var(--sh-1)' : 'none',
                    padding: '6px 14px',
                    fontSize: 12,
                    fontWeight: 700,
                    color: active ? 'var(--fg-1)' : 'var(--fg-3)',
                    borderRadius: 6,
                    cursor: 'pointer',
                    fontFamily: 'inherit',
                    whiteSpace: 'nowrap',
                    transition: 'background 120ms, color 120ms',
                  }}>
                  {lbl}
                </button>
              );
            })}
          </div>
        </div>
      ) : (
        <div style={{ display: 'flex', alignItems: 'center', gap: 16, marginBottom: 12, flexWrap: 'wrap' }}>
          <div style={{ display: 'flex', gap: 4 }}>
            {['t1','t3','t4'].map(t => (
              <button key={t} onClick={() => toggleTerm(t)} style={{
                padding: '5px 12px', borderRadius: 6, fontSize: 12, fontWeight: 700,
                cursor: 'pointer', fontFamily: 'inherit',
                border: '1px solid var(--border-default)',
                background: at.includes(t) ? 'var(--enssa-aubergine)' : '#fff',
                color:      at.includes(t) ? '#fff' : 'var(--fg-2)',
              }}>{TERM_LABEL[t]}</button>
            ))}
          </div>
          <label style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 12, fontWeight: 600, color: 'var(--fg-2)', cursor: 'pointer' }}>
            <input type="checkbox" checked={showUnknown} onChange={e => setShowUnknown(e.target.checked)} />
            Show unknown
          </label>
        </div>
      )}

      <svg viewBox={`0 0 ${W} ${H}`} style={{ width: '100%', height: 'auto', display: 'block', overflow: 'visible' }}>
        {/* Flows */}
        {allFlows.map(f => {
          const isHov = hovered && hovered.key === f.key;
          const opacity = hovered ? (isHov ? 0.82 : 0.1) : 0.45;
          return (
            <path key={f.key} d={f.path} fill={f.color} opacity={opacity}
              style={{ cursor: 'pointer', transition: 'opacity 150ms ease' }}
              onMouseEnter={e => handleEnter(f, e)}
              onMouseLeave={handleLeave}
              onMouseMove={handleMove}
              onClick={() => setDrillFlow({ fromTerm: f.fromTerm, toTerm: f.toTerm, fromRisk: f.from, toRisk: f.to })} />
          );
        })}
        {/* Node columns */}
        {Object.entries(nMap).map(([termKey, nodes], colIdx) => {
          const isFirst = colIdx === 0;
          const isLast  = termKey === lastTermKey;
          // A band is "new" if it didn't appear in ANY earlier column
          const earlierBands = new Set(
            Object.entries(nMap)
              .filter((_, i) => i < colIdx)
              .flatMap(([, ns]) => ns.map(n => n.band))
          );
          const x = COL_X[termKey] || 0;
          return (
            <g key={termKey}>
              <text x={x + NODE_W/2} y={TOP - 12} textAnchor="middle"
                fontSize={isMobile ? "16" : "10"} fontWeight="700" fill="var(--fg-3)" letterSpacing="1">
                {TERM_LABEL[termKey].toUpperCase()}
              </text>
              {nodes.map(n => {
                const isNew = !isFirst && !earlierBands.has(n.band);
                const labelFs = isMobile ? "18" : "11";
                return (
                  <g key={n.band}>
                    <rect x={x} y={n.y} width={NODE_W} height={n.h} fill={SANKEY_COLORS[n.band]} rx="2" />
                    {isFirst && !isLast && (
                      <text x={x - 8} y={n.y + n.h/2 + 4} textAnchor="end" fontSize={labelFs} fontWeight="600" fill="var(--fg-2)">
                        {BAND_LABELS[n.band]}<tspan fontFamily="var(--font-mono)" fill="var(--fg-3)"> {n.n}</tspan>
                      </text>
                    )}
                    {!isFirst && !isLast && (
                      <text x={x - 8} y={n.y + n.h/2 + 4} textAnchor="end" fontSize={labelFs} fontWeight="600" fill="var(--fg-2)">
                        {isNew && <tspan>{BAND_LABELS[n.band]} </tspan>}
                        <tspan fontFamily="var(--font-mono)" fill="var(--fg-3)">{n.n}</tspan>
                      </text>
                    )}
                    {isLast && (
                      <text x={x + NODE_W + 8} y={n.y + n.h/2 + 4} textAnchor="start" fontSize={labelFs} fontWeight="600" fill="var(--fg-2)">
                        {isNew && <tspan>{BAND_LABELS[n.band]} </tspan>}
                        <tspan fontFamily="var(--font-mono)" fill="var(--fg-3)">{n.n}</tspan>
                      </text>
                    )}
                    {isFirst && isLast && (
                      <text x={x - 8} y={n.y + n.h/2 + 4} textAnchor="end" fontSize={labelFs} fontWeight="600" fill="var(--fg-2)">
                        {BAND_LABELS[n.band]}<tspan fontFamily="var(--font-mono)" fill="var(--fg-3)"> {n.n}</tspan>
                      </text>
                    )}
                  </g>
                );
              })}
            </g>
          );
        })}
      </svg>

      {/* Drill-down overlay */}
      {drillFlow && (() => {
        const { fromTerm, toTerm, fromRisk, toRisk } = drillFlow;
        const fromField = fromTerm + 'Risk';
        const toField   = toTerm   + 'Risk';
        const matching  = transitions.filter(s => s[fromField] === fromRisk && s[toField] === toRisk);

        // Determine drill level
        let drillContent = null;
        if (classId) {
          // Class level: list individual students with score change T1→latest term
          const lastTerm  = toTerm;
          const firstTerm = fromTerm;

          // Derive a plausible score for each term from indexScore + risk shift.
          // We use a simple deterministic offset based on risk movement direction.
          function scoreForTerm(s, term) {
            if (term === 't1') return s.indexScore;
            const r1 = s.t1Risk, rT = s[term + 'Risk'];
            const order = ['high','moderate','low'];
            const delta = (order.indexOf(rT) - order.indexOf(r1)) * 8; // ~8pts per band
            return Math.max(1, Math.min(99, s.indexScore + delta));
          }

          drillContent = (
            <div style={{ overflowX: 'auto' }}>
              <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
                <thead>
                  <tr>
                    <th style={{ textAlign: 'left',  padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>Student</th>
                    <th style={{ textAlign: 'right', padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>Score</th>
                  </tr>
                </thead>
                <tbody>
                  {matching.map(s => {
                    const s1 = scoreForTerm(s, firstTerm);
                    const s2 = scoreForTerm(s, lastTerm);
                    const delta = s2 - s1;
                    const hasUnknown = s[firstTerm + 'Risk'] === 'unknown' || s[lastTerm + 'Risk'] === 'unknown';
                    return (
                      <tr key={s.studentId} style={{ borderBottom: '1px solid var(--border-subtle)' }}>
                        <td style={{ padding: '10px 12px', fontWeight: 600 }}>{s.studentName}</td>
                        <td style={{ padding: '10px 12px', textAlign: 'right', fontFamily: 'var(--font-mono)' }}>
                          {hasUnknown ? (
                            <>
                              <span style={{ fontWeight: 700 }}>{s1}</span>
                              <span style={{ color: 'var(--fg-3)', margin: '0 4px' }}>→</span>
                              <span style={{ color: 'var(--fg-3)', fontWeight: 700 }}>?</span>
                            </>
                          ) : (
                            <>
                              <span style={{ fontWeight: 700 }}>{s1}</span>
                              <span style={{ color: 'var(--fg-3)', margin: '0 4px' }}>→</span>
                              <span style={{ fontWeight: 700 }}>{s2}</span>
                              {' '}
                              <span style={{ fontSize: 12, fontWeight: 700, color: delta > 0 ? 'var(--risk-on-track-fg)' : delta < 0 ? 'var(--risk-concern-fg)' : 'var(--fg-3)' }}>
                                ({delta > 0 ? '+' : ''}{delta})
                              </span>
                            </>
                          )}
                        </td>
                      </tr>
                    );
                  })}
                </tbody>
              </table>
            </div>
          );
        } else if (schoolId) {
          // School level: breakdown by class, grouped by year level
          const byYear = {};
          matching.forEach(s => {
            const yr = s.year || 'Unknown';
            if (!byYear[yr]) byYear[yr] = {};
            const cl = s.classId || 'Unknown';
            byYear[yr][cl] = (byYear[yr][cl] || 0) + 1;
          });
          drillContent = (
            <div style={{ display: 'flex', flexDirection: 'column', gap: 16 }}>
              {Object.entries(byYear).sort().map(([yr, classes]) => (
                <div key={yr}>
                  <div style={{ fontSize: 11, fontWeight: 700, color: 'var(--fg-3)', textTransform: 'uppercase', letterSpacing: '0.08em', marginBottom: 8 }}>
                    {yr === 'F' ? 'Foundation' : yr === 'Y1' ? 'Year 1' : yr}
                  </div>
                  <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
                    <thead>
                      <tr>
                        <th style={{ textAlign: 'left', padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>Class</th>
                        <th style={{ textAlign: 'right', padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>Students</th>
                      </tr>
                    </thead>
                    <tbody>
                      {Object.entries(classes).sort((a,b) => b[1]-a[1]).map(([cl, n]) => (
                        <tr key={cl} style={{ borderBottom: '1px solid var(--border-subtle)' }}>
                          <td style={{ padding: '10px 12px', fontWeight: 600 }}>
                            {onOpenClass ? (
                              <button onClick={() => onOpenClass(cl)}
                                style={{ background: 'none', border: 0, padding: 0, color: 'var(--accent)', fontWeight: 700, fontSize: 13, cursor: 'pointer', fontFamily: 'inherit' }}>
                                Class {cl}
                              </button>
                            ) : `Class ${cl}`}
                          </td>
                          <td style={{ padding: '10px 12px', textAlign: 'right', fontFamily: 'var(--font-mono)', fontWeight: 700 }}>
                            {n} <span style={{ color: 'var(--fg-3)', fontWeight: 400 }}>({Math.round(n/matching.length*100)}%)</span>
                          </td>
                        </tr>
                      ))}
                    </tbody>
                  </table>
                </div>
              ))}
            </div>
          );
        } else {
          // District level: breakdown by school
          const bySchool = {};
          matching.forEach(s => {
            const k = s.schoolName || s.schoolId || 'Unknown';
            bySchool[k] = (bySchool[k] || 0) + 1;
          });
          drillContent = (
            <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
              <thead>
                <tr>
                  <th style={{ textAlign: 'left', padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>School</th>
                  <th style={{ textAlign: 'right', padding: '8px 12px', fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', borderBottom: '1px solid var(--border-subtle)', background: 'var(--ink-50)' }}>Students</th>
                </tr>
              </thead>
              <tbody>
                {Object.entries(bySchool).sort((a,b) => b[1]-a[1]).map(([name, n]) => (
                  <tr key={name} style={{ borderBottom: '1px solid var(--border-subtle)' }}>
                    <td style={{ padding: '10px 12px', fontWeight: 600 }}>{name}</td>
                    <td style={{ padding: '10px 12px', textAlign: 'right', fontFamily: 'var(--font-mono)', fontWeight: 700 }}>
                      {n} <span style={{ color: 'var(--fg-3)', fontWeight: 400 }}>({Math.round(n/matching.length*100)}%)</span>
                    </td>
                  </tr>
                ))}
              </tbody>
            </table>
          );
        }

        return (
          <div style={{ marginTop: 18, background: 'var(--ink-50)', border: '1px solid var(--border-default)', borderRadius: 10, overflow: 'hidden' }}>
            <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'flex-start', padding: '16px 20px 12px', borderBottom: '1px solid var(--border-subtle)', background: '#fff' }}>
              <div>
                <div style={{ fontSize: 11, fontWeight: 700, textTransform: 'uppercase', letterSpacing: '0.08em', color: 'var(--fg-3)', marginBottom: 4 }}>
                  {TERM_LABEL[fromTerm]} → {TERM_LABEL[toTerm]}
                </div>
                <h3 style={{ fontSize: 16, fontWeight: 700, margin: 0 }}>
                  <span style={{ color: SANKEY_COLORS[fromRisk] }}>●</span> {BAND_LABELS[fromRisk]}
                  {' '}<span style={{ color: 'var(--fg-3)', fontWeight: 400 }}>→</span>{' '}
                  <span style={{ color: SANKEY_COLORS[toRisk] }}>●</span> {BAND_LABELS[toRisk]}
                </h3>
                <div style={{ fontSize: 12, color: 'var(--fg-3)', marginTop: 4, fontFamily: 'var(--font-mono)' }}>
                  {matching.length} student{matching.length !== 1 ? 's' : ''}
                </div>
              </div>
              <button onClick={() => setDrillFlow(null)}
                style={{ background: 'var(--ink-100)', border: '1px solid var(--border-default)', borderRadius: 6, width: 28, height: 28, cursor: 'pointer', fontSize: 16, color: 'var(--fg-2)', display: 'flex', alignItems: 'center', justifyContent: 'center', flexShrink: 0 }}>×</button>
            </div>
            <div style={{ padding: '12px 0' }}>{drillContent}</div>
          </div>
        );
      })()}

      {/* Tooltip */}
      {hovered && tooltip && (
        <div style={{
          position: 'fixed', left: tooltip.x + 14, top: tooltip.y - 10,
          background: 'var(--ink-800)', color: '#fff',
          padding: '8px 12px', borderRadius: 8, fontSize: 12, lineHeight: 1.5,
          pointerEvents: 'none', boxShadow: 'var(--sh-3)', zIndex: 100, whiteSpace: 'nowrap',
        }}>
          <span style={{ color: SANKEY_COLORS[hovered.from], fontWeight: 700 }}>{BAND_LABELS[hovered.from]}</span>
          <span style={{ color: 'rgba(255,255,255,0.5)' }}> ({TERM_LABEL[hovered.fromTerm]}) → </span>
          <span style={{ color: SANKEY_COLORS[hovered.to], fontWeight: 700 }}>{BAND_LABELS[hovered.to]}</span>
          <span style={{ color: 'rgba(255,255,255,0.5)' }}> ({TERM_LABEL[hovered.toTerm]})</span>
          <br />
          <span style={{ fontFamily: 'var(--font-mono)', fontWeight: 700 }}>{hovered.n}</span>
          <span style={{ color: 'rgba(255,255,255,0.6)' }}> students</span>
        </div>
      )}
    </div>
  );
}

Object.assign(window, { RiskTransitionSankey });
