Skip to content

Commit

Permalink
fix: change bedrock embed_model_name to embedder_model_name (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
amadeusz-ds authored Feb 13, 2025
1 parent 4edcd1b commit 364fc9b
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 16 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.5.3-dev2
## 0.5.3

### Enhancements

Expand All @@ -7,6 +7,8 @@

### Fixes

* **Fix bedrock embedder: rename embed_model_name to embedder_model_name**

## 0.5.2

### Enhancements
Expand Down
2 changes: 1 addition & 1 deletion test/unit/v2/embedders/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def generate_embedder_config_params() -> dict:
"region_name": fake.city(),
}
if random.random() < 0.5:
params["embed_model_name"] = fake.word()
params["embedder_model_name"] = fake.word()
return params


Expand Down
2 changes: 1 addition & 1 deletion test/unit/v2/embedders/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def generate_embedder_config_params() -> dict:
params = {}
if random.random() < 0.5:
params["embed_model_name"] = fake.word() if random.random() < 0.5 else None
params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
params["embedder_model_kwargs"] = (
generate_random_dictionary(key_type=str, value_type=Any)
if random.random() < 0.5
Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.3-dev2" # pragma: no cover
__version__ = "0.5.3" # pragma: no cover
10 changes: 5 additions & 5 deletions unstructured_ingest/embed/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
aws_access_key_id: SecretStr
aws_secret_access_key: SecretStr
region_name: str = "us-west-2"
embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")

def wrap_error(self, e: Exception) -> Exception:
if is_internal_error(e=e):
Expand Down Expand Up @@ -130,15 +130,15 @@ def wrap_error(self, e: Exception) -> Exception:

def embed_query(self, query: str) -> list[float]:
"""Call out to Bedrock embedding endpoint."""
provider = self.config.embed_model_name.split(".")[0]
provider = self.config.embedder_model_name.split(".")[0]
body = conform_query(query=query, provider=provider)

bedrock_client = self.config.get_client()
# invoke bedrock API
try:
response = bedrock_client.invoke_model(
body=json.dumps(body),
modelId=self.config.embed_model_name,
modelId=self.config.embedder_model_name,
accept="application/json",
contentType="application/json",
)
Expand Down Expand Up @@ -173,15 +173,15 @@ def wrap_error(self, e: Exception) -> Exception:

async def embed_query(self, query: str) -> list[float]:
"""Call out to Bedrock embedding endpoint."""
provider = self.config.embed_model_name.split(".")[0]
provider = self.config.embedder_model_name.split(".")[0]
body = conform_query(query=query, provider=provider)
try:
async with self.config.get_async_client() as bedrock_client:
# invoke bedrock API
try:
response = await bedrock_client.invoke_model(
body=json.dumps(body),
modelId=self.config.embed_model_name,
modelId=self.config.embedder_model_name,
accept="application/json",
contentType="application/json",
)
Expand Down
16 changes: 9 additions & 7 deletions unstructured_ingest/v2/processes/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,20 @@ def get_octoai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":

return OctoAIEmbeddingEncoder(config=OctoAiEmbeddingConfig.model_validate(embedding_kwargs))

def get_bedrock_embedder(self) -> "BaseEmbeddingEncoder":
def get_bedrock_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
from unstructured_ingest.embed.bedrock import (
BedrockEmbeddingConfig,
BedrockEmbeddingEncoder,
)

embedding_kwargs = embedding_kwargs | {
"aws_access_key_id": self.embedding_aws_access_key_id,
"aws_secret_access_key": self.embedding_aws_secret_access_key.get_secret_value(),
"region_name": self.embedding_aws_region,
}

return BedrockEmbeddingEncoder(
config=BedrockEmbeddingConfig(
aws_access_key_id=self.embedding_aws_access_key_id,
aws_secret_access_key=self.embedding_aws_secret_access_key.get_secret_value(),
region_name=self.embedding_aws_region,
)
config=BedrockEmbeddingConfig.model_validate(embedding_kwargs)
)

def get_vertexai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
Expand Down Expand Up @@ -163,7 +165,7 @@ def get_embedder(self) -> "BaseEmbeddingEncoder":
return self.get_octoai_embedder(embedding_kwargs=kwargs)

if self.embedding_provider == "bedrock":
return self.get_bedrock_embedder()
return self.get_bedrock_embedder(embedding_kwargs=kwargs)

if self.embedding_provider == "vertexai":
return self.get_vertexai_embedder(embedding_kwargs=kwargs)
Expand Down

0 comments on commit 364fc9b

Please sign in to comment.