Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speech recognition to chat interface #1541

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/lib/components/animations/LoadingAnimation.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<script lang="ts">
export let classNames = "";
</script>

<div class={"loading-animation " + classNames}>
<div class="spinner"></div>
</div>

<style>
.loading-animation {
display: flex;
justify-content: center;
align-items: center;
height: 100%;
}

.spinner {
border: 4px solid rgba(0, 0, 0, 0.1);
border-left-color: #000;
border-radius: 50%;
width: 36px;
height: 36px;
animation: spin 1s linear infinite;
}

@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
</style>
43 changes: 43 additions & 0 deletions src/lib/components/animations/TranscriptionAnimation.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<script lang="ts">
export let classNames = "";
</script>

<div class={"transcription-animation " + classNames}>
<div class="wave"></div>
<div class="wave"></div>
<div class="wave"></div>
</div>

<style>
.transcription-animation {
display: flex;
justify-content: center;
align-items: center;
height: 100%;
}

.wave {
background: #000;
width: 5px;
height: 20px;
margin: 0 2px;
animation: wave 1.2s linear infinite;
}

.wave:nth-child(2) {
animation-delay: -1.1s;
}

.wave:nth-child(3) {
animation-delay: -1s;
}

@keyframes wave {
0%, 40%, 100% {
transform: scaleY(0.4);
}
20% {
transform: scaleY(1);
}
}
</style>
2 changes: 1 addition & 1 deletion src/lib/components/chat/ChatMessage.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
.replaceAll("<", "&lt;")
.trim();

for (const stop of [...(model.parameters?.stop ?? []), "<|endoftext|>"]) {
for (const stop of [...(model.parameters?.stop ?? [])]) {
if (ret.endsWith(stop)) {
ret = ret.slice(0, -stop.length).trim();
}
Expand Down
81 changes: 79 additions & 2 deletions src/lib/components/chat/ChatWindow.svelte
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
<script lang="ts">
import type { Message, MessageFile } from "$lib/types/Message";
import { createEventDispatcher, onDestroy, tick } from "svelte";
import { createEventDispatcher, onDestroy, tick, onMount } from "svelte";

import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
import CarbonExport from "~icons/carbon/export";
import CarbonStopFilledAlt from "~icons/carbon/stop-filled-alt";
import CarbonCheckmark from "~icons/carbon/checkmark";
import CarbonCaretDown from "~icons/carbon/caret-down";
import CarbonMicrophone from "~icons/carbon/microphone";

import EosIconsLoading from "~icons/eos-icons/loading";

Expand Down Expand Up @@ -37,6 +38,10 @@
import { useSettingsStore } from "$lib/stores/settings";
import type { ToolFront } from "$lib/types/Tool";
import ModelSwitch from "./ModelSwitch.svelte";
import TranscriptionAnimation from "../animations/TranscriptionAnimation.svelte";

import { AutomaticSpeechRecognitionPipeline, pipeline } from "@huggingface/transformers";
import LoadingAnimation from "../animations/LoadingAnimation.svelte";

export let messages: Message[] = [];
export let loading = false;
Expand Down Expand Up @@ -218,12 +223,68 @@
];

$: isFileUploadEnabled = activeMimeTypes.length > 0;

let transcriber: AutomaticSpeechRecognitionPipeline;
let isRecording = false;
let mediaRecorder: MediaRecorder;
let audioChunks: BlobPart[] = [];
let isLoadingModel = false;
let isTranscribing = false;
let webgpuSupported = false;
let microphoneButton: HTMLButtonElement;

async function initializeTranscriber() {
if (!transcriber) {
isLoadingModel = true;
transcriber = await pipeline("automatic-speech-recognition", "onnx-community/whisper-small", {
device: "webgpu",
});
isLoadingModel = false;
}
}

async function startRecording() {
await initializeTranscriber();
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
audioChunks = [];

mediaRecorder.ondataavailable = (event) => {
audioChunks.push(event.data);
};

mediaRecorder.onstop = async () => {
const audioBlob = new Blob(audioChunks, { type: "audio/wav" });
const audioUrl = URL.createObjectURL(audioBlob);
isTranscribing = true;
const userLanguage = navigator.language;
console.log("Detected language:", userLanguage);
const firstTwoChars = userLanguage.slice(0, 2).toLowerCase();
const output = await transcriber(audioUrl, { language: firstTwoChars, task: "transcribe" });
message = output.text;
isTranscribing = false;
};

mediaRecorder.start();
isRecording = true;
}

function stopRecording() {
mediaRecorder.stop();
isRecording = false;
}

onMount(() => {
if (navigator.gpu) {
webgpuSupported = true;
}
});
</script>

<svelte:window
on:dragenter={onDragEnter}
on:dragleave={onDragLeave}
on:dragover|preventDefault
on:dragleover|preventDefault
on:drop|preventDefault={() => (onDrag = false)}
/>

Expand Down Expand Up @@ -449,6 +510,16 @@
>
<CarbonSendAltFilled />
</button>
{#if webgpuSupported}
<button
bind:this={microphoneButton}
class="btn mx-1 my-1 h-[2.4rem] self-end rounded-lg bg-transparent p-1 px-[0.7rem] text-gray-400 enabled:hover:text-gray-700 disabled:opacity-60 enabled:dark:hover:text-gray-100 dark:disabled:opacity-40"
on:click={isRecording ? stopRecording : startRecording}
type="button"
>
<CarbonMicrophone />
</button>
{/if}
{/if}
</div>
{/if}
Expand Down Expand Up @@ -506,5 +577,11 @@
{/if}
</div>
</div>
{#if isLoadingModel}
<LoadingAnimation classNames="absolute inset-0 z-10" />
{/if}
{#if isTranscribing}
<TranscriptionAnimation classNames="absolute inset-0 z-10" />
{/if}
</div>
</div>
23 changes: 23 additions & 0 deletions src/lib/components/icons/IconMicrophone.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<script lang="ts">
export let classNames = "";
export let webgpuSupported = true;
</script>

{#if webgpuSupported}
<svg
class={classNames}
xmlns="http://www.w3.org/2000/svg"
aria-hidden="true"
focusable="false"
role="img"
width="1em"
height="1em"
fill="currentColor"
preserveAspectRatio="xMidYMid meet"
viewBox="0 0 24 24"
>
<path
d="M12 14a3.5 3.5 0 0 0 3.5-3.5V5.5A3.5 3.5 0 0 0 12 2a3.5 3.5 0 0 0-3.5 3.5v5A3.5 3.5 0 0 0 12 14zm6.5-3.5a6.5 6.5 0 0 1-13 0h-2a8.5 8.5 0 0 0 17 0h-2zM11 18v4h2v-4h-2z"
/>
</svg>
{/if}
2 changes: 1 addition & 1 deletion src/lib/server/generateFromDefaultEndpoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export async function generateFromDefaultEndpoint({
// if not generated_text is here it means the generation is not done
if (output.generated_text) {
let generated_text = output.generated_text;
for (const stop of [...(smallModel.parameters?.stop ?? []), "<|endoftext|>"]) {
for (const stop of [...(smallModel.parameters?.stop ?? [])]) {
if (generated_text.endsWith(stop)) {
generated_text = generated_text.slice(0, -stop.length).trimEnd();
}
Expand Down