Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions examples/example_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ExampleExecutor final : public server::AgentExecutor {
lf::a2a::v1::Task task;
task.set_id(task_id);
task.mutable_status()->set_state(lf::a2a::v1::TASK_STATE_WORKING);
task.mutable_status()->mutable_message()->set_role("agent");
task.mutable_status()->mutable_message()->add_parts()->mutable_text()->set_text("ack");
task.mutable_status()->mutable_message()->set_role(lf::a2a::v1::ROLE_AGENT);
task.mutable_status()->mutable_message()->add_parts()->set_text("ack");
task_ = task;

lf::a2a::v1::SendMessageResponse response;
Expand All @@ -77,7 +77,6 @@ class ExampleExecutor final : public server::AgentExecutor {
completed.mutable_status_update()->set_task_id(request.message().task_id());
completed.mutable_status_update()->mutable_status()->set_state(
lf::a2a::v1::TASK_STATE_COMPLETED);
completed.mutable_status_update()->set_final(true);

std::vector<lf::a2a::v1::StreamResponse> events;
events.push_back(working);
Expand Down Expand Up @@ -124,21 +123,29 @@ class ExampleExecutor final : public server::AgentExecutor {

inline lf::a2a::v1::AgentCard BuildRestAgentCard(std::string_view name, std::string_view url) {
lf::a2a::v1::AgentCard card;
card.set_protocol_version("1.0");
card.set_name(std::string(name));
card.set_description("example rest agent");
card.set_version("1.0.0");
card.add_default_input_modes("text/plain");
card.add_default_output_modes("text/plain");
auto* iface = card.add_supported_interfaces();
iface->set_transport(lf::a2a::v1::TRANSPORT_PROTOCOL_REST);
iface->set_url(std::string(url));
iface->set_protocol_binding("HTTP+JSON");
iface->set_protocol_version("1.0");
return card;
}

inline lf::a2a::v1::AgentCard BuildJsonRpcAgentCard(std::string_view name, std::string_view url) {
lf::a2a::v1::AgentCard card;
card.set_protocol_version("1.0");
card.set_name(std::string(name));
card.set_description("example json-rpc agent");
card.set_version("1.0.0");
card.add_default_input_modes("text/plain");
card.add_default_output_modes("text/plain");
auto* iface = card.add_supported_interfaces();
iface->set_transport(lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC);
iface->set_url(std::string(url));
iface->set_protocol_binding("JSONRPC");
iface->set_protocol_version("1.0");
return card;
}

Expand Down
7 changes: 4 additions & 3 deletions include/a2a/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class ClientTransport {

[[nodiscard]] virtual core::Result<lf::a2a::v1::TaskPushNotificationConfig>
CreateTaskPushNotificationConfig(const lf::a2a::v1::TaskPushNotificationConfig& request,
const CallOptions& options) = 0;
const CallOptions& options) = 0;

[[nodiscard]] virtual core::Result<lf::a2a::v1::TaskPushNotificationConfig>
GetTaskPushNotificationConfig(const lf::a2a::v1::GetTaskPushNotificationConfigRequest& request,
Expand Down Expand Up @@ -161,8 +161,9 @@ class A2AClient final {
[[nodiscard]] core::Result<lf::a2a::v1::Task> CancelTask(
const lf::a2a::v1::CancelTaskRequest& request, const CallOptions& options = {});

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> CreateTaskPushNotificationConfig(
const lf::a2a::v1::TaskPushNotificationConfig& request, const CallOptions& options = {});
[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig>
CreateTaskPushNotificationConfig(const lf::a2a::v1::TaskPushNotificationConfig& request,
const CallOptions& options = {});

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> GetTaskPushNotificationConfig(
const lf::a2a::v1::GetTaskPushNotificationConfigRequest& request,
Expand Down
5 changes: 3 additions & 2 deletions include/a2a/client/grpc_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ class GrpcTransport final : public ClientTransport {
[[nodiscard]] core::Result<lf::a2a::v1::Task> CancelTask(
const lf::a2a::v1::CancelTaskRequest& request, const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> CreateTaskPushNotificationConfig(
const lf::a2a::v1::TaskPushNotificationConfig& request, const CallOptions& options) override;
[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig>
CreateTaskPushNotificationConfig(const lf::a2a::v1::TaskPushNotificationConfig& request,
const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> GetTaskPushNotificationConfig(
const lf::a2a::v1::GetTaskPushNotificationConfigRequest& request,
Expand Down
5 changes: 3 additions & 2 deletions include/a2a/client/http_json_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class HttpJsonTransport final : public ClientTransport {
[[nodiscard]] core::Result<lf::a2a::v1::Task> CancelTask(
const lf::a2a::v1::CancelTaskRequest& request, const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> CreateTaskPushNotificationConfig(
const lf::a2a::v1::TaskPushNotificationConfig& request, const CallOptions& options) override;
[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig>
CreateTaskPushNotificationConfig(const lf::a2a::v1::TaskPushNotificationConfig& request,
const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> GetTaskPushNotificationConfig(
const lf::a2a::v1::GetTaskPushNotificationConfigRequest& request,
Expand Down
5 changes: 3 additions & 2 deletions include/a2a/client/json_rpc_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class JsonRpcTransport final : public ClientTransport {
[[nodiscard]] core::Result<lf::a2a::v1::Task> CancelTask(
const lf::a2a::v1::CancelTaskRequest& request, const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> CreateTaskPushNotificationConfig(
const lf::a2a::v1::TaskPushNotificationConfig& request, const CallOptions& options) override;
[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig>
CreateTaskPushNotificationConfig(const lf::a2a::v1::TaskPushNotificationConfig& request,
const CallOptions& options) override;

[[nodiscard]] core::Result<lf::a2a::v1::TaskPushNotificationConfig> GetTaskPushNotificationConfig(
const lf::a2a::v1::GetTaskPushNotificationConfigRequest& request,
Expand Down
136 changes: 62 additions & 74 deletions src/client/discovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,45 +45,38 @@ bool HasHostPortShape(std::string_view endpoint) {
return endpoint.find(':') != std::string_view::npos;
}

bool IsValidInterfaceEndpoint(lf::a2a::v1::TransportProtocol transport, std::string_view endpoint) {
switch (transport) {
case lf::a2a::v1::TRANSPORT_PROTOCOL_REST:
case lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC:
return HasHttpScheme(endpoint);
case lf::a2a::v1::TRANSPORT_PROTOCOL_GRPC:
return HasGrpcScheme(endpoint) || HasHttpScheme(endpoint) || HasHostPortShape(endpoint);
case lf::a2a::v1::TRANSPORT_PROTOCOL_UNSPECIFIED:
case lf::a2a::v1::TransportProtocol_INT_MIN_SENTINEL_DO_NOT_USE_:
case lf::a2a::v1::TransportProtocol_INT_MAX_SENTINEL_DO_NOT_USE_:
return false;
constexpr std::string_view kProtocolBindingHttpJson = "HTTP+JSON";
constexpr std::string_view kProtocolBindingJsonRpc = "JSONRPC";
constexpr std::string_view kProtocolBindingGrpc = "GRPC";

bool IsValidInterfaceEndpoint(std::string_view protocol_binding, std::string_view endpoint) {
if (protocol_binding == kProtocolBindingHttpJson || protocol_binding == kProtocolBindingJsonRpc) {
return HasHttpScheme(endpoint);
}
if (protocol_binding == kProtocolBindingGrpc) {
return HasGrpcScheme(endpoint) || HasHttpScheme(endpoint) || HasHostPortShape(endpoint);
}
return false;
}

PreferredTransport ToPreferredTransport(lf::a2a::v1::TransportProtocol transport) {
switch (transport) {
case lf::a2a::v1::TRANSPORT_PROTOCOL_REST:
return PreferredTransport::kRest;
case lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC:
return PreferredTransport::kJsonRpc;
case lf::a2a::v1::TRANSPORT_PROTOCOL_GRPC:
return PreferredTransport::kGrpc;
case lf::a2a::v1::TRANSPORT_PROTOCOL_UNSPECIFIED:
case lf::a2a::v1::TransportProtocol_INT_MIN_SENTINEL_DO_NOT_USE_:
case lf::a2a::v1::TransportProtocol_INT_MAX_SENTINEL_DO_NOT_USE_:
break;
PreferredTransport ToPreferredTransport(std::string_view protocol_binding) {
if (protocol_binding == kProtocolBindingHttpJson) {
return PreferredTransport::kRest;
}
return PreferredTransport::kRest;
if (protocol_binding == kProtocolBindingJsonRpc) {
return PreferredTransport::kJsonRpc;
}
return PreferredTransport::kGrpc;
}

std::optional<lf::a2a::v1::TransportProtocol> ToWireTransport(PreferredTransport transport) {
std::optional<std::string_view> ToWireTransport(PreferredTransport transport) {
switch (transport) {
case PreferredTransport::kRest:
return lf::a2a::v1::TRANSPORT_PROTOCOL_REST;
return kProtocolBindingHttpJson;
case PreferredTransport::kJsonRpc:
return lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC;
return kProtocolBindingJsonRpc;
case PreferredTransport::kGrpc:
return lf::a2a::v1::TRANSPORT_PROTOCOL_GRPC;
return kProtocolBindingGrpc;
}
return std::nullopt;
}
Expand Down Expand Up @@ -197,40 +190,38 @@ core::Result<std::string> DiscoveryClient::BuildExtendedDiscoveryUrl(std::string
}

core::Result<void> DiscoveryClient::ValidateAgentCard(const lf::a2a::v1::AgentCard& card) {
if (card.protocol_version().empty()) {
return core::Error::Validation("Agent Card protocol_version is required");
}
if (!core::Version::IsSupported(card.protocol_version())) {
return core::Error::UnsupportedVersion("Only A2A protocol version 1.0 is supported");
}
if (card.supported_interfaces().empty()) {
return core::Error::Validation("Agent Card must include at least one supported interface");
}

for (const auto& iface : card.supported_interfaces()) {
if (iface.transport() == lf::a2a::v1::TRANSPORT_PROTOCOL_UNSPECIFIED) {
return core::Error::Validation("Agent Card contains an interface with unspecified transport");
if (iface.protocol_binding().empty()) {
return core::Error::Validation(
"Agent Card contains an interface with unspecified protocol binding");
}
if (iface.protocol_version().empty()) {
return core::Error::Validation("Agent Card contains an interface with no protocol version");
}
if (!core::Version::IsSupported(iface.protocol_version())) {
return core::Error::UnsupportedVersion("Only A2A protocol version 1.0 is supported");
}
if (iface.url().empty()) {
return core::Error::Validation("Agent Card contains an interface without a URL");
}
if (!IsValidInterfaceEndpoint(iface.transport(), iface.url())) {
return core::Error::Validation("Agent Card interface endpoint is invalid for its transport");
if (!IsValidInterfaceEndpoint(iface.protocol_binding(), iface.url())) {
return core::Error::Validation(
"Agent Card interface endpoint is invalid for its protocol binding");
}
for (const auto& requirement : iface.security_requirements()) {
if (!card.security_schemes().contains(requirement)) {
return core::Error::Validation(
"Agent Card interface references an unknown security scheme: " + requirement);
for (const auto& requirement : card.security_requirements()) {
for (const auto& [scheme_name, _] : requirement.schemes()) {
if (!card.security_schemes().contains(scheme_name)) {
return core::Error::Validation(
"Agent Card security requirement references an unknown security scheme: " +
scheme_name);
}
}
}
}

for (const auto& requirement : card.default_security_requirements()) {
if (!card.security_schemes().contains(requirement)) {
return core::Error::Validation("Agent Card default security requirement is not defined: " +
requirement);
}
}
return {};
}

Expand All @@ -241,23 +232,22 @@ core::Result<ResolvedInterface> AgentCardResolver::SelectPreferredInterface(
return core::Error::Validation("Invalid preferred transport requested");
}

std::array<lf::a2a::v1::TransportProtocol, 3> order = {preferred_wire.value(),
lf::a2a::v1::TRANSPORT_PROTOCOL_REST,
lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC};
if (preferred_wire.value() == lf::a2a::v1::TRANSPORT_PROTOCOL_REST) {
order[1] = lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC;
order[2] = lf::a2a::v1::TRANSPORT_PROTOCOL_GRPC;
} else if (preferred_wire.value() == lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC) {
order[1] = lf::a2a::v1::TRANSPORT_PROTOCOL_REST;
order[2] = lf::a2a::v1::TRANSPORT_PROTOCOL_GRPC;
std::array<std::string_view, 3> order = {preferred_wire.value(), kProtocolBindingHttpJson,
kProtocolBindingJsonRpc};
if (preferred_wire.value() == kProtocolBindingHttpJson) {
order[1] = kProtocolBindingJsonRpc;
order[2] = kProtocolBindingGrpc;
} else if (preferred_wire.value() == kProtocolBindingJsonRpc) {
order[1] = kProtocolBindingHttpJson;
order[2] = kProtocolBindingGrpc;
} else {
order[1] = lf::a2a::v1::TRANSPORT_PROTOCOL_REST;
order[2] = lf::a2a::v1::TRANSPORT_PROTOCOL_JSON_RPC;
order[1] = kProtocolBindingHttpJson;
order[2] = kProtocolBindingJsonRpc;
}

for (const auto transport : order) {
for (const auto& iface : card.supported_interfaces()) {
if (iface.transport() != transport) {
if (iface.protocol_binding() != transport) {
continue;
}
const auto valid = ValidateInterface(iface);
Expand All @@ -266,16 +256,14 @@ core::Result<ResolvedInterface> AgentCardResolver::SelectPreferredInterface(
}

ResolvedInterface resolved;
resolved.transport = ToPreferredTransport(transport);
resolved.transport = ToPreferredTransport(iface.protocol_binding());
resolved.url = iface.url();
if (iface.security_requirements().empty()) {
resolved.security_requirements.insert(resolved.security_requirements.end(),
card.default_security_requirements().begin(),
card.default_security_requirements().end());
} else {
resolved.security_requirements.insert(resolved.security_requirements.end(),
iface.security_requirements().begin(),
iface.security_requirements().end());
if (!card.security_requirements().empty()) {
for (const auto& requirement : card.security_requirements()) {
for (const auto& [scheme_name, _] : requirement.schemes()) {
resolved.security_requirements.push_back(scheme_name);
}
}
}
for (const auto& name : resolved.security_requirements) {
const auto scheme = card.security_schemes().find(name);
Expand All @@ -291,14 +279,14 @@ core::Result<ResolvedInterface> AgentCardResolver::SelectPreferredInterface(
}

core::Result<void> AgentCardResolver::ValidateInterface(const lf::a2a::v1::AgentInterface& iface) {
if (iface.transport() == lf::a2a::v1::TRANSPORT_PROTOCOL_UNSPECIFIED) {
return core::Error::Validation("Unsupported transport");
if (iface.protocol_binding().empty()) {
return core::Error::Validation("Unsupported protocol binding");
}
if (iface.url().empty()) {
return core::Error::Validation("Missing interface URL");
}
if (!IsValidInterfaceEndpoint(iface.transport(), iface.url())) {
return core::Error::Validation("Interface endpoint is invalid for its transport");
if (!IsValidInterfaceEndpoint(iface.protocol_binding(), iface.url())) {
return core::Error::Validation("Interface endpoint is invalid for its protocol binding");
}
return {};
}
Expand Down
6 changes: 4 additions & 2 deletions src/client/grpc_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,17 @@ core::Result<lf::a2a::v1::Task> GrpcTransport::CancelTask(
return response;
}

core::Result<lf::a2a::v1::TaskPushNotificationConfig> GrpcTransport::CreateTaskPushNotificationConfig(
core::Result<lf::a2a::v1::TaskPushNotificationConfig>
GrpcTransport::CreateTaskPushNotificationConfig(
const lf::a2a::v1::TaskPushNotificationConfig& request, const CallOptions& options) {
auto context_result = BuildContext(options);
if (!context_result.ok()) {
return context_result.error();
}
auto context = std::move(context_result.value());
lf::a2a::v1::TaskPushNotificationConfig response;
const auto status = rpc_client_->CreateTaskPushNotificationConfig(context.get(), request, &response);
const auto status =
rpc_client_->CreateTaskPushNotificationConfig(context.get(), request, &response);
if (!status.ok()) {
return BuildGrpcError(status);
}
Expand Down
8 changes: 4 additions & 4 deletions src/client/http_json_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ core::Result<lf::a2a::v1::Task> HttpJsonTransport::GetTask(
}

std::string endpoint = BuildTaskPath(request.id());
if (!request.history_length().empty()) {
endpoint += "?historyLength=" + request.history_length();
if (request.has_history_length()) {
endpoint += "?historyLength=" + std::to_string(request.history_length());
}

const auto response = SendRequest({.method = "GET", .endpoint = endpoint}, {}, options);
Expand Down Expand Up @@ -528,8 +528,8 @@ core::Result<std::unique_ptr<StreamHandle>> HttpJsonTransport::SubscribeTask(
}

std::string endpoint = BuildTaskPath(request.id()) + ":subscribe";
if (!request.history_length().empty()) {
endpoint += "?historyLength=" + request.history_length();
if (request.has_history_length()) {
endpoint += "?historyLength=" + std::to_string(request.history_length());
}

return StartSseStream({.method = "GET", .endpoint = endpoint}, {}, observer, options);
Expand Down
6 changes: 4 additions & 2 deletions src/server/rest_server_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ RestServerTransport::RestServerTransport(Dispatcher* dispatcher, lf::a2a::v1::Ag
RestServerTransportOptions options)
: transport_(dispatcher), agent_card_(std::move(agent_card)), options_(std::move(options)) {
options_.rest_api_base_path = NormalizeBasePath(options_.rest_api_base_path);
if (agent_card_.protocol_version().empty()) {
agent_card_.set_protocol_version(core::Version::HeaderValue());
for (auto& iface : *agent_card_.mutable_supported_interfaces()) {
if (iface.protocol_version().empty()) {
iface.set_protocol_version(core::Version::HeaderValue());
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/server/rest_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ std::optional<DispatchRequest> RestTransport::BuildDispatchRequest(const RestReq
payload.set_id(task_id.value());
if (const auto history_length = LookupQuery(request, "historyLength");
history_length.has_value()) {
payload.set_history_length(*history_length);
const int parsed_history_length = ParsePageSize(*history_length);
if (parsed_history_length < 0) {
return std::nullopt;
}
payload.set_history_length(parsed_history_length);
}
return DispatchRequest{.operation = DispatcherOperation::kGetTask, .payload = payload};
}
Expand Down
Loading
Loading