推断已训练模型 API

编辑

评估已训练的模型。该模型可以是任何由数据框分析训练或导入的监督模型。

对于启用了缓存的模型部署,结果可以直接从推断缓存返回。

请求

编辑

POST _ml/trained_models/<model_id>/_infer POST _ml/trained_models/<deployment_id>/_infer

路径参数

编辑
<model_id>
(可选,字符串) 已训练模型或模型别名的唯一标识符。

如果在 API 调用中指定了 model_id,并且该模型有多个部署,则将使用随机部署。如果 model_id 与其中一个部署的 ID 匹配,则将使用该部署。

<deployment_id>
(可选,字符串) 模型部署的唯一标识符。

查询参数

编辑
timeout
(可选,时间) 控制等待推断结果的时间。默认为 10 秒。

请求主体

编辑
docs
(必需,数组) 要传递给模型进行推断的对象数组。这些对象应包含与您配置的已训练模型输入匹配的字段。通常,对于 NLP 模型,字段名称为 text_field。此属性中指定的每个推断输入字段都必须是单个字符串,而不是字符串数组。
inference_config

(可选,对象) 推断的默认配置。这可以是:regressionclassificationfill_masknerquestion_answeringtext_classificationtext_embeddingzero_shot_classification。如果为 regressionclassification,则它必须与基础 definition.trained_modeltarget_type 匹配。如果为 fill_masknerquestion_answeringtext_classificationtext_embedding,则 model_type 必须为 pytorch。如果未指定,则使用创建模型时的 inference_config

inference_config 的属性
classification

(可选,对象) 用于推断的分类配置。

分类推断的属性
num_top_classes
(可选,整数) 指定要返回的最高类预测数。默认为 0。
num_top_feature_importance_values
(可选,整数) 指定每个文档的特征重要性值的最大数量。默认为 0,这意味着不进行特征重要性计算。
prediction_field_type
(可选,字符串) 指定要写入的预测字段的类型。有效值为:stringnumberboolean。当提供 boolean 时,1.0 将转换为 true0.0 将转换为 false
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
top_classes_results_field
(可选,字符串) 指定写入最高类的字段。默认为 top_classes
fill_mask

(可选,对象) 用于填充掩码自然语言处理 (NLP) 任务的配置。填充掩码任务适用于为填充掩码操作优化的模型。例如,对于 BERT 模型,可以提供以下文本:“法国的首都是 [MASK]。”。响应指示最有可能替换 [MASK] 的值。在这种情况下,最可能的标记是 paris

填充掩码推断的属性
num_top_classes
(可选,整数) 要返回的用于替换掩码标记的最高预测标记数。默认为 0
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

ner

(可选,对象) 配置命名实体识别 (NER) 任务。NER 是标记分类的特例。序列中的每个标记都根据提供的分类标签进行分类。目前,NER 任务需要 classification_labels 内部-外部-开始 (IOB) 格式的标签。仅支持人员、组织、位置和杂项。

ner 推断的属性
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

pass_through

(可选,对象) 配置 pass_through 任务。此任务对于调试很有用,因为不对推断输出进行后处理,并且原始池化层结果将返回给调用者。

pass_through 推断的属性
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

question_answering

(可选,对象) 配置问题解答自然语言处理 (NLP) 任务。问题解答可用于从大型文本语料库中提取特定问题的答案。

问题解答推断的属性
max_answer_length
(可选,整数) 答案中的最大单词数。默认为 15
num_top_classes
(可选,整数) 要返回的前几个找到的答案的数量。默认为 0,表示仅返回找到的最佳答案。
question
(必需,字符串) 用于提取答案的问题
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。

建议将 max_sequence_length 设置为 386span 设置为 128,并将 truncate 设置为 none

标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

regression

(可选,对象) 用于推断的回归配置。

回归推断的属性
num_top_feature_importance_values
(可选,整数) 指定每个文档的特征重要性值的最大数量。默认情况下,它为零,不进行特征重要性计算。
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
text_classification

(可选,对象) 文本分类任务。文本分类将提供的文本序列分类为先前已知的目标类。一个具体的例子是情感分析,它返回指示文本情感的可能目标类,例如“悲伤”、“快乐”或“愤怒”。

文本分类推断的属性
classification_labels
(可选,字符串) 分类标签数组。
num_top_classes
(可选,整数) 指定要返回的最高类预测数。默认为所有类 (-1)。
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

text_embedding

(可选,对象) 文本嵌入将输入序列转换为数字向量。这些嵌入不仅捕获标记,还捕获语义含义和上下文。这些嵌入可以用于稠密向量字段,以获得强大的见解。

文本嵌入推断的属性
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

text_similarity

(可选,对象) 文本相似度接受一个输入序列,并将其与另一个输入序列进行比较。这通常称为交叉编码。当将文档文本与提供的另一个文本输入进行比较时,此任务对于对文档文本进行排名非常有用。

文本相似度推断的属性
span_score_combination_function

(可选,字符串)标识当提供的文本段落长度超过 max_sequence_length 且必须自动分割进行多次调用时,如何组合生成的相似度得分。这仅在 truncatenonespan 为非负数时适用。默认值为 max。可用选项包括:

  • max: 返回所有跨度中的最大得分。
  • mean: 返回所有跨度的平均得分。
text
(必需,字符串)这是用来与所有提供的文档文本输入进行比较的文本。
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

with_special_tokens

(可选,布尔值)使用特殊标记进行分词。BERT 风格分词中通常包含的标记是:

  • [CLS]:被分类序列的第一个标记。
  • [SEP]:表示序列分隔。
deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

with_special_tokens

(可选,布尔值)使用特殊标记进行分词。BERT 风格分词中通常包含的标记是:

  • [CLS]:被分类序列的第一个标记。
  • [SEP]:表示序列分隔。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
span

(可选,整数) 当 truncatenone 时,您可以对较长的文本序列进行分区以进行推断。该值指示每个子序列之间重叠的标记数。

默认值为 -1,表示不进行窗口化或跨度。

当您的典型输入仅略大于 max_sequence_length 时,最好直接截断;第二个子序列中的信息很少。

with_special_tokens
(可选,布尔值)如果 true,则使用特殊标记进行分词。
zero_shot_classification

(对象,可选)配置零样本分类任务。零样本分类允许在没有预定标签的情况下进行文本分类。在推理时,可以调整要分类的标签。这使得这种类型的模型和任务非常灵活。

如果始终对相同的标签进行分类,最好使用微调的文本分类模型。

零样本分类推理的属性
labels
(可选,数组)要分类的标签。可以在创建时设置默认标签,然后在推理期间更新。
multi_label
(可选,布尔值)表示给定输入是否可能存在多个 true 标签。当标记的文本可能与多个输入标签相关时,这很有用。默认为 false
results_field
(可选,字符串) 添加到传入文档以包含推断预测的字段。默认为 predicted_value
tokenization

(可选,对象) 指示要执行的标记化以及所需的设置。默认的标记化配置为 bert。有效的标记化值为

  • bert:用于 BERT 风格的模型
  • deberta_v2:用于 DeBERTa v2 和 v3 风格的模型
  • mpnet:用于 MPNet 风格的模型
  • roberta:用于 RoBERTa 风格和 BART 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 xlm_roberta:用于 XLMRoBERTa 风格的模型
  • [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 bert_ja:用于为日语训练的 BERT 风格的模型。
标记化的属性
bert

(可选,对象) 将使用封闭的设置执行 BERT 风格的标记化。

bert 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

deberta_v2

(可选,对象) 将使用封闭的设置执行 DeBERTa 风格的标记化。

deberta_v2 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • balanced:可以截断第一个和/或第二个序列,以平衡两个序列中包含的标记。
  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。
roberta

(可选,对象) 将使用封闭的设置执行 RoBERTa 风格的标记化。

roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

mpnet

(可选,对象) 将使用封闭的设置执行 MPNet 风格的标记化。

mpnet 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

xlm_roberta

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行 XLMRoBERTa 风格的标记化。

xlm_roberta 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

bert_ja

(可选,对象) [预览] 此功能为技术预览版,可能会在未来的版本中更改或删除。Elastic 将致力于解决任何问题,但技术预览版中的功能不受官方 GA 功能的支持 SLA 的约束。 将使用封闭的设置执行日语文本的 BERT 风格标记化。

bert_ja 的属性
truncate

(可选,字符串) 指示当标记超过 max_sequence_length 时如何截断标记。默认值为 first

  • none:不进行截断;推断请求收到错误。
  • first:仅截断第一个序列。
  • second:仅截断第二个序列。如果只有一个序列,则截断该序列。

对于 zero_shot_classification,假设序列始终是第二个序列。因此,在这种情况下不要使用 second

示例

编辑

响应取决于模型的类型。

例如,对于语言识别,响应是预测的语言和得分

resp = client.ml.infer_trained_model(
    model_id="lang_ident_model_1",
    docs=[
        {
            "text": "The fool doth think he is wise, but the wise man knows himself to be a fool."
        }
    ],
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'lang_ident_model_1',
  body: {
    docs: [
      {
        text: 'The fool doth think he is wise, but the wise man knows himself to be a fool.'
      }
    ]
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "lang_ident_model_1",
  docs: [
    {
      text: "The fool doth think he is wise, but the wise man knows himself to be a fool.",
    },
  ],
});
console.log(response);
POST _ml/trained_models/lang_ident_model_1/_infer
{
  "docs":[{"text": "The fool doth think he is wise, but the wise man knows himself to be a fool."}]
}

以下是预测英语的结果,具有较高的概率。

{
  "inference_results": [
    {
      "predicted_value": "en",
      "prediction_probability": 0.9999658805366392,
      "prediction_score": 0.9999658805366392
    }
  ]
}

当它是文本分类模型时,响应是得分和预测的分类。

例如

resp = client.ml.infer_trained_model(
    model_id="model2",
    docs=[
        {
            "text_field": "The movie was awesome!!"
        }
    ],
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'model2',
  body: {
    docs: [
      {
        text_field: 'The movie was awesome!!'
      }
    ]
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "model2",
  docs: [
    {
      text_field: "The movie was awesome!!",
    },
  ],
});
console.log(response);
POST _ml/trained_models/model2/_infer
{
	"docs": [{"text_field": "The movie was awesome!!"}]
}

API 返回预测的标签和置信度。

{
  "inference_results": [{
    "predicted_value" : "POSITIVE",
    "prediction_probability" : 0.9998667964092964
  }]
}

对于命名实体识别 (NER) 模型,响应包含带注释的文本输出和识别出的实体。

resp = client.ml.infer_trained_model(
    model_id="model2",
    docs=[
        {
            "text_field": "Hi my name is Josh and I live in Berlin"
        }
    ],
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'model2',
  body: {
    docs: [
      {
        text_field: 'Hi my name is Josh and I live in Berlin'
      }
    ]
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "model2",
  docs: [
    {
      text_field: "Hi my name is Josh and I live in Berlin",
    },
  ],
});
console.log(response);
POST _ml/trained_models/model2/_infer
{
	"docs": [{"text_field": "Hi my name is Josh and I live in Berlin"}]
}

在这种情况下,API 返回:

{
  "inference_results": [{
    "predicted_value" : "Hi my name is [Josh](PER&Josh) and I live in [Berlin](LOC&Berlin)",
    "entities" : [
      {
        "entity" : "Josh",
        "class_name" : "PER",
        "class_probability" : 0.9977303419824,
        "start_pos" : 14,
        "end_pos" : 18
      },
      {
        "entity" : "Berlin",
        "class_name" : "LOC",
        "class_probability" : 0.9992474323902818,
        "start_pos" : 33,
        "end_pos" : 39
      }
    ]
  }]
}

零样本分类模型需要额外的配置来定义类标签。这些标签在零样本推理配置中传递。

resp = client.ml.infer_trained_model(
    model_id="model2",
    docs=[
        {
            "text_field": "This is a very happy person"
        }
    ],
    inference_config={
        "zero_shot_classification": {
            "labels": [
                "glad",
                "sad",
                "bad",
                "rad"
            ],
            "multi_label": False
        }
    },
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'model2',
  body: {
    docs: [
      {
        text_field: 'This is a very happy person'
      }
    ],
    inference_config: {
      zero_shot_classification: {
        labels: [
          'glad',
          'sad',
          'bad',
          'rad'
        ],
        multi_label: false
      }
    }
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "model2",
  docs: [
    {
      text_field: "This is a very happy person",
    },
  ],
  inference_config: {
    zero_shot_classification: {
      labels: ["glad", "sad", "bad", "rad"],
      multi_label: false,
    },
  },
});
console.log(response);
POST _ml/trained_models/model2/_infer
{
  "docs": [
    {
      "text_field": "This is a very happy person"
    }
  ],
  "inference_config": {
    "zero_shot_classification": {
      "labels": [
        "glad",
        "sad",
        "bad",
        "rad"
      ],
      "multi_label": false
    }
  }
}

API 返回预测的标签和置信度,以及最上面的类

{
  "inference_results": [{
    "predicted_value" : "glad",
    "top_classes" : [
      {
        "class_name" : "glad",
        "class_probability" : 0.8061155063386439,
        "class_score" : 0.8061155063386439
      },
      {
        "class_name" : "rad",
        "class_probability" : 0.18218006158387956,
        "class_score" : 0.18218006158387956
      },
      {
        "class_name" : "bad",
        "class_probability" : 0.006325615787634201,
        "class_score" : 0.006325615787634201
      },
      {
        "class_name" : "sad",
        "class_probability" : 0.0053788162898424545,
        "class_score" : 0.0053788162898424545
      }
    ],
    "prediction_probability" : 0.8061155063386439
  }]
}

问答模型需要额外的配置来定义要回答的问题。

resp = client.ml.infer_trained_model(
    model_id="model2",
    docs=[
        {
            "text_field": "<long text to extract answer>"
        }
    ],
    inference_config={
        "question_answering": {
            "question": "<question to be answered>"
        }
    },
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'model2',
  body: {
    docs: [
      {
        text_field: '<long text to extract answer>'
      }
    ],
    inference_config: {
      question_answering: {
        question: '<question to be answered>'
      }
    }
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "model2",
  docs: [
    {
      text_field: "<long text to extract answer>",
    },
  ],
  inference_config: {
    question_answering: {
      question: "<question to be answered>",
    },
  },
});
console.log(response);
POST _ml/trained_models/model2/_infer
{
  "docs": [
    {
      "text_field": "<long text to extract answer>"
    }
  ],
  "inference_config": {
    "question_answering": {
      "question": "<question to be answered>"
    }
  }
}

API 返回类似于以下的响应:

{
    "predicted_value": <string subsection of the text that is the answer>,
    "start_offset": <character offset in document to start>,
    "end_offset": <character offset end of the answer,
    "prediction_probability": <prediction score>
}

文本相似度模型需要至少两个文本序列进行比较。可以提供多个文本字符串与另一个文本序列进行比较

resp = client.ml.infer_trained_model(
    model_id="cross-encoder__ms-marco-tinybert-l-2-v2",
    docs=[
        {
            "text_field": "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."
        },
        {
            "text_field": "New York City is famous for the Metropolitan Museum of Art."
        }
    ],
    inference_config={
        "text_similarity": {
            "text": "How many people live in Berlin?"
        }
    },
)
print(resp)
const response = await client.ml.inferTrainedModel({
  model_id: "cross-encoder__ms-marco-tinybert-l-2-v2",
  docs: [
    {
      text_field:
        "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
    },
    {
      text_field: "New York City is famous for the Metropolitan Museum of Art.",
    },
  ],
  inference_config: {
    text_similarity: {
      text: "How many people live in Berlin?",
    },
  },
});
console.log(response);
POST _ml/trained_models/cross-encoder__ms-marco-tinybert-l-2-v2/_infer
{
  "docs":[{ "text_field": "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."}, {"text_field": "New York City is famous for the Metropolitan Museum of Art."}],
  "inference_config": {
    "text_similarity": {
      "text": "How many people live in Berlin?"
    }
  }
}

响应包含与 text_similaritytext 字段中提供的文本进行比较的每个字符串的预测

{
  "inference_results": [
    {
      "predicted_value": 7.235751628875732
    },
    {
      "predicted_value": -11.562295913696289
    }
  ]
}

调用 API 时可以覆盖分词截断选项

resp = client.ml.infer_trained_model(
    model_id="model2",
    docs=[
        {
            "text_field": "The Amazon rainforest covers most of the Amazon basin in South America"
        }
    ],
    inference_config={
        "ner": {
            "tokenization": {
                "bert": {
                    "truncate": "first"
                }
            }
        }
    },
)
print(resp)
response = client.ml.infer_trained_model(
  model_id: 'model2',
  body: {
    docs: [
      {
        text_field: 'The Amazon rainforest covers most of the Amazon basin in South America'
      }
    ],
    inference_config: {
      ner: {
        tokenization: {
          bert: {
            truncate: 'first'
          }
        }
      }
    }
  }
)
puts response
const response = await client.ml.inferTrainedModel({
  model_id: "model2",
  docs: [
    {
      text_field:
        "The Amazon rainforest covers most of the Amazon basin in South America",
    },
  ],
  inference_config: {
    ner: {
      tokenization: {
        bert: {
          truncate: "first",
        },
      },
    },
  },
});
console.log(response);
POST _ml/trained_models/model2/_infer
{
  "docs": [{"text_field": "The Amazon rainforest covers most of the Amazon basin in South America"}],
  "inference_config": {
    "ner": {
      "tokenization": {
        "bert": {
          "truncate": "first"
        }
      }
    }
  }
}

当输入由于模型 max_sequence_length 的限制而被截断时,响应中会出现 is_truncated 字段。

{
  "inference_results": [{
    "predicted_value" : "The [Amazon](LOC&Amazon) rainforest covers most of the [Amazon](LOC&Amazon) basin in [South America](LOC&South+America)",
    "entities" : [
      {
        "entity" : "Amazon",
        "class_name" : "LOC",
        "class_probability" : 0.9505460915724254,
        "start_pos" : 4,
        "end_pos" : 10
      },
      {
        "entity" : "Amazon",
        "class_name" : "LOC",
        "class_probability" : 0.9969992804311777,
        "start_pos" : 41,
        "end_pos" : 47
      }
    ],
    "is_truncated" : true
  }]
}