SageMaker
Amazon SageMaker (opens in a new tab) 是一个系统,可以使用完全托管的基础设施、工具和工作流程构建、训练和部署任何用例的机器学习(ML)模型。
本教程将介绍如何使用托管在 SageMaker endpoint
上的LLM。
!pip3 install langchain boto3
设置#
您必须设置 SagemakerEndpoint
调用的以下必需参数:
endpoint_name
:已部署的Sagemaker模型的端点名称。必须在AWS区域内是唯一的。credentials_profile_name
:位于~/.aws/credentials或~/.aws/config文件中的配置文件名称,其中指定了访问密钥或角色信息。 如果未指定,将使用默认凭据文件配置文件或,如果在EC2实例上,则使用来自IMDS的凭据。 参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html (opens in a new tab)
Example#
from langchain.docstore.document import Document
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""
docs = [
Document(
page_content=example_doc_1,
)
]
from typing import Dict
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.chains.question_answering import load_qa_chain
import json
query = """How long was Elizabeth hospitalized?
"""
prompt_template = """Use the following pieces of context to answer the question at the end.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
class ContentHandler(ContentHandlerBase):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
content_handler = ContentHandler()
chain = load_qa_chain(
llm=SagemakerEndpoint(
endpoint_name="endpoint-name",
credentials_profile_name="credentials-profile-name",
region_name="us-west-2",
model_kwargs={"temperature":1e-10},
content_handler=content_handler
),
prompt=PROMPT
)
chain({"input_documents": docs, "question": query}, return_only_outputs=True)