Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out ActionExecutionClient #5796

Merged
merged 29 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions openhands/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.impl.e2b.sandbox import E2BBox
from openhands.runtime.impl.eventstream.eventstream_runtime import (
EventStreamRuntime,
from openhands.runtime.impl.docker.docker_runtime import (
DockerRuntime,
)
from openhands.runtime.impl.e2b.sandbox import E2BBox
from openhands.runtime.impl.modal.modal_runtime import ModalRuntime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
from openhands.runtime.impl.runloop.runloop_runtime import RunloopRuntime


def get_runtime_cls(name: str):
# Local imports to avoid circular imports
if name == 'eventstream':
return EventStreamRuntime
if name == 'eventstream' or name == 'docker':
return DockerRuntime
elif name == 'e2b':
return E2BBox
elif name == 'remote':
Expand All @@ -30,6 +30,6 @@ def get_runtime_cls(name: str):
'RemoteRuntime',
'ModalRuntime',
'RunloopRuntime',
'EventStreamRuntime',
'DockerRuntime',
'get_runtime_cls',
]
289 changes: 289 additions & 0 deletions openhands/runtime/impl/action_execution/action_execution_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import os
import tempfile
import threading
from abc import abstractmethod
from pathlib import Path
from typing import Any
from zipfile import ZipFile

import requests

from openhands.core.config import AppConfig
from openhands.core.exceptions import (
AgentRuntimeTimeoutError,
)
from openhands.events import EventStream
from openhands.events.action import (
ActionConfirmationStatus,
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
FileEditAction,
FileReadAction,
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.action import Action
from openhands.events.observation import (
ErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request


class ActionExecutionClient(Runtime):
"""Base class for runtimes that interact with the action execution server.

This class contains shared logic between DockerRuntime and RemoteRuntime
for interacting with the HTTP server defined in action_execution_server.py.
"""

def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_callback: Any | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
):
self.session = requests.Session()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self._runtime_initialized: bool = False
self._vscode_token: str | None = None # initial dummy value
super().__init__(
config,
event_stream,
sid,
plugins,
env_vars,
status_callback,
attach_to_existing,
headless_mode,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most thing LGTM! Do we keep the LogStreamer log in the client or still in each runtime implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh does the LogStreamer use the API now, instead of docker logs? If so we can centralize it!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like no--we need different log logic for docker, remote, modal, etc.

Mostly it seems like LogStreamer is a convenience when running locally, so only DockerRuntime has it for now. I removed the Modal and Runloop implementations since they're unused and hard to maintain


@abstractmethod
def _get_action_execution_server_host(self) -> str:
pass

def _send_action_server_request(
self,
method: str,
url: str,
**kwargs,
) -> requests.Response:
"""Send a request to the action execution server.

Args:
method: HTTP method (GET, POST, etc.)
url: URL to send the request to
**kwargs: Additional arguments to pass to requests.request()

Returns:
Response from the server

Raises:
AgentRuntimeError: If the request fails
"""
return send_request(self.session, method, url, **kwargs)

def check_if_alive(self) -> None:
with self._send_action_server_request(
'GET',
f'{self._get_action_execution_server_host()}/alive',
timeout=5,
):
pass

def list_files(self, path: str | None = None) -> list[str]:
"""List files in the sandbox.

If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
"""

try:
data = {}
if path is not None:
data['path'] = path

with send_request(
self.session,
'POST',
f'{self._get_action_execution_server_host()}/list_files',
json=data,
timeout=10,
) as response:
response_json = response.json()
assert isinstance(response_json, list)
return response_json
except requests.Timeout:
raise TimeoutError('List files operation timed out')

def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return as a stream of bytes."""

try:
params = {'path': path}
with send_request(
self.session,
'GET',
f'{self._get_action_execution_server_host()}/download_files',
params=params,
stream=True,
timeout=30,
) as response:
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
except requests.Timeout:
raise TimeoutError('Copy operation timed out')

def copy_to(
self, host_src: str, sandbox_dest: str, recursive: bool = False
) -> None:
if not os.path.exists(host_src):
raise FileNotFoundError(f'Source file {host_src} does not exist')

try:
if recursive:
with tempfile.NamedTemporaryFile(
suffix='.zip', delete=False
) as temp_zip:
temp_zip_path = temp_zip.name

with ZipFile(temp_zip_path, 'w') as zipf:
for root, _, files in os.walk(host_src):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(
file_path, os.path.dirname(host_src)
)
zipf.write(file_path, arcname)

upload_data = {'file': open(temp_zip_path, 'rb')}
else:
upload_data = {'file': open(host_src, 'rb')}

params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}

with self._send_action_server_request(
'POST',
f'{self._get_action_execution_server_host()}/upload_file',
files=upload_data,
params=params,
timeout=300,
) as response:
self.log(
'debug',
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
)
finally:
if recursive:
os.unlink(temp_zip_path)
self.log(
'debug', f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}'
)

def get_vscode_token(self) -> str:
if self.vscode_enabled and self._runtime_initialized:
if self._vscode_token is not None: # cached value
return self._vscode_token
with send_request(
self.session,
'GET',
f'{self._get_action_execution_server_host()}/vscode/connection_token',
timeout=10,
) as response:
response_json = response.json()
assert isinstance(response_json, dict)
if response_json['token'] is None:
return ''
self._vscode_token = response_json['token']
return response_json['token']
else:
return ''

def send_action_for_execution(self, action: Action) -> Observation:
if isinstance(action, FileEditAction):
return self.edit(action)

# set timeout to default if not set
if action.timeout is None:
action.timeout = self.config.sandbox.timeout

with self.action_semaphore:
if not action.runnable:
return NullObservation('')
if (
hasattr(action, 'confirmation_state')
and action.confirmation_state
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
if (
getattr(action, 'confirmation_state', None)
== ActionConfirmationStatus.REJECTED
):
return UserRejectObservation(
'Action has been rejected by the user! Waiting for further user input.'
)

assert action.timeout is not None

try:
with send_request(
self.session,
'POST',
f'{self._get_action_execution_server_host()}/execute_action',
json={'action': event_to_dict(action)},
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
raise AgentRuntimeTimeoutError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)

return obs

def run(self, action: CmdRunAction) -> Observation:
return self.send_action_for_execution(action)

def run_ipython(self, action: IPythonRunCellAction) -> Observation:
return self.send_action_for_execution(action)

def read(self, action: FileReadAction) -> Observation:
return self.send_action_for_execution(action)

def write(self, action: FileWriteAction) -> Observation:
return self.send_action_for_execution(action)

def browse(self, action: BrowseURLAction) -> Observation:
return self.send_action_for_execution(action)

def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return self.send_action_for_execution(action)

def close(self) -> None:
self.session.close()
Loading
Loading