Ich erhalte diesen Fehler immerPython

Python-Programme
Anonymous
 Ich erhalte diesen Fehler immer

Post by Anonymous »

Ich trainiere ein Transformatormodell mit dem PPO -Algorithmus von Rllib, aber ich begegne einen Gerätemisationsfehler: < /p>

RuntimeError: Erwartete alle Tensors, die alle Tensoren auf demselben Gerät sind, aber mindestens zwei Devices, Cuda: 0 und CPU! Modellkomponenten zur GPU mit (self.Device) , der Fehler bleibt bestehen. CUDA ist verfügbar und das Modell soll auf der GPU ausgeführt werden. < /P>

Code: Select all

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class SimpleTransformer(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)

# Configuration
custom_config = model_config["custom_model_config"]
self.input_dim = 76
self.seq_len = custom_config["seq_len"]
self.embed_size = custom_config["embed_size"]
self.nheads = custom_config["nhead"]
self.nlayers = custom_config["nlayers"]
self.dropout = custom_config["dropout"]
self.values_out = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Input layer
self.input_embed = nn.Linear(self.input_dim, self.embed_size).to(self.device)

# Positional encoding
self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size).to(self.device)

# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.embed_size,
nhead=self.nheads,
dropout=self.dropout,
activation='gelu',
device=self.device),
num_layers=self.nlayers
)

# Policy and value heads
self.policy_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
nn.ReLU(),
nn.Linear(64, num_outputs) # Action space size
).to(self.device)

self.value_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64),
nn.ReLU(),
nn.Linear(64, 1)
).to(self.device)

def forward(self, input_dict, state, seq_len):
# Process input
x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
dynamic_features = x[:, -1, 2:4].clone().to(self.device)
x = self.input_embed(x)

position = torch.arange(0, self.seq_len).unsqueeze(0).expand(x.size(0), -1).to(self.device)
x = x + self.pos_encoding(position)

transformer_out = self.transformer(x)
last_out = transformer_out[:, -1, :]

combined = torch.cat((last_out, dynamic_features), dim=1)

actions = self.policy_head(combined)
self.values_out = self.value_head(combined).squeeze(1)

return actions, state
< /code>
Hier ist die vollständige Fehlermeldung: < /p>
Trial status: 1 ERROR
Current time: 2025-04-11 20:44:55.  Total running time: 14s
Logical resource usage: 0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
╭──────────────────────────────────────╮
│ Trial name                  status   │
├──────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000   ERROR    │
╰──────────────────────────────────────╯

Number of errored trials: 1
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                    # failures   error file                                                                                                                                                                                                      │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000              1   C:/Users/tmpou/AppData/Local/Temp/ray/session_2025-04-11_20-44-35_479257_23712/artifacts/2025-04-11_20-44-40/PPO_2025-04-11_20-44-40/driver_artifacts/PPO_CryptoEnv_a50d0_00000_0_2025-04-11_20-44-40/error.txt │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Traceback (most recent call last):
File "C:\Users\tmpou\Developer\MSc AI\Deep Learning and Multi-media data\crypto_rl_bot\train.py", line 14, in 
tune.run(
File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\tune.py", line 1042, in run
raise TuneError("Trials did not complete", incomplete_trials)
ray.tune.error.TuneError: ('Trials did not complete', [PPO_CryptoEnv_a50d0_00000])
(PPO pid=31224) 2025-04-11 20:44:55,030 ERROR actor_manager.py:517 -- Ray error, taking actor 1 out of service.  The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224)     self._build_policy_map(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224)     new_policy = create_policy_for_framework(
(PPO pid=31224)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224)     self._initialize_loss_from_dummy_batch()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224)     self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224)     return self.dist.log_prob(actions)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224)     return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224) Exception raised in creation task:  The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 229, in _setup
(PPO pid=31224)     self.add_workers(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 682, in add_workers
(PPO pid=31224)     raise result.get()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\actor_manager.py", line 497, in _fetch_result
(PPO pid=31224)     result = ray.get(r)
(PPO pid=31224)              ^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\auto_init_hook.py", line 21, in auto_init_wrapper
(PPO pid=31224)     return fn(*args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\client_mode_hook.py", line 103, in wrapper
(PPO pid=31224)     return func(*args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 2667, in get
(PPO pid=31224)     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
(PPO pid=31224)                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 866, in get_objects
(PPO pid=31224)     raise value
(PPO pid=31224) ray.exceptions.RayActorError:  The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224)     self._build_policy_map(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224)     new_policy = create_policy_for_framework(
(PPO pid=31224)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224)     self._initialize_loss_from_dummy_batch()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224)     self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224)     return self.dist.log_prob(actions)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224)     return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError:  Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224)
(PPO pid=31224) During handling of the above exception, another exception occurred:
(PPO pid=31224)
(PPO pid=31224) ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 533, in __init__
(PPO pid=31224)     super().__init__(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\trainable\trainable.py", line 161, in __init__
(PPO pid=31224)     self.setup(copy.deepcopy(self.config))
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 631, in setup
(PPO pid=31224)     self.workers = WorkerSet(
(PPO pid=31224)                    ^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 181, in __init__
(PPO pid=31224)     raise e.args[0].args[2]
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(RolloutWorker pid=3964) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=)
(RolloutWorker pid=3964)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(RolloutWorker pid=3964)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(RolloutWorker pid=3964)     return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span [repeated 3x across cluster] (Ray deduplicates logs by default.  Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RolloutWorker pid=3964)     return method(self, *_args, **_kwargs) [repeated 3x across cluster]
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 3x across cluster]
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__ [repeated 2x across cluster]
(RolloutWorker pid=3964)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(RolloutWorker pid=3964)     self._build_policy_map(
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(RolloutWorker pid=3964)     new_policy = create_policy_for_framework(
(RolloutWorker pid=3964)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(RolloutWorker pid=3964)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)     self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=3964)     self.loss(self.model, self.dist_class, train_batch)
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(RolloutWorker pid=3964)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(RolloutWorker pid=3964)     return self.dist.log_prob(actions)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(RolloutWorker pid=3964)     return log_pmf.gather(-1, value).squeeze(-1)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post