Skip to content

Commit

Permalink
Big improvements to the Ollama agent (#310)
Browse files Browse the repository at this point in the history
1. Switching to `OllamaSharp` to simplify API calls and support both streaming and non-streaming.
2. Add context support to enable the agent to remember previous responses.
3. Add configuration management support.
4. Improving chat interaction.
  • Loading branch information
kborowinski authored Dec 3, 2024
1 parent 3110c2c commit 943b701
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 166 deletions.
5 changes: 5 additions & 0 deletions shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<SuppressNETCoreSdkPreviewMessage>true</SuppressNETCoreSdkPreviewMessage>
<CopyLocalLockFileAssemblies>true</CopyLocalLockFileAssemblies>

<!-- Disable deps.json generation -->
<GenerateDependencyFile>false</GenerateDependencyFile>
Expand All @@ -15,6 +16,10 @@
<DebugType>None</DebugType>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="OllamaSharp" Version="4.0.8" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\AIShell.Abstraction\AIShell.Abstraction.csproj">
<!-- Disable copying AIShell.Abstraction.dll to output folder -->
Expand Down
256 changes: 230 additions & 26 deletions shell/agents/AIShell.Ollama.Agent/OllamaAgent.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using AIShell.Abstraction;
using OllamaSharp;
using OllamaSharp.Models;

namespace AIShell.Ollama.Agent;

public sealed class OllamaAgent : ILLMAgent
public sealed partial class OllamaAgent : ILLMAgent
{
private bool _reloadSettings;
private bool _isDisposed;
private string _configRoot;
private Settings _settings;
private OllamaApiClient _client;
private GenerateRequest _request;
private FileSystemWatcher _watcher;

/// <summary>
/// The name of setting file
/// </summary>
private const string SettingFileName = "ollama.config.json";

/// <summary>
/// Gets the settings.
/// </summary>
internal Settings Settings => _settings;

/// <summary>
/// The name of the agent
/// </summary>
Expand All @@ -13,7 +36,7 @@ public sealed class OllamaAgent : ILLMAgent
/// <summary>
/// The description of the agent to be shown at start up
/// </summary>
public string Description => "This is an AI assistant to interact with a language model running locally by utilizing the Ollama CLI tool. Be sure to follow all prerequisites in aka.ms/aish/ollama";
public string Description => "This is an AI assistant to interact with a language model running locally or remotely by utilizing the Ollama API. Be sure to follow all prerequisites in https://github.com/PowerShell/AIShell/tree/main/shell/agents/AIShell.Ollama.Agent";

/// <summary>
/// This is the company added to /like and /dislike verbiage for who the telemetry helps.
Expand All @@ -30,19 +53,25 @@ public sealed class OllamaAgent : ILLMAgent
/// <summary>
/// These are any optional legal/additional information links you want to provide at start up
/// </summary>
public Dictionary<string, string> LegalLinks { private set; get; }

/// <summary>
/// This is the chat service to call the API from
/// </summary>
private OllamaChatService _chatService;
public Dictionary<string, string> LegalLinks { private set; get; } = new(StringComparer.OrdinalIgnoreCase)
{
["Ollama Docs"] = "https://github.com/ollama/ollama",
["Prerequisites"] = "https://github.com/PowerShell/AIShell/tree/main/shell/agents/AIShell.Ollama.Agent"
};

/// <summary>
/// Dispose method to clean up the unmanaged resource of the chatService
/// </summary>
public void Dispose()
{
_chatService?.Dispose();
if (_isDisposed)
{
return;
}

GC.SuppressFinalize(this);
_watcher.Dispose();
_isDisposed = true;
}

/// <summary>
Expand All @@ -51,13 +80,31 @@ public void Dispose()
/// <param name="config">Agent configuration for any configuration file and other settings</param>
public void Initialize(AgentConfig config)
{
_chatService = new OllamaChatService();
_configRoot = config.ConfigurationRoot;

SettingFile = Path.Combine(_configRoot, SettingFileName);
_settings = ReadSettings();

if (_settings is null)
{
// Create the setting file with examples to serve as a template for user to update.
NewExampleSettingFile();
_settings = ReadSettings();
}

// Create Ollama request
_request = new GenerateRequest();

// Create Ollama client
_client = new OllamaApiClient(_settings.Endpoint);

LegalLinks = new(StringComparer.OrdinalIgnoreCase)
// Watch for changes to the settings file
_watcher = new FileSystemWatcher(_configRoot, SettingFileName)
{
["Ollama Docs"] = "https://github.com/ollama/ollama",
["Prerequisites"] = "https://aka.ms/ollama/readme"
NotifyFilter = NotifyFilters.LastWrite,
EnableRaisingEvents = true,
};
_watcher.Changed += OnSettingFileChange;
}

/// <summary>
Expand All @@ -68,7 +115,7 @@ public void Initialize(AgentConfig config)
/// <summary>
/// Gets the path to the setting file of the agent.
/// </summary>
public string SettingFile { private set; get; } = null;
public string SettingFile { private set; get; }

/// <summary>
/// Gets a value indicating whether the agent accepts a specific user action feedback.
Expand All @@ -87,7 +134,19 @@ public void OnUserAction(UserActionPayload actionPayload) {}
/// Refresh the current chat by starting a new chat session.
/// This method allows an agent to reset chat states, interact with user for authentication, print welcome message, and more.
/// </summary>
public Task RefreshChatAsync(IShell shell, bool force) => Task.CompletedTask;
public Task RefreshChatAsync(IShell shell, bool force)
{
if (force)
{
// Reload the setting file if needed.
ReloadSettings();

// Reset context
_request.Context = null;
}

return Task.CompletedTask;
}

/// <summary>
/// Main chat function that takes the users input and passes it to the LLM and renders it.
Expand All @@ -100,26 +159,171 @@ public async Task<bool> ChatAsync(string input, IShell shell)
// Get the shell host
IHost host = shell.Host;

// get the cancellation token
// Get the cancellation token
CancellationToken token = shell.CancellationToken;

if (Process.GetProcessesByName("ollama").Length is 0)
// Reload the setting file if needed.
ReloadSettings();

if (IsLocalHost().IsMatch(_client.Uri.Host) && Process.GetProcessesByName("ollama").Length is 0)
{
host.RenderFullResponse("Please be sure the Ollama is installed and server is running. Check all the prerequisites in the README of this agent are met.");
host.WriteErrorLine("Please be sure the Ollama is installed and server is running. Check all the prerequisites in the README of this agent are met.");
return false;
}

ResponseData ollamaResponse = await host.RunWithSpinnerAsync(
status: "Thinking ...",
func: async context => await _chatService.GetChatResponseAsync(context, input, token)
).ConfigureAwait(false);
// Prepare request
_request.Prompt = input;
_request.Model = _settings.Model;
_request.Stream = _settings.Stream;

if (ollamaResponse is not null)
try
{
// render the content
host.RenderFullResponse(ollamaResponse.response);
if (_request.Stream)
{
// Wait for the stream with the spinner running
var ollamaStreamEnumerator = await host.RunWithSpinnerAsync(
status: "Thinking ...",
func: async () =>
{
// Start generating the stream asynchronously and return an enumerator
var enumerator = _client.GenerateAsync(_request, token).GetAsyncEnumerator(token);
if (await enumerator.MoveNextAsync().ConfigureAwait(false))
{
return enumerator;
}
return null;
}
).ConfigureAwait(false);

if (ollamaStreamEnumerator is not null)
{
using IStreamRender streamingRender = host.NewStreamRender(token);

do
{
var currentStream = ollamaStreamEnumerator.Current;

// Update the render with stream response
streamingRender.Refresh(currentStream.Response);

if (currentStream.Done)
{
// If the stream is complete, update the request context with the last stream context
var ollamaLastStream = (GenerateDoneResponseStream)currentStream;
_request.Context = ollamaLastStream.Context;
}
}
while (await ollamaStreamEnumerator.MoveNextAsync().ConfigureAwait(false));
}
}
else
{
// Build single response with spinner
var ollamaResponse = await host.RunWithSpinnerAsync(
status: "Thinking ...",
func: async () => { return await _client.GenerateAsync(_request, token).StreamToEndAsync(); }
).ConfigureAwait(false);

// Update request context
_request.Context = ollamaResponse.Context;

// Render the full response
host.RenderFullResponse(ollamaResponse.Response);
}
}

catch (OperationCanceledException)
{
// Ignore the cancellation exception.
}
catch (HttpRequestException e)
{
host.WriteErrorLine($"{e.Message}");
host.WriteErrorLine($"Ollama model: \"{_settings.Model}\"");
host.WriteErrorLine($"Ollama endpoint: \"{_settings.Endpoint}\"");
host.WriteErrorLine($"Ollama settings: \"{SettingFile}\"");
}

return true;
}

private void ReloadSettings()
{
if (_reloadSettings)
{
_reloadSettings = false;
var settings = ReadSettings();
if (settings is null)
{
return;
}

_settings = settings;

// Check if the endpoint has changed
bool isEndpointChanged = !string.Equals(_settings.Endpoint, _client.Uri.OriginalString, StringComparison.OrdinalIgnoreCase);

if (isEndpointChanged)
{
// Create a new client with updated endpoint
_client = new OllamaApiClient(_settings.Endpoint);
}
}
}

private Settings ReadSettings()
{
Settings settings = null;
FileInfo file = new(SettingFile);

if (file.Exists)
{
try
{
using var stream = file.OpenRead();
var data = JsonSerializer.Deserialize(stream, SourceGenerationContext.Default.ConfigData);
settings = new Settings(data);
}
catch (Exception e)
{
throw new InvalidDataException($"Parsing settings from '{SettingFile}' failed with the following error: {e.Message}", e);
}
}

return settings;
}

private void OnSettingFileChange(object sender, FileSystemEventArgs e)
{
if (e.ChangeType is WatcherChangeTypes.Changed)
{
_reloadSettings = true;
}
}

private void NewExampleSettingFile()
{
string SampleContent = """
{
// To use Ollama API service:
// 1. Install Ollama: `winget install Ollama.Ollama`
// 2. Start Ollama API server: `ollama serve`
// 3. Install Ollama model: `ollama pull phi3`
// Declare Ollama model
"Model": "phi3",
// Declare Ollama endpoint
"Endpoint": "http://localhost:11434",
// Enable Ollama streaming
"Stream": false
}
""";
File.WriteAllText(SettingFile, SampleContent, Encoding.UTF8);
}

/// <summary>
/// Defines a generated regular expression to match localhost addresses
/// "localhost", "127.0.0.1" and "[::1]" with case-insensitivity.
/// </summary>
[GeneratedRegex("^(localhost|127\\.0\\.0\\.1|\\[::1\\])$", RegexOptions.IgnoreCase)]
internal partial Regex IsLocalHost();
}
Loading

0 comments on commit 943b701

Please sign in to comment.