Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions crates/openshell-providers/src/profiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,16 @@ pub fn validate_profile_set(
message,
));
}
if credential.token_grant.is_none()
&& let Err(message) = validate_static_credential_header_name(credential)
{
diagnostics.push(ProfileValidationDiagnostic::error(
source,
profile_id,
"credentials.header_name",
message,
));
}
}

for (index, endpoint) in profile.endpoints.iter().enumerate() {
Expand Down Expand Up @@ -1542,6 +1552,19 @@ fn validate_token_grant_header_name(credential: &CredentialProfile) -> Result<()
"" | "bearer" | "header" => credential.header_name.trim(),
_ => return Ok(()),
};
validate_credential_header_name(header_name, "token_grant")
}

fn validate_static_credential_header_name(credential: &CredentialProfile) -> Result<(), String> {
let header_name = match credential.auth_style.trim().to_ascii_lowercase().as_str() {
"bearer" if credential.header_name.trim().is_empty() => "Authorization",
"bearer" | "header" => credential.header_name.trim(),
_ => return Ok(()),
};
validate_credential_header_name(header_name, "credential")
}

fn validate_credential_header_name(header_name: &str, label: &str) -> Result<(), String> {
if header_name.is_empty() {
return Ok(());
}
Expand All @@ -1566,13 +1589,14 @@ fn validate_token_grant_header_name(credential: &CredentialProfile) -> Result<()
)
});
if !valid {
return Err("token_grant header_name is not a valid HTTP header name".to_string());
return Err(format!(
"{label} header_name is not a valid HTTP header name"
));
}
match header_name.to_ascii_lowercase().as_str() {
"host" | "content-length" | "transfer-encoding" | "connection" => Err(
"token_grant header_name may not override HTTP framing or connection headers"
.to_string(),
),
"host" | "content-length" | "transfer-encoding" | "connection" => Err(format!(
"{label} header_name may not override HTTP framing or connection headers"
)),
_ => Ok(()),
}
}
Expand Down Expand Up @@ -2166,6 +2190,93 @@ credentials:
);
}

#[test]
fn validate_profile_set_rejects_static_credential_framing_header_name() {
let profile = parse_profile_yaml(
r"
id: framing-header-static
display_name: Framing Header Static
credentials:
- name: api_token
env_vars: [API_TOKEN]
auth_style: header
header_name: Host
",
)
.expect("profile should parse");

let diagnostics = validate_profile_set(&[("framing-static.yaml".to_string(), profile)]);
let diagnostic = diagnostics
.iter()
.find(|diagnostic| {
diagnostic.field == "credentials.header_name"
&& diagnostic.message.contains("HTTP framing")
})
.expect("framing header diagnostic should be reported for static credentials");

assert_eq!(
diagnostic.message,
"credential header_name may not override HTTP framing or connection headers"
);
}

#[test]
fn validate_profile_set_rejects_static_credential_invalid_header_name() {
let profile = parse_profile_yaml(
r"
id: invalid-header-static
display_name: Invalid Header Static
credentials:
- name: api_token
env_vars: [API_TOKEN]
auth_style: header
header_name: 'Invalid Header'
",
)
.expect("profile should parse");

let diagnostics =
validate_profile_set(&[("invalid-header-static.yaml".to_string(), profile)]);
let diagnostic = diagnostics
.iter()
.find(|diagnostic| {
diagnostic.field == "credentials.header_name"
&& diagnostic.message.contains("not a valid HTTP header name")
})
.expect("invalid header name diagnostic should be reported for static credentials");

assert_eq!(
diagnostic.message,
"credential header_name is not a valid HTTP header name"
);
}

#[test]
fn validate_profile_set_accepts_static_credential_valid_header_name() {
let profile = parse_profile_yaml(
r"
id: valid-header-static
display_name: Valid Header Static
credentials:
- name: api_token
env_vars: [API_TOKEN]
auth_style: header
header_name: X-Api-Key
endpoints:
- host: api.example.com
port: 443
",
)
.expect("profile should parse");

let diagnostics =
validate_profile_set(&[("valid-header-static.yaml".to_string(), profile)]);
assert!(
diagnostics.is_empty(),
"valid static credential header should produce no diagnostics, got: {diagnostics:?}"
);
}

#[test]
fn validate_profile_set_rejects_ambiguous_same_credential_audience_overrides() {
let profile = parse_profile_yaml(
Expand Down
78 changes: 65 additions & 13 deletions crates/openshell-server/src/grpc/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1219,8 +1219,12 @@ pub(super) async fn handle_get_sandbox_config(

let settings = merge_effective_settings(&global_settings, &sandbox_settings)?;
let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source);
let provider_env_revision =
compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?;
let provider_env_revision = compute_provider_env_revision(
state.store.as_ref(),
&sandbox_provider_names,
providers_v2_enabled,
)
.await?;

Ok(Response::new(GetSandboxConfigResponse {
policy,
Expand All @@ -1237,9 +1241,11 @@ pub(super) async fn handle_get_sandbox_config(
pub(super) async fn compute_provider_env_revision(
store: &Store,
provider_names: &[String],
providers_v2_enabled: bool,
) -> Result<u64, Status> {
let mut hasher = Sha256::new();
hasher.update(b"openshell-provider-env-revision-v1");
hasher.update([u8::from(providers_v2_enabled)]);

for provider_name in provider_names {
hasher.update(provider_name.as_bytes());
Expand Down Expand Up @@ -1366,6 +1372,16 @@ async fn profile_provider_policy_layers(
Ok(layers)
}

pub async fn is_providers_v2_enabled(store: &Store) -> bool {
load_global_settings(store)
.await
.and_then(|s| bool_setting_enabled(&s, settings::PROVIDERS_V2_ENABLED_KEY))
.unwrap_or_else(|e| {
warn!("failed to read providers_v2_enabled setting, defaulting to false: {e}");
false
})
}

fn bool_setting_enabled(settings: &StoredSettings, key: &str) -> Result<bool, Status> {
match settings.settings.get(key) {
None => Ok(false),
Expand Down Expand Up @@ -1407,13 +1423,20 @@ pub(super) async fn handle_get_sandbox_provider_environment(
.spec
.ok_or_else(|| Status::internal("sandbox has no spec"))?;

let providers_v2_enabled = is_providers_v2_enabled(state.store.as_ref()).await;

let provider_names = spec.providers;
let provider_env_revision =
compute_provider_env_revision(state.store.as_ref(), &provider_names).await?;
let provider_environment =
super::provider::resolve_provider_environment(state.store.as_ref(), &provider_names)
compute_provider_env_revision(state.store.as_ref(), &provider_names, providers_v2_enabled)
.await?;

let provider_environment = super::provider::resolve_provider_environment(
state.store.as_ref(),
&provider_names,
providers_v2_enabled,
)
.await?;

info!(
sandbox_id = %sandbox_id,
provider_count = provider_names.len(),
Expand Down Expand Up @@ -5001,10 +5024,13 @@ mod tests {
.await
.unwrap();

let first =
compute_provider_env_revision(state.store.as_ref(), &["work-custom-token".to_string()])
.await
.unwrap();
let first = compute_provider_env_revision(
state.store.as_ref(),
&["work-custom-token".to_string()],
false,
)
.await
.unwrap();

tokio::time::sleep(Duration::from_millis(2)).await;
state
Expand All @@ -5015,17 +5041,43 @@ mod tests {
.await
.unwrap();

let second =
compute_provider_env_revision(state.store.as_ref(), &["work-custom-token".to_string()])
.await
.unwrap();
let second = compute_provider_env_revision(
state.store.as_ref(),
&["work-custom-token".to_string()],
false,
)
.await
.unwrap();

assert_ne!(
first, second,
"custom provider profile updates must trigger sandbox dynamic credential refresh"
);
}

#[tokio::test]
async fn provider_env_revision_changes_when_providers_v2_enabled_toggles() {
let state = test_server_state().await;
let store = state.store.as_ref();
let provider = test_provider("work-github", "github");
store.put_message(&provider).await.unwrap();

let revision_v2_off =
compute_provider_env_revision(store, &["work-github".to_string()], false)
.await
.unwrap();

let revision_v2_on =
compute_provider_env_revision(store, &["work-github".to_string()], true)
.await
.unwrap();

assert_ne!(
revision_v2_off, revision_v2_on,
"toggling providers_v2_enabled must change provider_env_revision so running sandboxes refresh"
);
}

#[tokio::test]
async fn sandbox_config_and_provider_env_follow_attached_provider_lifecycle() {
use crate::grpc::sandbox::{
Expand Down
Loading
Loading