Skip to content

Commit

Permalink
Merge pull request #1112 from newrelic/develop-ai-limited-preview-3
Browse files Browse the repository at this point in the history
Develop ai limited preview 3
  • Loading branch information
umaannamalai authored Mar 27, 2024
2 parents 43e5e25 + a21115e commit 0e67af5
Show file tree
Hide file tree
Showing 75 changed files with 25,170 additions and 328 deletions.
8 changes: 7 additions & 1 deletion newrelic/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def __asgi_application(*args, **kwargs):
from newrelic.api.message_transaction import (
wrap_message_transaction as __wrap_message_transaction,
)
from newrelic.api.ml_model import (
record_llm_feedback_event as __record_llm_feedback_event,
)
from newrelic.api.ml_model import set_llm_token_count_callback as __set_llm_token_count_callback
from newrelic.api.ml_model import wrap_mlmodel as __wrap_mlmodel
from newrelic.api.profile_trace import ProfileTraceWrapper as __ProfileTraceWrapper
from newrelic.api.profile_trace import profile_trace as __profile_trace
Expand All @@ -174,10 +178,10 @@ def __asgi_application(*args, **kwargs):
from newrelic.api.web_transaction import web_transaction as __web_transaction
from newrelic.api.web_transaction import wrap_web_transaction as __wrap_web_transaction
from newrelic.common.object_names import callable_name as __callable_name
from newrelic.common.object_wrapper import CallableObjectProxy as __CallableObjectProxy
from newrelic.common.object_wrapper import FunctionWrapper as __FunctionWrapper
from newrelic.common.object_wrapper import InFunctionWrapper as __InFunctionWrapper
from newrelic.common.object_wrapper import ObjectProxy as __ObjectProxy
from newrelic.common.object_wrapper import CallableObjectProxy as __CallableObjectProxy
from newrelic.common.object_wrapper import ObjectWrapper as __ObjectWrapper
from newrelic.common.object_wrapper import OutFunctionWrapper as __OutFunctionWrapper
from newrelic.common.object_wrapper import PostFunctionWrapper as __PostFunctionWrapper
Expand Down Expand Up @@ -343,3 +347,5 @@ def __asgi_application(*args, **kwargs):
insert_html_snippet = __wrap_api_call(__insert_html_snippet, "insert_html_snippet")
verify_body_exists = __wrap_api_call(__verify_body_exists, "verify_body_exists")
wrap_mlmodel = __wrap_api_call(__wrap_mlmodel, "wrap_mlmodel")
record_llm_feedback_event = __wrap_api_call(__record_llm_feedback_event, "record_llm_feedback_event")
set_llm_token_count_callback = __wrap_api_call(__set_llm_token_count_callback, "set_llm_token_count_callback")
91 changes: 91 additions & 0 deletions newrelic/api/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys
import uuid
import warnings

from newrelic.api.transaction import current_transaction
from newrelic.common.object_names import callable_name
from newrelic.core.config import global_settings
from newrelic.hooks.mlmodel_sklearn import _nr_instrument_model

_logger = logging.getLogger(__name__)


def wrap_mlmodel(model, name=None, version=None, feature_names=None, label_names=None, metadata=None):
model_callable_name = callable_name(model)
Expand All @@ -33,3 +40,87 @@ def wrap_mlmodel(model, name=None, version=None, feature_names=None, label_names
model._nr_wrapped_label_names = label_names
if metadata:
model._nr_wrapped_metadata = metadata


def record_llm_feedback_event(trace_id, rating, category=None, message=None, metadata=None):
transaction = current_transaction()
if not transaction:
warnings.warn(
"No message feedback events will be recorded. record_llm_feedback_event must be called within the "
"scope of a transaction."
)
return

feedback_event_id = str(uuid.uuid4())
feedback_event = metadata.copy() if metadata else {}
feedback_event.update(
{
"id": feedback_event_id,
"trace_id": trace_id,
"rating": rating,
"category": category,
"message": message,
"ingest_source": "Python",
}
)

transaction.record_custom_event("LlmFeedbackMessage", feedback_event)


def set_llm_token_count_callback(callback, application=None):
"""
Set the current callback to be used to calculate LLM token counts.
Arguments:
callback -- the user-defined callback that will calculate and return the total token count as an integer or None if it does not know
application -- optional application object to associate call with
"""
if callback and not callable(callback):
_logger.error(
"callback passed to set_llm_token_count_callback must be a Callable type or None to unset the callback."
)
return

from newrelic.api.application import application_instance

# Check for activated application if it exists and was not given.
application = application or application_instance(activate=False)

# Get application settings if it exists, or fallback to global settings object.
settings = application.settings if application else global_settings()

if not settings:
_logger.error(
"Failed to set llm_token_count_callback. Settings not found on application or in global_settings."
)
return

if not callback:
settings.ai_monitoring._llm_token_count_callback = None
return

def _wrap_callback(model, content):
if model is None:
_logger.debug(
"The model argument passed to the user-defined token calculation callback is None. The callback will not be run."
)
return None

if content is None:
_logger.debug(
"The content argument passed to the user-defined token calculation callback is None. The callback will not be run."
)
return None

token_count_val = callback(model, content)

if not isinstance(token_count_val, int) or token_count_val < 0:
_logger.warning(
"llm_token_count_callback returned an invalid value of %s. This value must be a positive integer and will not be recorded for the token_count."
% token_count_val
)
return None

return token_count_val

settings.ai_monitoring._llm_token_count_callback = _wrap_callback
43 changes: 27 additions & 16 deletions newrelic/api/time_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from newrelic.core.config import is_expected_error, should_ignore_error
from newrelic.core.trace_cache import trace_cache

from newrelic.packages import six

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -260,6 +259,11 @@ def _observe_exception(self, exc_info=None, ignore=None, expected=None, status_c
module, name, fullnames, message_raw = parse_exc_info((exc, value, tb))
fullname = fullnames[0]

# In case message is in JSON format for OpenAI models
# this will result in a "cleaner" message format
if getattr(value, "_nr_message", None):
message_raw = value._nr_message

# Check to see if we need to strip the message before recording it.

if settings.strip_exception_messages.enabled and fullname not in settings.strip_exception_messages.allowlist:
Expand Down Expand Up @@ -422,23 +426,32 @@ def notice_error(self, error=None, attributes=None, expected=None, ignore=None,
input_attributes = {}
input_attributes.update(transaction._custom_params)
input_attributes.update(attributes)
error_group_name_raw = settings.error_collector.error_group_callback(value, {
"traceback": tb,
"error.class": exc,
"error.message": message_raw,
"error.expected": is_expected,
"custom_params": input_attributes,
"transactionName": getattr(transaction, "name", None),
"response.status": getattr(transaction, "_response_code", None),
"request.method": getattr(transaction, "_request_method", None),
"request.uri": getattr(transaction, "_request_uri", None),
})
error_group_name_raw = settings.error_collector.error_group_callback(
value,
{
"traceback": tb,
"error.class": exc,
"error.message": message_raw,
"error.expected": is_expected,
"custom_params": input_attributes,
"transactionName": getattr(transaction, "name", None),
"response.status": getattr(transaction, "_response_code", None),
"request.method": getattr(transaction, "_request_method", None),
"request.uri": getattr(transaction, "_request_uri", None),
},
)
if error_group_name_raw:
_, error_group_name = process_user_attribute("error.group.name", error_group_name_raw)
if error_group_name is None or not isinstance(error_group_name, six.string_types):
raise ValueError("Invalid attribute value for error.group.name. Expected string, got: %s" % repr(error_group_name_raw))
raise ValueError(
"Invalid attribute value for error.group.name. Expected string, got: %s"
% repr(error_group_name_raw)
)
except Exception:
_logger.error("Encountered error when calling error group callback:\n%s", "".join(traceback.format_exception(*sys.exc_info())))
_logger.error(
"Encountered error when calling error group callback:\n%s",
"".join(traceback.format_exception(*sys.exc_info())),
)
error_group_name = None

transaction._create_error_node(
Expand Down Expand Up @@ -595,13 +608,11 @@ def update_async_exclusive_time(self, min_child_start_time, exclusive_duration):
def process_child(self, node, is_async):
self.children.append(node)
if is_async:

# record the lowest start time
self.min_child_start_time = min(self.min_child_start_time, node.start_time)

# if there are no children running, finalize exclusive time
if self.child_count == len(self.children):

exclusive_duration = node.end_time - self.min_child_start_time

self.update_async_exclusive_time(self.min_child_start_time, exclusive_duration)
Expand Down
18 changes: 14 additions & 4 deletions newrelic/api/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, application, enabled=None, source=None):

self.thread_id = None

self._transaction_id = id(self)
self._identity = id(self)
self._transaction_lock = threading.Lock()

self._dead = False
Expand All @@ -193,6 +193,7 @@ def __init__(self, application, enabled=None, source=None):
self._frameworks = set()
self._message_brokers = set()
self._dispatchers = set()
self._ml_models = set()

self._frozen_path = None

Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(self, application, enabled=None, source=None):
trace_id = "%032x" % random.getrandbits(128)

# 16-digit random hex. Padded with zeros in the front.
# This is the official transactionId in the UI.
self.guid = trace_id[:16]

# 32-digit random hex. Padded with zeros in the front.
Expand Down Expand Up @@ -421,7 +423,7 @@ def __exit__(self, exc, value, tb):
if not self.enabled:
return

if self._transaction_id != id(self):
if self._identity != id(self):
return

if not self._settings:
Expand Down Expand Up @@ -568,6 +570,10 @@ def __exit__(self, exc, value, tb):
for dispatcher, version in self._dispatchers:
self.record_custom_metric("Python/Dispatcher/%s/%s" % (dispatcher, version), 1)

if self._ml_models:
for ml_model, version in self._ml_models:
self.record_custom_metric("Supportability/Python/ML/%s/%s" % (ml_model, version), 1)

if self._settings.distributed_tracing.enabled:
# Sampled and priority need to be computed at the end of the
# transaction when distributed tracing or span events are enabled.
Expand Down Expand Up @@ -1715,7 +1721,7 @@ def record_custom_event(self, event_type, params):
if not settings.custom_insights_events.enabled:
return

event = create_custom_event(event_type, params)
event = create_custom_event(event_type, params, settings=settings)
if event:
self._custom_events.add(event, priority=self.priority)

Expand All @@ -1728,7 +1734,7 @@ def record_ml_event(self, event_type, params):
if not settings.ml_insights_events.enabled:
return

event = create_custom_event(event_type, params)
event = create_custom_event(event_type, params, settings=settings, is_ml_event=True)
if event:
self._ml_events.add(event, priority=self.priority)

Expand Down Expand Up @@ -1835,6 +1841,10 @@ def add_dispatcher_info(self, name, version=None):
if name:
self._dispatchers.add((name, version))

def add_ml_model_info(self, name, version=None):
if name:
self._ml_models.add((name, version))

def dump(self, file):
"""Dumps details about the transaction to the file object."""

Expand Down
Loading

0 comments on commit 0e67af5

Please sign in to comment.