Skip to content

Commit

Permalink
auth azure: support access token
Browse files Browse the repository at this point in the history
  • Loading branch information
LiliDeng committed Dec 25, 2024
1 parent f2014f6 commit 6333f0f
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import re
import sys
import time
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
from datetime import datetime
Expand All @@ -17,6 +18,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import requests
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -400,6 +402,27 @@ def cloud(self, value: Optional[CloudSchema]) -> None:
self.cloud_raw = value.to_dict() # type: ignore


class StaticAccessTokenCredential(TokenCredential):
def __init__(self, token: str, expires_on: int) -> None:
"""
Initialize StaticAccessTokenCredential with the provided token and expiry time.
:param token: The Azure access token as a string.
:param expires_on: The expiry time of the token as an integer (Unix timestamp).
"""
self._token = token
self._expires_on = expires_on

def get_token(self, *scopes: str) -> AccessToken:
"""
Get the access token for the specified scopes.
:param scopes: The OAuth 2.0 scopes the token applies to.
:return: An AccessToken instance containing the token and its expiry time.
"""
return AccessToken(self._token, self._expires_on)


class AzurePlatform(Platform):
_diagnostic_storage_container_pattern = re.compile(
r"(https:\/\/)(?P<storage_name>.*)([.].*){4}\/(?P<container_name>.*)\/",
Expand Down Expand Up @@ -700,9 +723,9 @@ def _get_node_information(self, node: Node) -> Dict[str, str]:
node.log.debug(f"vm generation: {information[KEY_VM_GENERATION]}")
if node.capture_kernel_config:
node.log.debug("detecting mana driver enabled...")
information[
KEY_MANA_DRIVER_ENABLED
] = node.nics.is_mana_driver_enabled()
information[KEY_MANA_DRIVER_ENABLED] = (
node.nics.is_mana_driver_enabled()
)
node.log.debug(f"mana enabled: {information[KEY_MANA_DRIVER_ENABLED]}")
node.log.debug("detecting nvme driver enabled...")
_has_nvme_core = node.tools[KernelConfig].is_built_in(
Expand Down Expand Up @@ -927,19 +950,29 @@ def _initialize_credential(self) -> None:
logging.getLogger("azure").setLevel(azure_runbook.log_level)

if azure_runbook.service_principal_tenant_id:
os.environ[
"AZURE_TENANT_ID"
] = azure_runbook.service_principal_tenant_id
os.environ["AZURE_TENANT_ID"] = (
azure_runbook.service_principal_tenant_id
)
if azure_runbook.service_principal_client_id:
os.environ[
"AZURE_CLIENT_ID"
] = azure_runbook.service_principal_client_id
os.environ["AZURE_CLIENT_ID"] = (
azure_runbook.service_principal_client_id
)
if azure_runbook.service_principal_key:
os.environ["AZURE_CLIENT_SECRET"] = azure_runbook.service_principal_key

credential = DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)
if "AZURE_ACCESS_TOKEN" in os.environ:
token = os.environ["AZURE_ACCESS_TOKEN"]
else:
token = None

if token:
credential = StaticAccessTokenCredential(
token, int(time.time()) + 3600 * 24
)
else:
credential = DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)

with SubscriptionClient(
credential,
Expand Down Expand Up @@ -1726,14 +1759,14 @@ def _resource_sku_to_capability( # noqa: C901
azure_raw_capabilities["availability_zones"] = location_info.zones
for zone_details in location_info.zone_details:
for location_capability in zone_details.capabilities:
azure_raw_capabilities[
location_capability.name
] = location_capability.value
azure_raw_capabilities[location_capability.name] = (
location_capability.value
)
# Zones supporting the feature
if zone_details.additional_properties["Name"]:
azure_raw_capabilities[
"availability_zones"
] = zone_details.additional_properties["Name"]
azure_raw_capabilities["availability_zones"] = (
zone_details.additional_properties["Name"]
)

if resource_sku.capabilities:
for sku_capability in resource_sku.capabilities:
Expand Down

0 comments on commit 6333f0f

Please sign in to comment.