import os

from scale_gp import SGPClient
from scale_gp.types.model_template_create_params import (
    VendorConfiguration,
    VendorConfigurationBundleConfig,
    VendorConfigurationEndpointConfig,
)

account_id = os.environ.get("SGP_ACCOUNT_ID", None)
api_key = os.environ.get("SGP_API_KEY", None)

assert (
    account_id is not None
), "You need to set the SGP_ACCOUNT_ID - you can find it at https://egp.dashboard.scale.com/admin/accounts"
assert api_key is not None, "You need to provide your API key - see https://egp.dashboard.scale.com/admin/api-key"

client = SGPClient(api_key=api_key)

bundle_config = VendorConfigurationBundleConfig(image="gemini-pro", registry="aws-registry", tag="latest")

endpoint_config = VendorConfigurationEndpointConfig(
    max_workers=3,
)

vendor_configuration = VendorConfiguration(
    bundle_config=bundle_config,
    endpoint_config=endpoint_config,
)

model_template = client.model_templates.create(
    account_id=account_id,
    endpoint_type="SYNC",
    model_type="COMPLETION",
    name="Gemini-Pro Template",
    vendor_configuration=vendor_configuration,
)

model_instance = client.models.create(
    account_id=account_id,
    model_type="COMPLETION",
    name="gemini-pro",
    model_vendor="GOOGLE",
    model_template_id=model_template.id,
)

model_deployment = client.models.deployments.create(
    model_instance_id=model_instance.id, name="Gemini-Pro Deployment", account_id=account_id
)

print(model_deployment)

model_completion_response = client.models.deployments.execute(
    model_deployment_id=model_deployment.id,
    model_instance_id=model_instance.id,
    extra_body={"prompts": ["What is the capital of Canada?"]},
)

print(model_completion_response)
ModelDeployment(
    id='d4a457c3-7b56-4b0d-b6f1-45e5809907dd',
    account_id='66049ada2fc77c99ef015be7',
    created_at=datetime.datetime(2024, 9, 26, 19, 58, 51, 105175),
    created_by_user_id='42a5c8af-f698-43d0-923e-ba70102a2887',
    name='Gemini-Pro Deployment',
    status='READY',
    deployment_metadata=None,
    model_creation_parameters=None,
    model_endpoint_id=None,
    model_instance_id='6f6b4a0e-0ae2-43f2-9b46-b783d83a729f',
    vendor_configuration=None
)
import os

from scale_gp import SGPClient
from scale_gp.types.model_template_create_params import (
    VendorConfiguration,
    VendorConfigurationBundleConfig,
    VendorConfigurationEndpointConfig,
)

account_id = os.environ.get("SGP_ACCOUNT_ID", None)
api_key = os.environ.get("SGP_API_KEY", None)

assert (
    account_id is not None
), "You need to set the SGP_ACCOUNT_ID - you can find it at https://egp.dashboard.scale.com/admin/accounts"
assert api_key is not None, "You need to provide your API key - see https://egp.dashboard.scale.com/admin/api-key"

client = SGPClient(api_key=api_key)

bundle_config = VendorConfigurationBundleConfig(image="gemini-pro", registry="aws-registry", tag="latest")

endpoint_config = VendorConfigurationEndpointConfig(
    max_workers=3,
)

vendor_configuration = VendorConfiguration(
    bundle_config=bundle_config,
    endpoint_config=endpoint_config,
)

model_template = client.model_templates.create(
    account_id=account_id,
    endpoint_type="SYNC",
    model_type="COMPLETION",
    name="Gemini-Pro Template",
    vendor_configuration=vendor_configuration,
)

model_instance = client.models.create(
    account_id=account_id,
    model_type="COMPLETION",
    name="gemini-pro",
    model_vendor="GOOGLE",
    model_template_id=model_template.id,
)

model_deployment = client.models.deployments.create(
    model_instance_id=model_instance.id, name="Gemini-Pro Deployment", account_id=account_id
)

print(model_deployment)

model_completion_response = client.models.deployments.execute(
    model_deployment_id=model_deployment.id,
    model_instance_id=model_instance.id,
    extra_body={"prompts": ["What is the capital of Canada?"]},
)

print(model_completion_response)
ModelDeployment(
    id='d4a457c3-7b56-4b0d-b6f1-45e5809907dd',
    account_id='66049ada2fc77c99ef015be7',
    created_at=datetime.datetime(2024, 9, 26, 19, 58, 51, 105175),
    created_by_user_id='42a5c8af-f698-43d0-923e-ba70102a2887',
    name='Gemini-Pro Deployment',
    status='READY',
    deployment_metadata=None,
    model_creation_parameters=None,
    model_endpoint_id=None,
    model_instance_id='6f6b4a0e-0ae2-43f2-9b46-b783d83a729f',
    vendor_configuration=None
)