diff --git a/public/models/ggml-model-whisper-base.en-q5_1.bin b/public/models/ggml-model-whisper-base.en-q5_1.bin deleted file mode 100644 index 1f4c0515..00000000 Binary files a/public/models/ggml-model-whisper-base.en-q5_1.bin and /dev/null differ diff --git a/public/models/ggml-model-whisper-tiny.en-q5_1.bin b/public/models/ggml-model-whisper-tiny.en-q5_1.bin deleted file mode 100644 index ae17d4ee..00000000 Binary files a/public/models/ggml-model-whisper-tiny.en-q5_1.bin and /dev/null differ diff --git a/src/components/api/returnAPI.tsx b/src/components/api/returnAPI.tsx index 3524d5c6..1adcf4c1 100644 --- a/src/components/api/returnAPI.tsx +++ b/src/components/api/returnAPI.tsx @@ -1,205 +1,277 @@ -// import * as sdk from 'microsoft-cognitiveservices-speech-sdk' -import installCOIServiceWorker from './coi-serviceworker' -import { API, PlaybackStatus } from '../../react-redux&middleware/redux/typesImports'; +// src/components/api/returnAPI.tsx + +import { useEffect, useState } from 'react'; +import { useDispatch, useSelector, batch } from 'react-redux'; +import type { Dispatch } from 'redux'; + +import type { RootState } from '../../store'; + import { - ApiStatus, - AzureStatus, - ControlStatus, - SRecognition, - StreamTextStatus, - ScribearServerStatus + API, + ApiStatus, + AzureStatus, + ControlStatus, + PlaybackStatus, + ScribearServerStatus, + StreamTextStatus, + STATUS, } from '../../react-redux&middleware/redux/typesImports'; -import { useEffect, useState } from 'react'; -import { batch, useDispatch, useSelector } from 'react-redux'; -import { AzureRecognizer } from './azure/azureRecognizer'; -import { Dispatch } from 'redux'; -import { Recognizer } from './recognizer'; -import { RootState } from '../../store'; -import { StreamTextRecognizer } from './streamtext/streamTextRecognizer'; import { TranscriptBlock } from '../../react-redux&middleware/redux/types/TranscriptTypes'; + +// Recognizer interface + concrete backends +import { Recognizer } from './recognizer'; import { WebSpeechRecognizer } from './web-speech/webSpeechRecognizer'; +import { AzureRecognizer } from './azure/azureRecognizer'; +import { StreamTextRecognizer } from './streamtext/streamTextRecognizer'; import { WhisperRecognizer } from './whisper/whisperRecognizer'; +import { ScribearRecognizer } from './scribearServer/scribearRecognizer'; import { PlaybackRecognizer } from './playback/playbackRecognizer'; -import { ScribearRecognizer } from './scribearServer/scribearRecognizer'; +// Model selection type import type { SelectedOption } from '../../react-redux&middleware/redux/types/modelSelection'; -// import { PlaybackReducer } from '../../react-redux&middleware/redux/reducers/apiReducers'; -// controls what api to send and what to do when error handling. - -// NOTES: this needs to do everything I think. Handler should be returned which allows -// event call like stop and the event should be returned... (maybe the recognition? idk.) - -/* -* === * === * DO NOT DELETE IN ANY CIRCUMSTANCE * === * === * -* === * TRIBUTE TO THE ORIGINAL AUTHOR OF THIS CODE: Will * === * -DO NOT DELETE IN ANY CIRCUMSTANCE -export const returnRecogAPI = (api : ApiStatus, control : ControlStatus, azure : AzureStatus) => { - // const apiStatus = useSelector((state: RootState) => { - // return state.APIStatusReducer as ApiStatus; - // }) - // const control = useSelector((state: RootState) => { - // return state.ControlReducer as ControlStatus; - // }); - // const azureStatus = useSelector((state: RootState) => { - // return state.AzureReducer as AzureStatus; - // }) - const recognition : Promise = getRecognition(api.currentApi, control, azure); - const useRecognition : Object = makeRecognition(api.currentApi); - // const recogHandler : Function = handler(api.currentApi); - - - return ({ useRecognition, recognition }); -} -* === * === * DO NOT DELETE IN ANY CIRCUMSTANCE * === * === * -* === * TRIBUTE TO THE ORIGINAL AUTHOR OF THIS CODE: Will * === * -*/ - - - -function toWhisperCode(bcp47: string): string { - // Accept "en", "en-US", etc. Return "en" for "en-US". - if (!bcp47) return "en"; - const base = bcp47.split('-')[0].toLowerCase(); - const supported = new Set([ - 'en','es','fr','de','it','pt','nl','sv','da','nb','fi','pl','cs','sk','sl','hr','sr','bg','ro', - 'hu','el','tr','ru','uk','ar','he','fa','ur','hi','bn','ta','te','ml','mr','gu','kn','pa', - 'id','ms','vi','th','zh','ja','ko' - ]); - return supported.has(base) ? base : 'en'; - } - -const createRecognizer = (currentApi: number, control: ControlStatus, azure: AzureStatus, streamTextConfig: StreamTextStatus, scribearServerStatus: ScribearServerStatus, selectedModelOption: SelectedOption, playbackStatus: PlaybackStatus): Recognizer => { - if (currentApi === API.SCRIBEAR_SERVER) { - return new ScribearRecognizer(scribearServerStatus, selectedModelOption, control.speechLanguage.CountryCode); - } else if (currentApi === API.PLAYBACK) { - return new PlaybackRecognizer(playbackStatus); - } - if (currentApi === API.WEBSPEECH) { - return new WebSpeechRecognizer(null, control.speechLanguage.CountryCode); - } else if (currentApi === API.AZURE_TRANSLATION) { - return new AzureRecognizer(null, control.speechLanguage.CountryCode, azure); - } - else if (currentApi === API.AZURE_CONVERSATION) { - throw new Error("Not implemented"); - } - else if (currentApi === API.STREAM_TEXT) { - // Placeholder - this is just WebSpeech for now - return new StreamTextRecognizer(streamTextConfig.streamTextEvent, 'en', streamTextConfig.startPosition); - } else if (currentApi === API.WHISPER) { - return new WhisperRecognizer( - null, - toWhisperCode(control.speechLanguage.CountryCode), - 4 - ); - } else { - throw new Error(`Unexpcted API_CODE: ${currentApi}`); - } -} /** - * Make a callback function that updates the Redux transcript using new final blocks and new - * in-progress block - * - * We have to do things in this roundabout way to have access to dispatch in a callback function, - * see https://stackoverflow.com/questions/59456816/how-to-call-usedispatch-in-a-callback - * @param dispatch A Redux dispatch function + * Normalize BCP-47 speech language like "en-US" into a Whisper language code ("en"). */ -const updateTranscript = (dispatch: Dispatch) => (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock): void => { - // console.log(`Updating transcript using these blocks: `, newFinalBlocks, newInProgressBlock) - // batch makes these dispatches only cause one re-rendering - batch(() => { - for (const block of newFinalBlocks) { - dispatch({ type: "transcript/new_final_block", payload: block }); - } - dispatch({ type: 'transcript/update_in_progress_block', payload: newInProgressBlock }); - }) -} +const toWhisperLanguage = (bcp47: string | undefined | null): string => { + if (!bcp47) return 'en'; + const lower = bcp47.toLowerCase(); + const dash = lower.indexOf('-'); + return dash > 0 ? lower.slice(0, dash) : lower; +}; + +/** + * Normalize model keys so the UI and the recognizer agree. + * In particular, tiny-multi β†’ tiny-q5_1 (the multilingual tiny model). + */ +const normalizeModelKey = (raw: string | undefined | null): string => { + if (!raw) return 'tiny-en-q5_1'; + if (raw === 'tiny-multi' || raw === 'tiny-multi-q5_1') return 'tiny-q5_1'; + return raw; +}; + +/** + * Try to extract the model key from whatever shape SelectedOption currently has. + */ +const extractModelKeyFromSelected = (selected?: SelectedOption | null): string => { + if (!selected) return 'tiny-en-q5_1'; + + const obj: any = selected; + const raw = + obj.model_key ?? + obj.id ?? + obj.key ?? + obj.value ?? + obj.name ?? + obj.label ?? + 'tiny-en-q5_1'; + + return normalizeModelKey(raw); +}; + +const isEnglishOnlyModel = (modelKey: string): boolean => + modelKey.includes('-en') || modelKey.endsWith('.en'); + +/** + * Given the current API choice and config, build the appropriate Recognizer. + */ +const createRecognizer = ( + currentApi: number, + control: ControlStatus, + azure: AzureStatus, + streamTextConfig: StreamTextStatus, + scribearServerStatus: ScribearServerStatus, + selectedModelOption: SelectedOption | null, + playbackStatus: PlaybackStatus, +): Recognizer => { + if (currentApi === API.SCRIBEAR_SERVER) { + return new ScribearRecognizer( + scribearServerStatus, + selectedModelOption, + control.speechLanguage.CountryCode, + ); + } + + if (currentApi === API.PLAYBACK) { + return new PlaybackRecognizer(playbackStatus); + } + + if (currentApi === API.WEBSPEECH) { + return new WebSpeechRecognizer(null, control.speechLanguage.CountryCode); + } + + if (currentApi === API.AZURE_TRANSLATION) { + return new AzureRecognizer(null, control.speechLanguage.CountryCode, azure); + } + + if (currentApi === API.AZURE_CONVERSATION) { + throw new Error('Azure Conversation API is not implemented'); + } + + if (currentApi === API.STREAM_TEXT) { + // Placeholder – this recognizer will likely be replaced with a real StreamText backend. + return new StreamTextRecognizer( + streamTextConfig.streamTextEvent, + 'en', + streamTextConfig.startPosition, + ); + } + + if (currentApi === API.WHISPER) { + const modelKey = extractModelKeyFromSelected(selectedModelOption); + const userLang = toWhisperLanguage(control.speechLanguage.CountryCode); + + // If a multilingual model is chosen and user language is still default 'en', + // let whisper auto-detect instead of pinning to English. + const whisperLang = + isEnglishOnlyModel(modelKey) || userLang !== 'en' ? userLang : 'auto'; + + return new WhisperRecognizer( + null, + whisperLang, + 4, // number of threads + modelKey, + ); + } + + throw new Error(`Unexpected API code: ${currentApi}`); +}; + +/** + * Build a transcript-update callback that dispatches Redux actions. + */ +const makeTranscriptUpdater = + (dispatch: Dispatch) => + (newFinalBlocks: TranscriptBlock[], newInProgressBlock: TranscriptBlock) => { + batch(() => { + newFinalBlocks.forEach((block) => { + dispatch({ + type: 'transcript/new_final_block', + payload: block, + }); + }); + + dispatch({ + type: 'transcript/update_in_progress_block', + payload: newInProgressBlock, + }); + }); + }; /** - * Syncs up the recognizer with the API selection and listening status - * - Creates new recognizer and stop old ones when API is changed - * - Start / stop recognizer as listening changes - * - Feed any phrase list updates to azure recognizer - * - * @param recog - * @param api - * @param control - * @param azure - * - * @return transcripts, resetTranscript, recogHandler + * Hook that manages the lifetime of the current Recognizer and returns + * the combined transcript string for the main speaker (index 0). + * + * RecogComponent should call this hook and pass the transcript into STTRenderer. */ -export const useRecognition = (sRecog: SRecognition, api: ApiStatus, control: ControlStatus, - azure: AzureStatus, streamTextConfig: StreamTextStatus, scribearServerStatus, selectedModelOption: SelectedOption, playbackStatus: PlaybackStatus) => { +export const useRecognition = ( + _sRecog: any, // kept for compatibility with existing call sites + apiStatus: ApiStatus, + control: ControlStatus, + azure: AzureStatus, + streamTextConfig: StreamTextStatus, + scribearServerStatus: ScribearServerStatus, + selectedModelOption: SelectedOption | null, + playbackStatus: PlaybackStatus, +): string => { + const dispatch = useDispatch(); + + // Multi-speaker transcript object is stored in Redux. + const transcriptState: any = useSelector( + (state: RootState) => (state as any).TranscriptReducer, + ); - const [recognizer, setRecognizer] = useState(); - // TODO: Add a reset button to utitlize resetTranscript - // const [resetTranscript, setResetTranscript] = useState<() => string>(() => () => dispatch('RESET_TRANSCRIPT')); - const dispatch = useDispatch(); + const [recognizer, setRecognizer] = useState(null); - // Register service worker for whisper on launch - useEffect(() => { - installCOIServiceWorker(); - }, []) + // (Re)create recognizer whenever API or its configuration changes + useEffect(() => { + const newRecognizer = createRecognizer( + apiStatus.currentApi, + control, + azure, + streamTextConfig, + scribearServerStatus, + selectedModelOption, + playbackStatus, + ); - // Change recognizer, if api changed - useEffect(() => { - console.log("UseRecognition, switching to new recognizer: ", api.currentApi); + newRecognizer.onTranscribed(makeTranscriptUpdater(dispatch)); + newRecognizer.onError((err) => { + console.error('Recognizer error:', err); + dispatch({ + type: 'CHANGE_API_STATUS', + payload: { + ...apiStatus, + webspeechStatus: STATUS.ERROR, + }, + }); + }); + + // Stop any previous recognizer before switching + setRecognizer((prev) => { + if (prev) { + try { + prev.stop(); + } catch (e) { + console.warn('Error stopping old recognizer', e); + } + } + return newRecognizer; + }); - let newRecognizer: Recognizer | null; + // Cleanup when this effect is torn down (e.g. API changed) + return () => { try { - // Create new recognizer, and subscribe to its events - newRecognizer = createRecognizer(api.currentApi, control, azure, streamTextConfig, scribearServerStatus, selectedModelOption, playbackStatus); - newRecognizer.onTranscribed(updateTranscript(dispatch)); - setRecognizer(newRecognizer) - - // Start new recognizer if necessary - if (control.listening) { - console.log("UseRecognition, attempting to start recognizer after switching") - newRecognizer.start() - } + newRecognizer.stop(); } catch (e) { - console.log("UseRecognition, failed to switch to new recognizer: ", e); + console.warn('Error stopping recognizer on cleanup', e); } + }; + }, [ + apiStatus, + control, + azure, + streamTextConfig, + scribearServerStatus, + selectedModelOption, + playbackStatus, + dispatch, + ]); - return () => { - // Stop current recognizer when switching to another one, if possible - newRecognizer?.stop(); - } - }, [api.currentApi, azure, control, streamTextConfig, playbackStatus, scribearServerStatus, selectedModelOption]); + // Start/stop recognizer when the listening flag changes + useEffect(() => { + if (!recognizer) return; - // Start / stop recognizer, if listening toggled - useEffect(() => { - if (!recognizer) { // whipser won't have recogHandler - return; + if (control.listening) { + try { + recognizer.start(); + } catch (e) { + console.error('Error starting recognizer', e); } - if (control.listening) { - console.log("UseRecognition, sending start signal to recognizer") - recognizer.start(); - } else if (!control.listening) { - console.log("UseRecognition, sending stop signal to recognizer") - recognizer.stop(); + } else { + try { + recognizer.stop(); + } catch (e) { + console.error('Error stopping recognizer', e); } - }, [control.listening]); + } + }, [control.listening, recognizer]); - // Update domain phrases for azure recognizer - useEffect(() => { - console.log("UseRecognition, changing azure phrases", azure.phrases); - if (api.currentApi === API.AZURE_TRANSLATION && recognizer) { - (recognizer as AzureRecognizer).addPhrases(azure.phrases); - } - }, [azure.phrases]); - - // TODO: whisper's transcript is not in redux store but only in sessionStorage at the moment. - let transcript: string = useSelector((state: RootState) => { - return state.TranscriptReducer.transcripts[0].toString() - }); - // if (api.currentApi === API.WHISPER) { - // // TODO: inefficient to get it from sessionStorage everytime - // // TODO: add whisper_transcript to redux store after integrating "whisper" folder (containing stream.js) into ScribeAR - // transcript = sessionStorage.getItem('whisper_transcript') || ''; - // return transcript; - // } - - return transcript; -} + // Turn the multi-speaker transcript into a single string for speaker 0 + const transcript0 = transcriptState?.transcripts?.[0]; + + let transcriptText = ''; + if (transcript0) { + if (typeof transcript0.toString === 'function') { + transcriptText = transcript0.toString(); + } else if (typeof transcript0.text === 'string') { + transcriptText = transcript0.text; + } + } + + return transcriptText; +}; + +// Allow default import as well: import useRecognition from "./returnAPI"; +export default useRecognition; diff --git a/src/components/api/whisper/indexedDB.js b/src/components/api/whisper/indexedDB.js index e906a4f0..14299aff 100644 --- a/src/components/api/whisper/indexedDB.js +++ b/src/components/api/whisper/indexedDB.js @@ -1,169 +1,133 @@ -// HACK: moving global variables from index.html to here for loadRemote - -let dbVersion = 1 -let dbName = 'whisper.ggerganov.com'; -let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB - -// fetch a remote file from remote URL using the Fetch API -async function fetchRemote(url, cbProgress, cbPrint) { - cbPrint('fetchRemote: downloading with fetch()...'); - - const response = await fetch( - url, - { - method: 'GET', - } - ); - - if (!response.ok) { - cbPrint('fetchRemote: failed to fetch ' + url); - return; - } - - const contentLength = response.headers.get('content-length'); - const total = parseInt(contentLength, 10); - const reader = response.body.getReader(); - - var chunks = []; - var receivedLength = 0; - var progressLast = -1; - - while (true) { - const { done, value } = await reader.read(); - - if (done) { - break; - } - - chunks.push(value); - receivedLength += value.length; - - if (contentLength) { - cbProgress(receivedLength/total); - - var progressCur = Math.round((receivedLength / total) * 10); - if (progressCur != progressLast) { - cbPrint('fetchRemote: fetching ' + 10*progressCur + '% ...'); - progressLast = progressCur; - } - } - } +// src/components/api/whisper/indexedDB.js +// Silent IndexedDB caching loader with multi-URL fallbacks. + +var dbVersion = 1; +var dbName = 'whisper.ggerganov.com'; +var indexedDB = + (window.indexedDB || + window.mozIndexedDB || + window.webkitIndexedDB || + window.msIndexedDB); + +// fetch a URL as a Uint8Array (progress callback optional) +async function fetchBinary(url, cbProgress, cbPrint) { + cbPrint && cbPrint('fetchBinary: GET ' + url); + + const res = await fetch(url, { method: 'GET', cache: 'no-cache' }); + if (!res.ok) throw new Error('http-' + res.status); + + // No stream? Just read all at once. + if (!res.body || !res.body.getReader) { + const buf = new Uint8Array(await res.arrayBuffer()); + cbProgress && cbProgress(1); + return buf; + } + + const total = parseInt(res.headers.get('content-length') || '0', 10) || 0; + const reader = res.body.getReader(); + + let received = 0; + const chunks = []; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + received += value.length; + if (total && cbProgress) cbProgress(received / total); + } + + const out = new Uint8Array(received); + let pos = 0; + for (const c of chunks) { out.set(c, pos); pos += c.length; } + if (!total && cbProgress) cbProgress(1); + return out; +} - var position = 0; - var chunksAll = new Uint8Array(receivedLength); +function openDB() { + return new Promise((resolve, reject) => { + const rq = indexedDB.open(dbName, dbVersion); + rq.onupgradeneeded = (ev) => { + const db = ev.target.result; + if (!db.objectStoreNames.contains('models')) { + db.createObjectStore('models', { autoIncrement: false }); + } + }; + rq.onsuccess = () => resolve(rq.result); + rq.onerror = () => reject(new Error('idb-open')); + rq.onblocked = () => reject(new Error('idb-blocked')); + rq.onabort = () => reject(new Error('idb-abort')); + }); +} - for (var chunk of chunks) { - chunksAll.set(chunk, position); - position += chunk.length; - } +async function getCached(db, key) { + try { + return await new Promise((resolve) => { + const tx = db.transaction(['models'], 'readonly'); + const os = tx.objectStore('models'); + const g = os.get(key); + g.onsuccess = () => { + let v = g.result; + if (v && v instanceof ArrayBuffer) v = new Uint8Array(v); + resolve(v || null); + }; + g.onerror = () => resolve(null); + }); + } catch { return null; } +} - return chunksAll; +async function putCached(db, key, data, cbPrint) { + try { + await new Promise((resolve) => { + const tx = db.transaction(['models'], 'readwrite'); + const os = tx.objectStore('models'); + const p = os.put(data, key); + p.onsuccess = () => { cbPrint && cbPrint('IDB: stored ' + key); resolve(); }; + p.onerror = () => resolve(); + }); + } catch (e) { cbPrint && cbPrint('IDB store error: ' + e); } } -// load remote data -// - check if the data is already in the IndexedDB -// - if not, fetch it from the remote URL and store it in the IndexedDB -export function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { - if (!navigator.storage || !navigator.storage.estimate) { - cbPrint('loadRemote: navigator.storage.estimate() is not supported'); - } else { - // query the storage quota and print it - navigator.storage.estimate().then(function (estimate) { - cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes'); - cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes'); - }); +// MAIN: try urls in order; cache successful one under that exact url key. +export async function loadRemoteWithFallbacks(urls, dst, sizeMB, cbProgress, cbReady, cbCancel, cbPrint) { + try { + if (navigator.storage && navigator.storage.estimate) { + const est = await navigator.storage.estimate(); + cbPrint && cbPrint(`IDB quota ~${est.quota}, used ~${est.usage}`); } - - // check if the data is already in the IndexedDB - var rq = indexedDB.open(dbName, dbVersion); - - rq.onupgradeneeded = function (event) { - var db = event.target.result; - if (db.version == 1) { - var os = db.createObjectStore('models', { autoIncrement: false }); - cbPrint('loadRemote: created IndexedDB ' + db.name + ' version ' + db.version); - } else { - // clear the database - var os = event.currentTarget.transaction.objectStore('models'); - os.clear(); - cbPrint('loadRemote: cleared IndexedDB ' + db.name + ' version ' + db.version); + } catch {} + + let db = null; + try { db = await openDB(); } catch { db = null; } + + const tinyBlob = (b) => !b || (b.length || 0) < 4096; + const urlList = (Array.isArray(urls) ? urls : [urls]).filter(Boolean); + + if (urlList.length === 0) { + cbCancel && cbCancel(); + return; + } + + for (const url of urlList) { + try { + // cache hit? + if (db) { + const cached = await getCached(db, url); + if (cached && !tinyBlob(cached)) { + cbReady(dst, cached); + return; } - }; - - rq.onsuccess = function (event) { - var db = event.target.result; - var tx = db.transaction(['models'], 'readonly'); - var os = tx.objectStore('models'); - var rq = os.get(url); - - rq.onsuccess = function (event) { - if (rq.result) { - cbPrint('loadRemote: "' + url + '" is already in the IndexedDB'); - cbReady(dst, rq.result); - } else { - // data is not in the IndexedDB - cbPrint('loadRemote: "' + url + '" is not in the IndexedDB'); - - // alert and ask the user to confirm - if (!confirm( - 'You are about to download ' + size_mb + ' MB of data.\n' + - 'The model data will be cached in the browser for future use.\n\n' + - 'Press OK to continue.')) { - cbCancel(); - return; - } - - fetchRemote(url, cbProgress, cbPrint).then(function (data) { - if (data) { - // store the data in the IndexedDB - var rq = indexedDB.open(dbName, dbVersion); - rq.onsuccess = function (event) { - var db = event.target.result; - var tx = db.transaction(['models'], 'readwrite'); - var os = tx.objectStore('models'); - - var rq = null; - try { - var rq = os.put(data, url); - } catch (e) { - cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e); - cbCancel(); - return; - } - - rq.onsuccess = function (event) { - cbPrint('loadRemote: "' + url + '" stored in the IndexedDB'); - cbReady(dst, data); - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB'); - cbCancel(); - }; - }; - } - }); - } - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to get data from the IndexedDB'); - cbCancel(); - }; - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to open IndexedDB'); - cbCancel(); - }; - - rq.onblocked = function (event) { - cbPrint('loadRemote: failed to open IndexedDB: blocked'); - cbCancel(); - }; + } + + cbPrint && cbPrint(`loadRemote: fetching ~${sizeMB} MB from ${url}`); + const data = await fetchBinary(url, cbProgress, cbPrint); + if (tinyBlob(data)) throw new Error('tiny-blob'); + if (db) await putCached(db, url, data, cbPrint); + cbReady(dst, data); + return; + } catch (e) { + cbPrint && cbPrint(`loadRemote: failed on ${url} (${e}), trying next...`); + } + } - rq.onabort = function (event) { - cbPrint('loadRemote: failed to open IndexedDB: abort'); - cbCancel(); - }; + cbCancel(); } diff --git a/src/components/api/whisper/whisperRecognizer.tsx b/src/components/api/whisper/whisperRecognizer.tsx index ad61e21f..fb55ea65 100644 --- a/src/components/api/whisper/whisperRecognizer.tsx +++ b/src/components/api/whisper/whisperRecognizer.tsx @@ -1,281 +1,404 @@ +// src/components/api/whisper/whisperRecognizer.tsx + import { Recognizer } from '../recognizer'; import { TranscriptBlock } from '../../../react-redux&middleware/redux/types/TranscriptTypes'; import makeWhisper from './libstream'; -import { loadRemote } from './indexedDB' +import { loadRemoteWithFallbacks } from './indexedDB'; import { SIPOAudioBuffer } from './sipo-audio-buffer'; -import RecordRTC, {StereoAudioRecorder} from 'recordrtc'; +import RecordRTC, { StereoAudioRecorder } from 'recordrtc'; import WavDecoder from 'wav-decoder'; -// https://stackoverflow.com/questions/4554252/typed-arrays-in-gecko-2-float32array-concatenation-and-expansion -function Float32Concat(first, second) { - const firstLength = first.length; - const result = new Float32Array(firstLength + second.length); - - result.set(first); - result.set(second, firstLength); +// Base URL for model binaries. +// You can override this with REACT_APP_WEBMODELS_BASE_URL if needed. +const WEBMODEL_BASE_URL: string = + (process.env.REACT_APP_WEBMODELS_BASE_URL as string | undefined) ?? + 'https://raw.githubusercontent.com/scribear/webmodels/main/data/'; - return result; +// https://stackoverflow.com/questions/4554252/typed-arrays-in-gecko-2-float32array-concatenation-and-expansion +function Float32Concat(first: Float32Array, second: Float32Array): Float32Array { + const firstLength = first.length; + const result = new Float32Array(firstLength + second.length); + result.set(first); + result.set(second, firstLength); + return result; } /** * Wrapper for Web Assembly implementation of whisper.cpp */ export class WhisperRecognizer implements Recognizer { - - // - // Audio params - // - private kSampleRate = 16000; - // Length of the suffix of the captured audio that whisper processes each time (in seconds) - private kWindowLength = 5; - // Number of samples in each audio chunk the audio worklet receives - private kChunkLength = 128; - - /** - * Audio context and buffer used to capture speech - */ - private context: AudioContext; - private audio_buffer: SIPOAudioBuffer; - private recorder?: RecordRTC; - - /** - * Instance of the whisper wasm module, and its variables - */ - private whisper: any = null; - private model_name: string = ""; - private model_index: Number = -1; - private language: string = "en"; - private num_threads: number; - - private transcribed_callback: ((newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void) | null = null; - - /** - * Creates an Whisper recognizer instance that listens to the default microphone - * and expects speech in the given language - * @param audioSource Not implemented yet - * @param language Not implemented yet - * @parem num_threads Number of worker threads that whisper uses - */ - constructor(audioSource: any, language: string, num_threads: number = 4, model: string = "tiny-en-q5_1") { - this.num_threads = num_threads; - this.language = language; - this.model_name = model; - - const num_chunks = this.kWindowLength * this.kSampleRate / this.kChunkLength; - this.audio_buffer = new SIPOAudioBuffer(num_chunks, this.kChunkLength); - - this.context = new AudioContext({ - sampleRate: this.kSampleRate, - }); - } - - private print = (text: string) => { - if (this.transcribed_callback != null) { - let block = new TranscriptBlock(); - block.text = text; - this.transcribed_callback([block], new TranscriptBlock()); - } - } - - private printDebug = (text: string) => { - console.log(text); + // + // Audio params + // + // Whisper expects 16kHz mono PCM + private kSampleRate = 16000; + + // Length of the suffix of the captured audio that whisper processes each time (in seconds) + private kWindowLength = 5; + + // Number of samples in each audio chunk the recorder gives us + private kChunkLength = 128; + + /** + * Audio context and buffer used to capture speech + */ + private context: AudioContext; + private audio_buffer: SIPOAudioBuffer; + private recorder?: RecordRTC; + + /** + * Instance of the whisper wasm module, and its variables + */ + private whisper: any = null; + private model_name: string = ''; + private model_index: number = -1; + private language: string = 'en'; + private num_threads: number; + + private transcribed_callback: + | ((newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void) + | null = null; + + /** + * Creates an Whisper recognizer instance that listens to the default microphone + * and expects speech in the given language + * + * @param audioSource Not implemented yet + * @param language Language code for whisper (e.g. "en") + * @param num_threads Number of worker threads that whisper uses + * @param model Whisper model key (e.g. "tiny-en-q5_1", "tiny-q5_1", "tiny-multi") + */ + constructor( + audioSource: any, + language: string, + num_threads: number = 4, + model: string = 'tiny-en-q5_1', + ) { + this.num_threads = num_threads; + this.language = language; + this.model_name = model; + + const num_chunks = (this.kWindowLength * this.kSampleRate) / this.kChunkLength; + this.audio_buffer = new SIPOAudioBuffer(num_chunks, this.kChunkLength); + + this.context = new AudioContext({ + sampleRate: this.kSampleRate, + }); + } + + private print = (text: string) => { + if (this.transcribed_callback != null) { + const block = new TranscriptBlock(); + block.text = text; + this.transcribed_callback([block], new TranscriptBlock()); } - - private storeFS(fname, buf) { - // write to WASM file using FS_createDataFile - // if the file exists, delete it + }; + + private printDebug = (text: string) => { + console.log(text); + }; + + private isEnglishOnlyModel(): boolean { + // Whisper model keys containing "-en" or ending with ".en" are English-only. + return this.model_name.includes('-en') || this.model_name.endsWith('.en'); + } + + // Try multiple setter name shapes to match whatever the wasm build exported. + private trySetter(names: string[], ...args: any[]): boolean { + for (const n of names) { + const fn = (this.whisper as any)?.[n]; + if (typeof fn === 'function') { try { - this.whisper.FS_unlink(fname); + fn.apply(this.whisper, args); + return true; } catch (e) { - // ignore + console.debug(`Whisper: setter ${n} threw`, e); } - this.whisper.FS_createDataFile("/", fname, buf, true, true); - this.printDebug('storeFS: stored model: ' + fname + ' size: ' + buf.length); - } - - - /** - * Async load the WASM module, ggml model, and Audio Worklet needed for whisper to work - */ - public async loadWhisper() { - // Load wasm and ggml - this.whisper = await makeWhisper({ - print: this.print, - printErr: this.printDebug, - setStatus: function(text) { - this.printErr('js: ' + text); - }, - monitorRunDependencies: function(left) { - } - }) - await this.load_model(this.model_name); - this.model_index = this.whisper.init('whisper.bin'); - - console.log("Whisper: Done instantiating whisper", this.whisper, this.model_index); - - // Set up audio source - let mic_stream = await navigator.mediaDevices.getUserMedia({audio: true, video: false}); - // let source = this.context.createMediaStreamSource(mic_stream); - - let last_suffix = new Float32Array(0); - this.recorder = new RecordRTC(mic_stream, { - type: 'audio', - mimeType: 'audio/wav', - desiredSampRate: this.kSampleRate, - timeSlice: 250, - ondataavailable: async (blob: Blob) => { - // Convert wav chunk to PCM - const array_buffer = await blob.arrayBuffer(); - const {channelData} = await WavDecoder.decode(array_buffer); - // Should be 16k, float32, stereo pcm data - // Just get 1 channel - let pcm_data = channelData[0]; - - // Prepend previous suffix and update with current suffix - pcm_data = Float32Concat(last_suffix, pcm_data); - last_suffix = pcm_data.slice(-(pcm_data.length % 128)) - - // Feed process_recorder_message audio in 128 sample chunks - for (let i = 0; i < pcm_data.length - 127; i+= 128) { - const audio_chunk = pcm_data.subarray(i, i + 128) - - this.process_recorder_message(audio_chunk); - } - - }, - recorderType: StereoAudioRecorder, - numberOfAudioChannels: 1, - }); - - this.recorder.startRecording(); - console.log("Whisper: Done setting up audio context"); + } } - - private async load_model(model: string) { - let urls = { - 'tiny.en': 'ggml-model-whisper-tiny.en.bin', - 'tiny': 'ggml-model-whisper-tiny.bin', - 'base.en': 'ggml-model-whisper-base.en.bin', - 'base': 'ggml-model-whisper-base.bin', - 'small.en': 'ggml-model-whisper-small.en.bin', - 'small': 'ggml-model-whisper-small.bin', - - 'tiny-en-q5_1': 'ggml-model-whisper-tiny.en-q5_1.bin', - 'tiny-q5_1': 'ggml-model-whisper-tiny-q5_1.bin', - 'base-en-q5_1': 'ggml-model-whisper-base.en-q5_1.bin', - 'base-q5_1': 'ggml-model-whisper-base-q5_1.bin', - 'small-en-q5_1': 'ggml-model-whisper-small.en-q5_1.bin', - 'small-q5_1': 'ggml-model-whisper-small-q5_1.bin', - 'medium-en-q5_0':'ggml-model-whisper-medium.en-q5_0.bin', - 'medium-q5_0': 'ggml-model-whisper-medium-q5_0.bin', - 'large-q5_0': 'ggml-model-whisper-large-q5_0.bin', - }; - let sizes = { - 'tiny.en': 75, - 'tiny': 75, - 'base.en': 142, - 'base': 142, - 'small.en': 466, - 'small': 466, - - 'tiny-en-q5_1': 31, - 'tiny-q5_1': 31, - 'base-en-q5_1': 57, - 'base-q5_1': 57, - 'small-en-q5_1': 182, - 'small-q5_1': 182, - 'medium-en-q5_0': 515, - 'medium-q5_0': 515, - 'large-q5_0': 1030, - }; - - let url = process.env.PUBLIC_URL + "/models/" + urls[model]; - let dst = 'whisper.bin'; - let size_mb = sizes[model]; - - // HACK: turn loadRemote into a promise so that we can chain .then - let that = this; - return new Promise((resolve, reject) => { - loadRemote(url, - dst, - size_mb, - (text) => {}, - (fname, buf) => { - that.storeFS(fname, buf); - resolve(); - }, - () => { - reject(); - }, - that.printDebug, - ); - }) + return false; + } + + /** Apply language/translation/thread settings to the wasm instance (best-effort). */ + private applyLanguageSettings() { + if (!this.whisper || this.model_index < 0) return; + + const lang = this.language || 'en'; + // If model is multilingual and language is still 'en', prefer auto-detect. + const detect = + lang === 'auto' || (!this.isEnglishOnlyModel() && lang === 'en') ? 1 : 0; + + // language / detect + this.trySetter(['set_detect_language', 'setDetectLanguage'], this.model_index, detect ? 1 : 0); + if (!detect) { + this.trySetter(['set_language', 'setLanguage'], this.model_index, lang); } - /** - * Helper method that stores audio chunks from the raw recorder in buffer - * @param audio_chunk Float32Array containing an audio chunk - */ - private process_recorder_message(audio_chunk: Float32Array) { - this.audio_buffer.push(audio_chunk); - if (this.audio_buffer.isFull()) { - this.whisper.set_audio(this.model_index, this.audio_buffer.getAll()); - this.audio_buffer.clear(); - } - } + // Force transcribe (task 0) and translate disabled + const translateSet = + this.trySetter(['set_translate', 'setTranslate'], this.model_index, 0) || + this.trySetter(['set_translate', 'setTranslate'], this.model_index, false); + const taskSet = this.trySetter(['set_task', 'setTask'], this.model_index, 0); - private process_analyzer_result(features: any) { + // Threads if supported + this.trySetter(['set_threads', 'setThreads'], this.model_index, this.num_threads); + if (!translateSet || !taskSet) { + console.debug('Whisper: translate/task setters not found; build may default to translate=true'); } - - /** - * Makes the Whisper recognizer start transcribing speech, if not already started - * Throws exception if recognizer fails to start - */ - start() { - console.log("trying to start whisper"); - let that = this; - if (this.whisper == null || this.model_index == -1) { - this.loadWhisper().then(() => { - that.whisper.setStatus(""); - that.context.resume(); - }) - } else { - this.whisper.setStatus(""); - this.context.resume(); - } + } + + /** Allow runtime language updates if control language changes. */ + public setLanguage(language: string) { + if (!language) return; + this.language = language; + this.applyLanguageSettings(); + } + + private storeFS(fname: string, buf: Uint8Array) { + // write to WASM file using FS_createDataFile + // if the file exists, delete it + try { + this.whisper.FS_unlink(fname); + } catch (e) { + // ignore } - - /** - * Makes the Whisper recognizer stop transcribing speech asynchronously - * Throws exception if recognizer fails to stop - */ - stop() { - if (this.whisper == null || this.model_index === -1) { - return; + this.whisper.FS_createDataFile('/', fname, buf, true, true); + this.printDebug('storeFS: stored model: ' + fname + ' size: ' + buf.length); + } + + /** + * Async load the WASM module, ggml model, and Audio Worklet needed for whisper to work + */ + public async loadWhisper() { + // Load wasm and ggml + this.whisper = await makeWhisper({ + print: this.print, + printErr: this.printDebug, + setStatus: (text: string) => { + this.printDebug('js: ' + text); + }, + monitorRunDependencies: (_left: number) => {}, + }); + + await this.load_model(this.model_name); + + this.model_index = this.whisper.init('whisper.bin'); + console.log('Whisper: Done instantiating whisper', this.whisper, this.model_index); + this.applyLanguageSettings(); + + // Set up audio source + const mic_stream = await navigator.mediaDevices.getUserMedia({ + audio: true, + video: false, + }); + + let last_suffix = new Float32Array(0); + + this.recorder = new RecordRTC(mic_stream, { + type: 'audio', + mimeType: 'audio/wav', + desiredSampRate: this.kSampleRate, + timeSlice: 250, + ondataavailable: async (blob: Blob) => { + // Convert wav chunk to PCM + const array_buffer = await blob.arrayBuffer(); + const { channelData } = await WavDecoder.decode(array_buffer); + + // Should be 16k, float32, stereo pcm data + // Just get 1 channel + let pcm_data = channelData[0]; + + // Prepend previous suffix and update with current suffix + pcm_data = Float32Concat(last_suffix, pcm_data); + last_suffix = pcm_data.slice(-(pcm_data.length % this.kChunkLength)); + + // Feed process_recorder_message audio in kChunkLength sample chunks + for (let i = 0; i + this.kChunkLength <= pcm_data.length; i += this.kChunkLength) { + const audio_chunk = pcm_data.subarray(i, i + this.kChunkLength); + this.process_recorder_message(audio_chunk); } - this.whisper.set_status("paused"); - this.context.suspend(); - this.recorder?.stopRecording(); + }, + recorderType: StereoAudioRecorder, + numberOfAudioChannels: 1, + }); + + this.recorder.startRecording(); + console.log('Whisper: Done setting up audio context'); + } + + private async load_model(model: string) { + const urls: Record = { + // Prefer the filenames that actually exist in webmodels/; we add + // whisper-prefixed variants as fallbacks below. + // Prefer q5_1 variants we actually host; fallbacks add other name shapes. + 'tiny.en': 'ggml-tiny.en-q5_1.bin', + tiny: 'ggml-tiny-q5_1.bin', + 'base.en': 'ggml-base.en-q5_1.bin', + base: 'ggml-base-q5_1.bin', + 'small.en': 'ggml-small.en-q5_1.bin', + small: 'ggml-small-q5_1.bin', + + 'tiny-en-q5_1': 'ggml-tiny.en-q5_1.bin', + 'tiny-q5_1': 'ggml-tiny-q5_1.bin', + + // multilingual tiny aliases + 'tiny-multi': 'ggml-tiny-q5_1.bin', + 'tiny-multi-q5_1': 'ggml-tiny-q5_1.bin', + + 'base-en-q5_1': 'ggml-base.en-q5_1.bin', + 'base-q5_1': 'ggml-base-q5_1.bin', + 'small-en-q5_1': 'ggml-small.en-q5_1.bin', + 'small-q5_1': 'ggml-small-q5_1.bin', + 'medium-en-q5_0': 'ggml-medium.en-q5_0.bin', + 'medium-q5_0': 'ggml-medium-q5_0.bin', + 'large-q5_0': 'ggml-large-q5_0.bin', + }; + + const sizes: Record = { + 'tiny.en': 75, + tiny: 75, + 'base.en': 142, + base: 142, + 'small.en': 466, + small: 466, + + 'tiny-en-q5_1': 31, + 'tiny-q5_1': 31, + 'tiny-multi': 31, + 'tiny-multi-q5_1': 31, + + 'base-en-q5_1': 57, + 'base-q5_1': 57, + 'small-en-q5_1': 182, + 'small-q5_1': 182, + 'medium-en-q5_0': 515, + 'medium-q5_0': 515, + 'large-q5_0': 1030, + }; + + const filename = urls[model]; + + if (!filename) { + this.printDebug(`Whisper: unknown model key "${model}", falling back to "tiny-en-q5_1"`); + this.model_name = 'tiny-en-q5_1'; + return this.load_model(this.model_name); } - /** - * Subscribe a callback function to the transcript update event, which is usually triggered - * when the recognizer has processed more speech or finalized some in-progress part - * @param callback A callback function called with updates to the transcript - */ - onTranscribed(callback: (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void) { - this.transcribed_callback = callback; + const size_mb = sizes[model] ?? 0; + const dst = 'whisper.bin'; + + const buildUrls = (fname: string): string[] => { + const variants = new Set(); + + // 1) canonical + variants.add(fname); + + // 2) add whisper-prefixed variant if not already present + if (!fname.includes('model-whisper-')) { + const withWhisper = fname.replace('ggml-', 'ggml-model-whisper-'); + variants.add(withWhisper); + } + + // 3) add plain ggml variant (in case fname carried the whisper prefix) + if (fname.includes('model-whisper-')) { + const plain = fname.replace('ggml-model-whisper-', 'ggml-'); + variants.add(plain); + } + + const base = WEBMODEL_BASE_URL.endsWith('/') + ? WEBMODEL_BASE_URL.slice(0, -1) + : WEBMODEL_BASE_URL; + + return Array.from(variants).map((v) => `${base}/${v}`); + }; + + const urlsToTry = buildUrls(filename); + + const that = this; + + return new Promise((resolve, reject) => { + loadRemoteWithFallbacks( + urlsToTry, + dst, + size_mb, + (_text: string) => {}, + (fname: string, buf: Uint8Array) => { + that.storeFS(fname, buf); + resolve(); + }, + () => { + reject(); + }, + that.printDebug, + ); + }); + } + + /** + * Helper method that stores audio chunks from the raw recorder in buffer + * @param audio_chunk Float32Array containing an audio chunk + */ + private process_recorder_message(audio_chunk: Float32Array) { + this.audio_buffer.push(audio_chunk); + + if (this.audio_buffer.isFull()) { + this.whisper.set_audio(this.model_index, this.audio_buffer.getAll()); + this.audio_buffer.clear(); } - - /** - * Subscribe a callback function to the error event, which is triggered - * when the recognizer has encountered an error - * @param callback A callback function called with the error object when the event is triggered - */ - onError(callback: (e: Error) => void) { - + } + + // Placeholder – we don't currently use analyzer features + private process_analyzer_result(_features: any) {} + + /** + * Makes the Whisper recognizer start transcribing speech, if not already started + * Throws exception if recognizer fails to start + */ + start() { + console.log('trying to start whisper'); + + if (this.whisper == null || this.model_index === -1) { + this.loadWhisper().then(() => { + this.whisper.setStatus(''); + this.context.resume(); + }); + } else { + this.whisper.setStatus(''); + this.context.resume(); } - + } + + /** + * Makes the Whisper recognizer stop transcribing speech asynchronously + * Throws exception if recognizer fails to stop + */ + stop() { + if (this.whisper == null || this.model_index === -1) { + return; + } + this.whisper.set_status('paused'); + this.context.suspend(); + this.recorder?.stopRecording(); + } + + /** + * Subscribe a callback function to the transcript update event, which is usually triggered + * when the recognizer has processed more speech or finalized some in-progress part + * @param callback A callback function called with updates to the transcript + */ + onTranscribed( + callback: (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void, + ) { + this.transcribed_callback = callback; + } + + /** + * Subscribe a callback function to the error event, which is triggered + * when the recognizer has encountered an error + * @param callback A callback function called with the error object when the event is triggered + */ + onError(_callback: (e: Error) => void) { + // TODO: wire up whisper error channel if/when libstream exposes one + } } - diff --git a/src/components/navbars/sidebar/model/menu.tsx b/src/components/navbars/sidebar/model/menu.tsx index 72809abc..4093a65f 100644 --- a/src/components/navbars/sidebar/model/menu.tsx +++ b/src/components/navbars/sidebar/model/menu.tsx @@ -1,56 +1,64 @@ +// src/components/navbars/sidebar/model/menu.tsx + import React from 'react'; +import { useDispatch, useSelector } from 'react-redux'; +import { + Autocomplete, + TextField, + Tooltip, + ListItem, +} from '../../../../muiImports'; -import { List, ListItemText, Collapse, ListItem, MemoryIcon, Autocomplete, TextField, Tooltip } from '../../../../muiImports'; -import { useSelector } from 'react-redux'; import type { RootState } from '../../../../store'; -import { API, type ApiStatus } from '../../../../react-redux&middleware/redux/typesImports'; -import { useDispatch } from 'react-redux'; -import { selectModelOptions, selectSelectedModel, setSelectedModel } from '../../../../react-redux&middleware/redux/reducers/modelSelectionReducers'; +import type { SelectedOption } from '../../../../react-redux&middleware/redux/types/modelSelection'; +import { + selectModelOptions, + selectSelectedModel, + setSelectedModel, +} from '../../../../react-redux&middleware/redux/reducers/modelSelectionReducers'; -export default function ModelMenu(props) { +export default function ModelMenu(_props: any) { const dispatch = useDispatch(); - const APIStatus = useSelector((state: RootState) => { - return state.APIStatusReducer as ApiStatus; - }); - const modelOptions = useSelector(selectModelOptions); - const selected = useSelector(selectSelectedModel); - const modelSelectEnable = APIStatus.currentApi !== API.SCRIBEAR_SERVER; - - return ( -
- {props.listItemHeader("Model", "model", MemoryIcon)} - - - - - - - - { - dispatch(setSelectedModel(val)) - }} - defaultValue={selected} - getOptionLabel={(v) => v.display_name} - isOptionEqualToValue={(a, b) => a.model_key === b.model_key} - renderInput={(params) => } - renderOption={(props, option) => { - return - - {option.display_name} - - - }} - /> - + const modelOptions = useSelector((state: RootState) => + selectModelOptions(state), + ); + const selected = useSelector((state: RootState) => + selectSelectedModel(state), + ); - - -
+ return ( + + sx={{ width: 300 }} + disablePortal + options={modelOptions} + value={selected} + onChange={(_, val) => { + if (val) { + dispatch(setSelectedModel(val)); + } + }} + getOptionLabel={(v: SelectedOption | null) => + v ? v.display_name : '' + } + isOptionEqualToValue={( + a: SelectedOption | null, + b: SelectedOption | null, + ) => !!a && !!b && a.model_key === b.model_key} + renderInput={(params) => } + renderOption={(props, option) => { + const opt = option as SelectedOption | null; + if (!opt) return null; + return ( + + {opt.display_name} + + ); + }} + /> ); -} \ No newline at end of file +} diff --git a/src/components/navbars/topbar/api/WhisperDropdown.tsx b/src/components/navbars/topbar/api/WhisperDropdown.tsx index c18229b6..6850da1f 100644 --- a/src/components/navbars/topbar/api/WhisperDropdown.tsx +++ b/src/components/navbars/topbar/api/WhisperDropdown.tsx @@ -1,13 +1,13 @@ +// src/components/navbars/topbar/api/WhisperDropdown.tsx import React, { useState } from 'react'; import { useDispatch, useSelector } from 'react-redux'; -// Pull UI pieces from the same aggregator used elsewhere import { IconButton, Tooltip, Menu, MenuItem, - SettingsIcon, // this should be re-exported by your muiImports like other icons + SettingsIcon, } from '../../../../muiImports'; import { selectSelectedModel } from @@ -15,64 +15,70 @@ import { selectSelectedModel } from type Props = { onPicked?: () => void }; -type ModelKey = 'tiny' | 'base'; - -// Helper to normalize current selection into 'tiny' | 'base' | undefined -function normalizeSelected(selected: unknown): ModelKey | undefined { - if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; - } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; - } - return undefined; -} +// Only keep tiny models (keep legacy 'tiny' for backward-compat) +type ModelKey = 'tiny-en-q5_1' | 'tiny-q5_1' | 'tiny'; export default function WhisperDropdown({ onPicked }: Props) { const dispatch = useDispatch(); - const selected = useSelector(selectSelectedModel); - const selectedKey = normalizeSelected(selected); const [anchorEl, setAnchorEl] = useState(null); - const open = Boolean(anchorEl); + const isShown = Boolean(anchorEl); + + const selected = useSelector(selectSelectedModel); + const selectedKey: ModelKey = React.useMemo(() => { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); + if (typeof selected === 'string') { + if (selected === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(selected) ? selected : 'tiny-en-q5_1') as ModelKey; + } + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key || (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(k) ? k : 'tiny-en-q5_1') as ModelKey; + } + return 'tiny-en-q5_1'; + }, [selected]); - const handleOpen = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); + const showPopup = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); const handleClose = () => setAnchorEl(null); const pick = (which: ModelKey) => { - // Keep the existing flags your loader is watching - sessionStorage.setItem('isDownloadTiny', String(which === 'tiny')); - sessionStorage.setItem('isDownloadBase', String(which === 'base')); + // Maintain legacy session flags for any old code paths + sessionStorage.setItem('isDownloadTiny', 'true'); + sessionStorage.setItem('isDownloadBase', 'false'); - // Update Redux – if you have a real action creator, use it here dispatch({ type: 'SET_SELECTED_MODEL', payload: which as any }); - - handleClose(); onPicked?.(); + handleClose(); }; return ( <> - {/* Right-aligned gear like other providers */} - - + + - pick('tiny')}> - TINY (75 MB) + pick('tiny-en-q5_1')} + > + TINY (EN, q5_1) ~31 MB - pick('base')}> - BASE (145 MB) + pick('tiny-q5_1')} + > + TINY (Multi, q5_1) ~31 MB diff --git a/src/components/navbars/topbar/api/WhisperSettings.tsx b/src/components/navbars/topbar/api/WhisperSettings.tsx deleted file mode 100644 index 19e5f518..00000000 --- a/src/components/navbars/topbar/api/WhisperSettings.tsx +++ /dev/null @@ -1,104 +0,0 @@ -// src/components/navbars/topbar/api/WhisperSettings.tsx -import * as React from 'react'; -import { useDispatch, useSelector } from 'react-redux'; - -import { - Box, - Menu, - List, - ListItem, - IconButton, - SettingsIcon, - Button, -} from '../../../../muiImports'; - -import { selectSelectedModel } from - '../../../../react-redux&middleware/redux/reducers/modelSelectionReducers'; - -type ModelKey = 'tiny' | 'base'; - -export default function WhisperSettings() { - const [anchorEl, setAnchorEl] = React.useState(null); - const isShown = Boolean(anchorEl); - - const dispatch = useDispatch(); - const selected = useSelector(selectSelectedModel); - - // Normalize current selection to 'tiny' | 'base' | undefined - const selectedKey: ModelKey | undefined = React.useMemo(() => { - if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; - } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; - } - return undefined; - }, [selected]); - - const showPopup = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); - const closePopup = () => setAnchorEl(null); - - const pick = (which: ModelKey) => { - // Keep flags if your loader watches them - sessionStorage.setItem('isDownloadTiny', (which === 'tiny').toString()); - sessionStorage.setItem('isDownloadBase', (which === 'base').toString()); - - // Update redux - dispatch({ type: 'SET_SELECTED_MODEL', payload: which as any }); - - closePopup(); - }; - - return ( - <> - {/* Right-side gear (same pattern as Azure/ScribeAR Server) */} - - - - - - - - - - - - - - - - - ); -} diff --git a/src/components/navbars/topbar/api/pickApi.tsx b/src/components/navbars/topbar/api/pickApi.tsx index bba360f8..57cd3c51 100644 --- a/src/components/navbars/topbar/api/pickApi.tsx +++ b/src/components/navbars/topbar/api/pickApi.tsx @@ -28,7 +28,7 @@ import { ListItemText, ThemeProvider, createTheme, - Chip, // from muiImports barrel + Chip, } from '../../../../muiImports'; import { ListItem } from '@mui/material'; @@ -36,7 +36,7 @@ import AzureSettings from './AzureSettings'; import StreamTextSettings from './StreamTextSettings'; import ScribearServerSettings from './ScribearServerSettings'; import PlaybackSettings from './PlaybackSettings'; -import WhisperSettings from './WhisperSettings'; // gear dialog for Whisper +import WhisperSettings from './WhisperDropdown'; import swal from 'sweetalert'; import { testAzureTranslRecog } from '../../../api/azure/azureTranslRecog'; @@ -83,15 +83,36 @@ const IconStatus = (currentApi: any) => { } }; -function getSelectedModelKey(selected: unknown): 'tiny' | 'base' | undefined { +/** Normalize various saved shapes into one of our TWO tiny model keys; default to tiny-en-q5_1 */ +function getSelectedModelKey( + selected: unknown +): 'tiny-en-q5_1' | 'tiny-q5_1' { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); + if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; + if (selected === 'tiny') return 'tiny-en-q5_1'; + // if someone still has 'base' persisted, fall back to tiny-en + if (selected === 'base') return 'tiny-en-q5_1'; + return allowed.has(selected) ? (selected as any) : 'tiny-en-q5_1'; + } + + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key ?? (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + if (k === 'base') return 'tiny-en-q5_1'; + return allowed.has(k) ? (k as any) : 'tiny-en-q5_1'; } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; + + return 'tiny-en-q5_1'; +} + +/** Short label for topbar chip */ +function getSelectedModelLabel(k?: string) { + switch (k) { + case 'tiny-en-q5_1': return 'TINY-EN'; + case 'tiny-q5_1': return 'TINY-MULTI'; + default: return 'TINY-EN'; } - return undefined; } export default function PickApi() { @@ -261,7 +282,7 @@ export default function PickApi() { - + @@ -306,7 +327,7 @@ export default function PickApi() { )} diff --git a/src/ml/inference.js b/src/ml/inference.js index d3a824e7..5cb42117 100644 --- a/src/ml/inference.js +++ b/src/ml/inference.js @@ -3,57 +3,109 @@ /*global BigInt64Array */ import { loadTokenizer } from './bert_tokenizer.ts'; -// import * as wasmFeatureDetect from 'wasm-feature-detect'; - -//Setup onnxruntime +// Setup onnxruntime const ort = require('onnxruntime-web'); -//requires Cross-Origin-*-policy headers https://web.dev/coop-coep/ -/** -const simdResolver = wasmFeatureDetect.simd().then(simdSupported => { - console.log("simd is supported? "+ simdSupported); - if (simdSupported) { - ort.env.wasm.numThreads = 3; - ort.env.wasm.simd = true; - } else { - ort.env.wasm.numThreads = 1; - ort.env.wasm.simd = false; - } -}); -*/ +// --- CONFIG ----------------------------------------------------------------- + +// We resolve URLs against PUBLIC_URL (prod) or root (dev) +function resolvePublicUrl(path) { + const base = + (process.env.NODE_ENV === 'development' + ? '' + : (process.env.PUBLIC_URL || '') + ).replace(/\/$/, ''); + return `${base}${path.startsWith('/') ? '' : '/'}${path}`; +} -const options = { - executionProviders: ['wasm'], - graphOptimizationLevel: 'all' +// Place these files in public/models/onnx/ +const LM_MODEL_URL = resolvePublicUrl('/models/onnx/xtremedistill-go-emotion-int8.onnx'); +// Note: parentheses are valid in URLs; we encode to be safe. +const GRU_MODEL_URL = resolvePublicUrl('/models/onnx/' + encodeURI('gru_embedder(1,40,80).onnx')); + +const ORT_OPTIONS = { + executionProviders: ['wasm'], + graphOptimizationLevel: 'all', }; -var downLoadingModel = true; -const model = "./xtremedistill-go-emotion-int8.onnx"; -// const gruModel = "./gru_embedder(1,21,80).onnx"; -const gruModel = "./gru_embedder(1,40,80).onnx"; -const gruSession = ort.InferenceSession.create(gruModel, options); +// --- INTERNAL STATE --------------------------------------------------------- + +let downloadingModel = false; + +let lmSessionPromise = null; // Promise +let gruSessionPromise = null; // Promise + +let lmDisabled = false; +let gruDisabled = false; + +// tokenizer promise (unchanged) +const tokenizer = loadTokenizer(); + +// --- HELPERS ---------------------------------------------------------------- + +async function fetchModelBytes(url) { + // Fetch to ArrayBuffer first so we can reject HTML/JSON and tiny files + const res = await fetch(url, { cache: 'no-cache' }); + if (!res.ok) throw new Error(`HTTP ${res.status} @ ${url}`); + const ct = (res.headers.get('content-type') || '').toLowerCase(); + const buf = await res.arrayBuffer(); + + // Reject clearly-wrong payloads (HTML/JSON or super tiny files) + if (buf.byteLength < 2048 || ct.includes('text/html') || ct.includes('json')) { + throw new Error(`Invalid model payload (type=${ct}, size=${buf.byteLength}) @ ${url}`); + } + return buf; +} + +function sortResult(a, b) { + if (a[1] === b[1]) return 0; + return (a[1] < b[1]) ? 1 : -1; +} + +function sigmoid(t) { + return 1 / (1 + Math.exp(-t)); +} + +// Build BERT inputs (BigInt) +function create_model_input(encoded) { + let input_ids = new Array(encoded.length + 2); + let attention_mask = new Array(encoded.length + 2); + let token_type_ids = new Array(encoded.length + 2); + input_ids[0] = BigInt(101); + attention_mask[0] = BigInt(1); + token_type_ids[0] = BigInt(0); -const session = ort.InferenceSession.create(model, options); -session.then(t => { - downLoadingModel = false; - //warmup the VM - for(var i = 0; i < 10; i++) { - console.log("Inference warmup " + i); - lm_inference("this is a warmup inference"); + let i = 0; + for (; i < encoded.length; i++) { + input_ids[i + 1] = BigInt(encoded[i]); + attention_mask[i + 1] = BigInt(1); + token_type_ids[i + 1] = BigInt(0); } -}); -const tokenizer = loadTokenizer() + input_ids[i + 1] = BigInt(102); + attention_mask[i + 1] = BigInt(1); + token_type_ids[i + 1] = BigInt(0); + + const seq = input_ids.length; + + input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1, seq]); + attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1, seq]); + token_type_ids = new ort.Tensor('int64', BigInt64Array.from(token_type_ids), [1, seq]); + + return { input_ids, attention_mask, token_type_ids }; +} + +// --- CONSTANTS -------------------------------------------------------------- const EMOJI_DEFAULT_DISPLAY = [ - ["Emotion", "Score"], - ['admiration πŸ‘',0], - ['amusement πŸ˜‚', 0], - ['neutral 😐',0], - ['approval πŸ‘',0], - ['joy πŸ˜ƒ',0], - ['gratitude πŸ™',0], + ['Emotion', 'Score'], + ['admiration πŸ‘', 0], + ['amusement πŸ˜‚', 0], + ['neutral 😐', 0], + ['approval πŸ‘', 0], + ['joy πŸ˜ƒ', 0], + ['gratitude πŸ™', 0], ]; const EMOJIS = [ @@ -80,112 +132,128 @@ const EMOJIS = [ 'optimism 🀞', 'pride 😌', 'realization πŸ’‘', - 'reliefπŸ˜…', - 'remorse 😞', + 'relief πŸ˜…', + 'remorse 😞', 'sadness 😞', 'surprise 😲', - 'neutral 😐' + 'neutral 😐', ]; -function isDownloading() { - return downLoadingModel; -} +// --- LAZY SESSION INITIALIZERS --------------------------------------------- -function sortResult(a, b) { - if (a[1] === b[1]) { - return 0; - } - else { - return (a[1] < b[1]) ? 1 : -1; - } +async function ensureLmSession() { + if (lmDisabled) return null; + if (lmSessionPromise) return lmSessionPromise; + + downloadingModel = true; + lmSessionPromise = (async () => { + try { + const bytes = await fetchModelBytes(LM_MODEL_URL); + const session = await ort.InferenceSession.create(bytes, ORT_OPTIONS); + return session; + } catch (err) { + console.warn('[inference] LM disabled:', err); + lmDisabled = true; + return null; + } finally { + downloadingModel = false; + } + })(); + + return lmSessionPromise; } -function sigmoid(t) { - return 1/(1+Math.pow(Math.E, -t)); +async function ensureGruSession() { + if (gruDisabled) return null; + if (gruSessionPromise) return gruSessionPromise; + + downloadingModel = true; + gruSessionPromise = (async () => { + try { + const bytes = await fetchModelBytes(GRU_MODEL_URL); + const session = await ort.InferenceSession.create(bytes, ORT_OPTIONS); + return session; + } catch (err) { + console.warn('[inference] GRU disabled:', err); + gruDisabled = true; + return null; + } finally { + downloadingModel = false; + } + })(); + + return gruSessionPromise; } -function create_model_input(encoded) { - var input_ids = new Array(encoded.length+2); - var attention_mask = new Array(encoded.length+2); - var token_type_ids = new Array(encoded.length+2); - input_ids[0]Β = BigInt(101); - attention_mask[0]Β = BigInt(1); - token_type_ids[0]Β = BigInt(0); - var i = 0; - for(; i < encoded.length; i++) { - input_ids[i+1] = BigInt(encoded[i]); - attention_mask[i+1] = BigInt(1); - token_type_ids[i+1] = BigInt(0); - } - input_ids[i+1]Β = BigInt(102); - attention_mask[i+1]Β = BigInt(1); - token_type_ids[i+1]Β = BigInt(0); - const sequence_length = input_ids.length; - input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1,sequence_length]); - attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1,sequence_length]); - token_type_ids = new ort.Tensor('int64', BigInt64Array.from(token_type_ids), [1,sequence_length]); - return { - input_ids: input_ids, - attention_mask: attention_mask, - token_type_ids:token_type_ids - } +// --- PUBLIC API ------------------------------------------------------------- + +function isDownloading() { + return downloadingModel; } async function lm_inference(text) { - try { - const encoded_ids = await tokenizer.then(t => { - return t.tokenize(text); - }); - if(encoded_ids.length === 0) { + try { + const session = await ensureLmSession(); + if (!session) return [0.0, EMOJI_DEFAULT_DISPLAY]; + + const encoded_ids = await tokenizer.then(t => t.tokenize(text)); + if (!encoded_ids || encoded_ids.length === 0) { return [0.0, EMOJI_DEFAULT_DISPLAY]; } + const start = performance.now(); - const model_input = create_model_input(encoded_ids); - const output = await session.then(s => { return s.run(model_input,['output_0'])}); + const feeds = create_model_input(encoded_ids); + const output = await session.run(feeds, ['output_0']); + const duration = (performance.now() - start).toFixed(1); - const probs = output['output_0'].data.map(sigmoid).map( t => Math.floor(t*100)); - // console.log(147, 'probs: ', probs); + const probs = output['output_0'].data + .map(sigmoid) + .map(t => Math.floor(t * 100)); + const result = []; - for(var i = 0; i < EMOJIS.length;i++) { - const t = [EMOJIS[i], probs[i]]; - result[i] = t; + for (let i = 0; i < EMOJIS.length; i++) { + result[i] = [EMOJIS[i], probs[i]]; } - result.sort(sortResult); - + result.sort(sortResult); + const result_list = []; - result_list[0] = ["Emotion", "Score"]; - for(i = 0; i < 6; i++) { - result_list[i+1] = result[i]; - } - return [duration,result_list]; - } catch (e) { - return [0.0,EMOJI_DEFAULT_DISPLAY]; + result_list[0] = ['Emotion', 'Score']; + for (let i = 0; i < 6; i++) result_list[i + 1] = result[i]; + + return [duration, result_list]; + } catch (_e) { + // Swallow errors to keep UI smooth + return [0.0, EMOJI_DEFAULT_DISPLAY]; } -} +} async function gru_inference(melLogSpectrogram) { try { + const session = await ensureGruSession(); + if (!session) return null; + + // Expecting shape [1, 40, 80] + const input = new ort.Tensor( + 'float32', + melLogSpectrogram.flat(), + [1, 40, 80] + ); - const inputs = ort.InferenceSession.FeedsType = { - // input: new ort.Tensor('float32', melLogSpectrogram.flat(), [1, 21, 80]) - input: new ort.Tensor('float32', melLogSpectrogram.flat(), [1, 40, 80]) - }; - // console.log(171, 'input: ', inputs.input.data); - // console.log('input sum: ', inputs.inpumot.data.reduce((a, b) => a + b, 0)); - // console.log(inputs.input); - const output = await gruSession.then(s => { return s.run(inputs)}); - // console.log(175, 'output: ', output.output.data); - // sum all data - // const sum = output.output.data.reduce((a, b) => a + b, 0); - // console.log('output sum: ', sum); - return output; + const outputs = await session.run({ input }); + return outputs; } catch (e) { - console.log(e); + console.warn('[inference] GRU inference error:', e); + return null; } } +// Named exports kept for compatibility with your codebase +export let intent_inference = lm_inference; +export let columnNames = EMOJI_DEFAULT_DISPLAY; +export let modelDownloadInProgress = isDownloading; +export let gruInference = gru_inference; -export let intent_inference = lm_inference -export let columnNames = EMOJI_DEFAULT_DISPLAY -export let modelDownloadInProgress = isDownloading -export let gruInference = gru_inference +// Optional: explicit initializer if you ever want to prefetch after a user action +export async function initOnnx() { + await Promise.all([ensureLmSession(), ensureGruSession()]); +} diff --git a/src/react-redux&middleware/redux/reducers/modelSelectionReducers.tsx b/src/react-redux&middleware/redux/reducers/modelSelectionReducers.tsx index 20c1f005..35dec47a 100644 --- a/src/react-redux&middleware/redux/reducers/modelSelectionReducers.tsx +++ b/src/react-redux&middleware/redux/reducers/modelSelectionReducers.tsx @@ -1,52 +1,164 @@ -import type { ModelOptions, ModelSelection, SelectedOption } from '../types/modelSelection' -import type { RootState } from '../typesImports' +// src/react-redux&middleware/redux/reducers/modelSelectionReducers.tsx -const initialState: ModelSelection = { options: [], selected: null } +import { AnyAction } from 'redux'; +import type { SelectedOption } from '../types/modelSelection'; +import type { RootState } from '../../../store'; - -const saveLocally = (key: string, value: any) => { - localStorage.setItem(key, JSON.stringify(value)); +export interface ModelSelectionState { + options: SelectedOption[]; + selected: SelectedOption | null; } -const getLocalState = (key: string) => { - const localState = localStorage.getItem(key); - if (localState) { - return Object.assign(initialState, JSON.parse(localState)); +/** tiny-multi β†’ tiny-q5_1 (canonical key for multilingual tiny) */ +const normalizeModelKey = (raw: string): string => { + if (raw === 'tiny-multi' || raw === 'tiny-multi-q5_1') return 'tiny-q5_1'; + return raw; +}; + +/** Defaults before server sends a model list */ +const defaultOptions: SelectedOption[] = [ + { + model_key: 'tiny-en-q5_1', + display_name: 'tiny-en', + description: 'Whisper tiny English (quantized q5_1)', + available_features: 'en', + }, + { + model_key: 'tiny-q5_1', // multilingual tiny, shown as "tiny-multi" + display_name: 'tiny-multi', + description: 'Whisper tiny multilingual (quantized q5_1)', + available_features: 'multi', + }, +]; + +/** + * Take an arbitrary object and coerce it into a SelectedOption, + * filling in defaults for any missing fields. + */ +const toSelectedOption = (raw: any): SelectedOption => { + if (typeof raw === 'string') { + const model_key = normalizeModelKey(raw); + const fromDefaults = defaultOptions.find( + (opt) => !!opt && opt.model_key === model_key, + ); + return ( + fromDefaults ?? { + model_key, + display_name: model_key, + description: '', + available_features: '', + } + ); } - return initialState; -} + const model_key = normalizeModelKey( + raw?.model_key ?? + raw?.key ?? + raw?.id ?? + raw?.value ?? + 'tiny-en-q5_1', + ); + + const display_name = + raw?.display_name ?? + raw?.label ?? + raw?.name ?? + model_key; + + const description: string = raw?.description ?? ''; + const available_features: string = raw?.available_features ?? ''; -export const ModelSelectionReducer = ( - state = getLocalState('modelSelection'), - action -) => { - let newState; + return { + model_key, + display_name, + description, + available_features, + }; +}; + +export const initialModelSelectionState: ModelSelectionState = { + options: defaultOptions, + selected: defaultOptions[0], +}; + +const ModelSelectionReducer = ( + state: ModelSelectionState = initialModelSelectionState, + action: AnyAction, +): ModelSelectionState => { switch (action.type) { - case 'SET_MODEL_OPTIONS': - newState = { - ...state, - options: action.payload as ModelOptions + // Replace the list of options (e.g. when server sends models) + case 'MODEL_SELECTION/SET_OPTIONS': + case 'SET_MODEL_OPTIONS': { + const payload = action.payload; + let rawOptions: any[]; + + if (Array.isArray(payload)) { + rawOptions = payload; + } else if (Array.isArray(payload?.options)) { + rawOptions = payload.options; + } else { + return state; } - saveLocally('modelSelection', newState) - return newState; - case 'SET_SELECTED_MODEL': - newState = { - ...state, - selected: action.payload as SelectedOption + + const options: SelectedOption[] = rawOptions.map(toSelectedOption); + + // Keep current selection if still present, else fall back to first + let selected = state.selected; + + if (!selected) { + selected = options[0] ?? null; + } else { + const exists = options.some( + (o: SelectedOption | null) => + o?.model_key === selected!.model_key, + ); + if (!exists) { + selected = options[0] ?? null; + } } - saveLocally('modelSelection', newState) - return newState; + + return { + ...state, + options, + selected, + }; + } + + // Change which model is selected + case 'MODEL_SELECTION/SELECT': + case 'SELECT_MODEL': + case 'SET_SELECTED_MODEL': { + if (!action.payload) return state; + const opt = toSelectedOption(action.payload); + return { + ...state, + selected: opt, + }; + } + default: - return state + return state; } -} +}; -export const selectModelOptions = (state: RootState) => state.ModelSelectionReducer.options; -export const setModelOptions = (options: ModelOptions) => { - return { type: 'SET_MODEL_OPTIONS', payload: options } -} -export const selectSelectedModel = (state: RootState) => state.ModelSelectionReducer.selected; -export function setSelectedModel(selected: SelectedOption) { - return { type: 'SET_SELECTED_MODEL', payload: selected } -} \ No newline at end of file +export default ModelSelectionReducer; +// for `import { ModelSelectionReducer } from ...` +export { ModelSelectionReducer }; + +/** Action creators used by scribearRecognizer + UI menus */ +export const setModelOptions = (options: SelectedOption[]) => ({ + type: 'SET_MODEL_OPTIONS', + payload: options, +}); + +export const setSelectedModel = (option: SelectedOption | null) => ({ + type: 'SET_SELECTED_MODEL', + payload: option, +}); + +/** Selectors used across the app */ +export const selectModelOptions = (state: RootState): SelectedOption[] => + state.ModelSelectionReducer.options; + +export const selectSelectedModel = (state: RootState): SelectedOption | null => + state.ModelSelectionReducer.selected;